From 285c6755d165574d53b1f7e875300ac877db6473 Mon Sep 17 00:00:00 2001 From: AbdelRauf Date: Sun, 23 May 2021 21:51:32 +0200 Subject: [PATCH] Ctc beam search decoder Signed-off-by: brian --- .../blas/ops/declarable/CustomOperations.h | 1 + .../generic/decoder/ctc_beam_op.cpp | 111 + .../ops/declarable/generic/loss/ctcLoss.cpp | 26 +- .../cpp/blas/ops/declarable/headers/decoder.h | 65 + .../cpp/blas/ops/declarable/headers/loss.h | 3 +- .../ops/declarable/helpers/cpu/ctcLoss.cpp | 65 +- .../cpp/blas/ops/declarable/helpers/ctc.h | 153 + .../cpp/blas/ops/declarable/helpers/ctcLoss.h | 55 - .../ops/declarable/helpers/cuda/ctcLoss.cu | 2 +- .../ops/declarable/helpers/impl/ctcBeam.cpp | 718 ++ .../src/test/tests_cpu/CMakeLists.txt | 31 + .../src/test/tests_cpu/CMakeLists.txt.in | 16 + .../test/tests_cpu/layers_tests/AllTests.cpp | 47 + .../layers_tests/ArrayOptionsTests.cpp | 111 + .../tests_cpu/layers_tests/AtomicTests.cu | 243 + .../tests_cpu/layers_tests/AttentionTests.cpp | 222 + .../tests_cpu/layers_tests/BackpropTests.cpp | 51 + .../layers_tests/BitwiseUtilsTests.cpp | 78 + .../layers_tests/BooleanOpsTests.cpp | 150 + .../layers_tests/BroadcastableOpsTests.cpp | 857 ++ .../tests_cpu/layers_tests/BrodcastTests.cpp | 68 + .../tests_cpu/layers_tests/CMakeLists.txt | 171 + .../test/tests_cpu/layers_tests/CnpyTests.cpp | 96 + .../layers_tests/ConditionalTests.cpp | 334 + .../layers_tests/ConstantShapeHelperTests.cpp | 351 + .../tests_cpu/layers_tests/ContextTests.cpp | 358 + .../layers_tests/ConvolutionTests1.cpp | 2921 +++++++ .../layers_tests/ConvolutionTests2.cpp | 2859 +++++++ .../test/tests_cpu/layers_tests/CuDnnTests.cu | 150 + .../layers_tests/CudaBasicsTests1.cu | 2926 +++++++ .../layers_tests/CudaBasicsTests2.cu | 1159 +++ .../layers_tests/CudaExtraArgumentsTests.cu | 76 + .../layers_tests/CudaLaunchHelperTests.cpp | 48 + .../layers_tests/DataBufferTests.cpp | 80 + .../layers_tests/DataBufferTestsCuda.cu | 91 + .../layers_tests/DataTypesValidationTests.cpp | 158 + .../layers_tests/DeclarableOpsTests1.cpp | 3370 ++++++++ .../layers_tests/DeclarableOpsTests10.cpp | 3238 ++++++++ .../layers_tests/DeclarableOpsTests11.cpp | 4030 ++++++++++ .../layers_tests/DeclarableOpsTests12.cpp | 3469 ++++++++ .../layers_tests/DeclarableOpsTests13.cpp | 2862 +++++++ .../layers_tests/DeclarableOpsTests14.cpp | 2454 ++++++ .../layers_tests/DeclarableOpsTests15.cpp | 2027 +++++ .../layers_tests/DeclarableOpsTests16.cpp | 1499 ++++ .../layers_tests/DeclarableOpsTests17.cpp | 94 + .../layers_tests/DeclarableOpsTests18.cpp | 1683 ++++ .../layers_tests/DeclarableOpsTests19.cpp | 427 + .../layers_tests/DeclarableOpsTests2.cpp | 4487 +++++++++++ .../layers_tests/DeclarableOpsTests3.cpp | 2764 +++++++ .../layers_tests/DeclarableOpsTests4.cpp | 2450 ++++++ .../layers_tests/DeclarableOpsTests5.cpp | 3086 ++++++++ .../layers_tests/DeclarableOpsTests6.cpp | 2807 +++++++ .../layers_tests/DeclarableOpsTests7.cpp | 7000 +++++++++++++++++ .../layers_tests/DeclarableOpsTests8.cpp | 3525 +++++++++ .../layers_tests/DeclarableOpsTests9.cpp | 2592 ++++++ .../layers_tests/DeclarableOpsTestsCuda1.cu | 78 + .../tests_cpu/layers_tests/EmptyTests.cpp | 256 + .../layers_tests/ExtraArgumentsTests.cpp | 68 + .../layers_tests/FlatBuffersTests.cpp | 817 ++ .../tests_cpu/layers_tests/FlatUtilsTests.cpp | 104 + .../layers_tests/GraphExecutionerTests.cpp | 105 + .../layers_tests/GraphHolderTests.cpp | 88 + .../GraphRandomGeneratorTests.cpp | 266 + .../layers_tests/GraphStateTests.cpp | 351 + .../tests_cpu/layers_tests/GraphTests.cpp | 1640 ++++ .../tests_cpu/layers_tests/HashUtilsTests.cpp | 45 + .../tests_cpu/layers_tests/HelpersTests1.cpp | 2347 ++++++ .../tests_cpu/layers_tests/HelpersTests2.cpp | 429 + .../tests_cpu/layers_tests/IndexingTests.cpp | 472 ++ .../layers_tests/JavaInteropCudaTests.cu | 89 + .../layers_tests/JavaInteropTests.cpp | 1500 ++++ .../tests_cpu/layers_tests/LambdaTests.cu | 221 + .../layers_tests/LaunchContextCudaTests.cu | 127 + .../layers_tests/LegacyOpsCudaTests.cu | 114 + .../tests_cpu/layers_tests/LegacyOpsTests.cpp | 770 ++ .../layers_tests/ListOperationsTests.cpp | 663 ++ .../layers_tests/LoopCoordsHelperTests.cpp | 225 + .../layers_tests/MemoryUtilsTests.cpp | 48 + .../tests_cpu/layers_tests/MklDnnTests.cpp | 111 + .../test/tests_cpu/layers_tests/MmapTests.cpp | 57 + .../layers_tests/MultiDataTypeTests.cpp | 1984 +++++ .../layers_tests/MultiDeviceTests.cpp | 72 + .../layers_tests/NDArrayConstructorsTests.cu | 208 + .../layers_tests/NDArrayCudaBasicsTests.cu | 2200 ++++++ .../layers_tests/NDArrayListTests.cpp | 75 + .../tests_cpu/layers_tests/NDArrayTests.cpp | 2682 +++++++ .../tests_cpu/layers_tests/NDArrayTests2.cpp | 1309 +++ .../tests_cpu/layers_tests/NativeOpsTests.cpp | 1612 ++++ .../test/tests_cpu/layers_tests/NlpTests.cpp | 476 ++ .../test/tests_cpu/layers_tests/NodeTests.cpp | 75 + .../layers_tests/OmpLaunchHelperTests.cpp | 125 + .../tests_cpu/layers_tests/OneOffTests.cpp | 390 + .../tests_cpu/layers_tests/OpTrackerTests.cpp | 71 + .../tests_cpu/layers_tests/OpTupleTests.cpp | 61 + .../tests_cpu/layers_tests/PairwiseTests.cpp | 52 + .../tests_cpu/layers_tests/ParityOpsTests.cpp | 1700 ++++ .../layers_tests/PerformanceTests.cpp | 146 + .../layers_tests/PlaygroundTests.cpp | 1686 ++++ .../layers_tests/PrimitivesTests.cpp | 94 + .../tests_cpu/layers_tests/ProtoBufTests.cpp | 112 + .../layers_tests/QuantizationTests.cpp | 72 + .../test/tests_cpu/layers_tests/RNGTests.cpp | 1299 +++ .../tests_cpu/layers_tests/ResultSetTests.cpp | 51 + .../tests_cpu/layers_tests/SanityTests.cpp | 64 + .../tests_cpu/layers_tests/ScalarTests.cpp | 238 + .../tests_cpu/layers_tests/ScopeTests.cpp | 167 + .../layers_tests/ServerRelatedTests.cpp | 190 + .../tests_cpu/layers_tests/ShapeTests.cpp | 336 + .../tests_cpu/layers_tests/ShapeTests2.cpp | 820 ++ .../layers_tests/ShapeUtilsTests.cpp | 295 + .../tests_cpu/layers_tests/SingleDimTests.cpp | 187 + .../tests_cpu/layers_tests/SortCpuTests.cpp | 106 + .../tests_cpu/layers_tests/SortCudaTests.cu | 126 + .../layers_tests/SparseUtilsTest.cpp | 248 + .../tests_cpu/layers_tests/StashTests.cpp | 90 + .../tests_cpu/layers_tests/StringTests.cpp | 880 +++ .../tests_cpu/layers_tests/SwitchTests.cpp | 253 + .../test/tests_cpu/layers_tests/TadTests.cpp | 445 ++ .../tests_cpu/layers_tests/ThreadsTests.cpp | 271 + .../tests_cpu/layers_tests/TypeCastTests.cpp | 74 + .../layers_tests/VariableProxyTests.cpp | 175 + .../layers_tests/VariableSpaceTests.cpp | 222 + .../tests_cpu/layers_tests/VariableTests.cpp | 227 + .../tests_cpu/layers_tests/WorkspaceTests.cpp | 291 + .../tests_cpu/layers_tests/WorkspaceTests.cu | 62 + .../tests_cpu/layers_tests/suppressions.txt | 2 + .../test/tests_cpu/layers_tests/testinclude.h | 54 + .../test/tests_cpu/layers_tests/testlayers.h | 43 + .../tests_cpu/libnd4j_tests/CMakeLists.txt | 301 + .../tests_cpu/resources/arr_3,4_float32.npy | Bin 0 -> 176 bytes .../resources/assert_type_rank2_int64.fb | Bin 0 -> 1664 bytes .../tests_cpu/resources/assertsomething.fb | Bin 0 -> 11736 bytes .../test/tests_cpu/resources/avg_pooling3d.fb | Bin 0 -> 4808 bytes .../channels_last_b1_k2_s1_d1_SAME_crelu.fb | Bin 0 -> 4320 bytes .../src/test/tests_cpu/resources/cond_true.fb | Bin 0 -> 4088 bytes .../test/tests_cpu/resources/identity_n_2.fb | Bin 0 -> 1120 bytes .../src/test/tests_cpu/resources/non2d_0A.fb | Bin 0 -> 1960 bytes .../src/test/tests_cpu/resources/non2d_1.fb | Bin 0 -> 4600 bytes .../src/test/tests_cpu/resources/pad_1D.fb | Bin 0 -> 1336 bytes .../resources/reduce_all_rank2_d0_keep.fb | Bin 0 -> 1240 bytes .../tests_cpu/resources/scalar_float32.fb | Bin 0 -> 11832 bytes .../tests_cpu/resources/scatter_nd_update.fb | Bin 0 -> 2584 bytes .../tests_cpu/resources/simpleif_0_alt.fb | Bin 0 -> 7648 bytes .../test/tests_cpu/resources/simplewhile_1.fb | Bin 0 -> 12504 bytes .../tests_cpu/resources/simplewhile_nested.fb | Bin 0 -> 26808 bytes ...se_sz1_float32_nodynamic_noname_noshape.fb | Bin 0 -> 2920 bytes ...it_sz1_float32_nodynamic_noname_noshape.fb | Bin 0 -> 3384 bytes ...ay_stack_sz3-1_int32_dynamic_name_shape.fb | Bin 0 -> 6280 bytes ...ack_sz1_int64_nodynamic_noname_shape2-3.fb | Bin 0 -> 6400 bytes .../test/tests_cpu/resources/while_iter3.fb | Bin 0 -> 9512 bytes .../src/test/tests_cpu/run_minifier.sh | 153 + .../src/test/tests_cpu/run_tests.sh | 68 + 152 files changed, 102529 insertions(+), 117 deletions(-) create mode 100644 cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/generic/decoder/ctc_beam_op.cpp create mode 100644 cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/headers/decoder.h create mode 100644 cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/helpers/ctc.h delete mode 100644 cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/helpers/ctcLoss.h create mode 100644 cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/helpers/impl/ctcBeam.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/CMakeLists.txt create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/CMakeLists.txt.in create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/AllTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ArrayOptionsTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/AtomicTests.cu create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/AttentionTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/BackpropTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/BitwiseUtilsTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/BooleanOpsTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/BroadcastableOpsTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/BrodcastTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/CMakeLists.txt create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/CnpyTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ConditionalTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ConstantShapeHelperTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ContextTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ConvolutionTests1.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ConvolutionTests2.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/CuDnnTests.cu create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/CudaBasicsTests1.cu create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/CudaBasicsTests2.cu create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/CudaExtraArgumentsTests.cu create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/CudaLaunchHelperTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DataBufferTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DataBufferTestsCuda.cu create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DataTypesValidationTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests1.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests10.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests11.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests12.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests13.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests14.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests15.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests16.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests17.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests18.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests19.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests2.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests3.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests4.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests5.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests6.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests7.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests8.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests9.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTestsCuda1.cu create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/EmptyTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ExtraArgumentsTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/FlatBuffersTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/FlatUtilsTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/GraphExecutionerTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/GraphHolderTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/GraphRandomGeneratorTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/GraphStateTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/GraphTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/HashUtilsTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/HelpersTests1.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/HelpersTests2.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/IndexingTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/JavaInteropCudaTests.cu create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/JavaInteropTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/LambdaTests.cu create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/LaunchContextCudaTests.cu create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/LegacyOpsCudaTests.cu create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/LegacyOpsTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ListOperationsTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/LoopCoordsHelperTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/MemoryUtilsTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/MklDnnTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/MmapTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/MultiDataTypeTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/MultiDeviceTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/NDArrayConstructorsTests.cu create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/NDArrayListTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/NDArrayTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/NDArrayTests2.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/NativeOpsTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/NlpTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/NodeTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/OmpLaunchHelperTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/OneOffTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/OpTrackerTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/OpTupleTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/PairwiseTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ParityOpsTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/PerformanceTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/PlaygroundTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/PrimitivesTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ProtoBufTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/QuantizationTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/RNGTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ResultSetTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/SanityTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ScalarTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ScopeTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ServerRelatedTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ShapeTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ShapeTests2.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ShapeUtilsTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/SingleDimTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/SortCpuTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/SortCudaTests.cu create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/SparseUtilsTest.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/StashTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/StringTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/SwitchTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/TadTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ThreadsTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/TypeCastTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/VariableProxyTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/VariableSpaceTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/VariableTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/WorkspaceTests.cpp create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/WorkspaceTests.cu create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/suppressions.txt create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/testinclude.h create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/testlayers.h create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/libnd4j_tests/CMakeLists.txt create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/resources/arr_3,4_float32.npy create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/resources/assert_type_rank2_int64.fb create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/resources/assertsomething.fb create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/resources/avg_pooling3d.fb create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/resources/channels_last_b1_k2_s1_d1_SAME_crelu.fb create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/resources/cond_true.fb create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/resources/identity_n_2.fb create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/resources/non2d_0A.fb create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/resources/non2d_1.fb create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/resources/pad_1D.fb create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/resources/reduce_all_rank2_d0_keep.fb create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/resources/scalar_float32.fb create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/resources/scatter_nd_update.fb create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/resources/simpleif_0_alt.fb create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/resources/simplewhile_1.fb create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/resources/simplewhile_nested.fb create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/resources/tensor_array_close_sz1_float32_nodynamic_noname_noshape.fb create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/resources/tensor_array_split_sz1_float32_nodynamic_noname_noshape.fb create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/resources/tensor_array_stack_sz3-1_int32_dynamic_name_shape.fb create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/resources/tensor_array_unstack_sz1_int64_nodynamic_noname_shape2-3.fb create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/resources/while_iter3.fb create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/run_minifier.sh create mode 100644 cavis-native/cavis-native-lib/src/test/tests_cpu/run_tests.sh diff --git a/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/CustomOperations.h b/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/CustomOperations.h index 339914213..e583b5661 100644 --- a/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/CustomOperations.h +++ b/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/CustomOperations.h @@ -49,6 +49,7 @@ #include #include #include +#include #include #include #include diff --git a/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/generic/decoder/ctc_beam_op.cpp b/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/generic/decoder/ctc_beam_op.cpp new file mode 100644 index 000000000..15ff3951d --- /dev/null +++ b/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/generic/decoder/ctc_beam_op.cpp @@ -0,0 +1,111 @@ +/******************************************************************************* + * Copyright (c) 2021 Deeplearning4j Contributors + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + *******************************************************************************/ + +// +// @author AbdelRauf +// + +#include +#include +#include + +namespace sd { +namespace ops { + + +////////////////////////////////////////////////////////////////////////// +CUSTOM_OP_IMPL(ctc_beam, 2, 3, false, 0, -2) { + + auto logit = INPUT_VARIABLE(0); + auto sequence_length = INPUT_VARIABLE(1); + auto result_sequences = OUTPUT_VARIABLE(0); + auto result_probs = OUTPUT_VARIABLE(1); + auto result_sequences_length = OUTPUT_VARIABLE(2); + auto arg_size = block.getIArguments()->size(); + auto normalize_logits = block.numB() > 0 ? B_ARG(0) : false; + + int blank_index = arg_size>0 ? INT_ARG(0) : -1; + int beam_width = arg_size>1 ? INT_ARG(1) : 25; + int nbest_len = arg_size>2? INT_ARG(2): 1; + + REQUIRE_TRUE(logit->rankOf()==3, 0, "Ctc Beam Search: logit Input fails to meet rank requirement {BATCH_LEN, MAX_FRAME_LEN, CLASS_LEN }: %i == 3 ", logit->rankOf()); + REQUIRE_TRUE(sequence_length->rankOf()==1, 0, "Ctc Beam Search: sequence frame length (sequence_length) Input fails to meet rank requirement {BATCH_LEN}: %i == 1 ", sequence_length->rankOf()); + + REQUIRE_TRUE(result_sequences->rankOf()==3, 0, "Ctc Beam Search: result_sequences Output fails to meet rank requirement {BATCH_LEN, NBEST_LEN, MAX_FRAME_LEN }: %i == 3 ", result_sequences->rankOf()); + REQUIRE_TRUE(result_probs->rankOf()==2, 0, "Ctc Beam Search: result_probs Output fails to meet rank requirement {BATCH_LEN, NBEST_LEN}: %i == 2 ", result_probs->rankOf()); + REQUIRE_TRUE(result_sequences_length->rankOf()==2, 0, "Ctc Beam Search: result_sequences_length Output fails to meet rank requirement {BATCH_LEN, NBEST_LEN}: %i == 2 ", result_sequences_length->rankOf()); + + auto batchSize0 = logit->sizeAt(0); + auto batchSize1 = sequence_length->sizeAt(0); + auto batchSize2 = result_sequences->sizeAt(0); + auto batchSize3 = result_probs->sizeAt(0); + auto batchSize4 = result_sequences_length->sizeAt(0); + + bool check_batches = (batchSize0 == batchSize1) && (batchSize2 == batchSize3); + check_batches = check_batches && (batchSize0 == batchSize4) && (batchSize0 == batchSize2); + + REQUIRE_TRUE(nbest_len>0 && nbest_len <=beam_width, 0, "Ctc Beam Search: nbest_len %i should be > 0 and <= %i", nbest_len, beam_width); + REQUIRE_TRUE(check_batches, 0, "Ctc Beam Search: All batch sizes should be %i", batchSize0); + auto max_t = logit->sizeAt(1); + REQUIRE_TRUE(result_sequences->sizeAt(1) == nbest_len && result_sequences->sizeAt(2) == max_t , 0, "Ctc Beam Search: shape of the result_sequences should be {%i, %i, %i} but got { %i, %i, %i}", + batchSize0, nbest_len, max_t, batchSize1, result_sequences->sizeAt(1), result_sequences->sizeAt(2)); + REQUIRE_TRUE(result_probs->sizeAt(1) == nbest_len , 0, "Ctc Beam Search: shape of the result_probs should be {%i, %i} but got { %i, %i}", + batchSize0, nbest_len, batchSize3, result_sequences->sizeAt(1)); + REQUIRE_TRUE(result_sequences_length->sizeAt(1) == nbest_len , 0, "Ctc Beam Search: shape of the result_sequences_length should be {%i, %i} but got { %i, %i}", + batchSize0, nbest_len, batchSize4, result_sequences_length->sizeAt(1)); + REQUIRE_TRUE(result_sequences->ews()==1 && result_sequences->ordering()=='c', 0, "Ctc Beam Search: result_sequences output should be ews()==1 and c order: %d == ews(1) %c == order(c) ", result_sequences->ews(), result_sequences->ordering()); + REQUIRE_TRUE(result_probs->ews()==1 && result_probs->ordering()=='c', 0, "Ctc Beam Search: result_probs output should be ews()==1 and c order: %d == ews(1) %c == order(c) ", result_probs->ews(), result_probs->ordering()); + REQUIRE_TRUE(result_sequences_length->ews()==1 && result_sequences_length->ordering()=='c', 0, "Ctc Beam Search: result_sequences_length output should be ews()==1 and c order: %d == ews(1) %c == order(c) ", result_sequences_length->ews(), result_sequences_length->ordering()); + + sd::ops::helpers::beamSearch(*logit, *sequence_length, *result_sequences, *result_probs, *result_sequences_length, blank_index, beam_width, nbest_len, normalize_logits); + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +DECLARE_TYPES(ctc_beam) { + getOpDescriptor()->setAllowedInputTypes({ALL_FLOATS}) + ->setAllowedInputTypes(1,{ALL_INDICES}) + ->setAllowedOutputTypes(0, {ALL_INDICES}) + ->setAllowedOutputTypes(1, {ALL_FLOATS}) + ->setAllowedOutputTypes(2, {ALL_INDICES}); +} + +////////////////////////////////////////////////////////////////////////// +DECLARE_SHAPE_FN(ctc_beam) { + auto logitShapeInfo = inputShape->at(0); + auto sequenceShapeInfo = inputShape->at(1); + auto arg_size = block.getIArguments()->size(); + + auto nbest_len = arg_size>2? INT_ARG(2): 1; + + REQUIRE_TRUE(logitShapeInfo[0] ==3 , 0, "Ctc Beam Search: logit Input fails to meet rank requirement {BATCH_LEN, MAX_FRAME_LEN, CLASS_LEN }: %i == 3", + logitShapeInfo[0]); + + auto batch_size = shape::shapeOf(logitShapeInfo)[0] ; + auto max_t = shape::shapeOf(logitShapeInfo)[1] ; + + auto dtype_float = ArrayOptions::dataType(logitShapeInfo); + auto dtype_index = ArrayOptions::dataType(sequenceShapeInfo); + + auto output0 = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(dtype_index, 'c', {batch_size, nbest_len, max_t})); + auto output1 = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(dtype_float, 'c', {batch_size, nbest_len})); + auto output2 = ConstantShapeHelper::getInstance().createShapeInfo(ShapeDescriptor(dtype_index, 'c', {batch_size, nbest_len})); + return SHAPELIST(output0, output1, output2); +} + +} +} diff --git a/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/generic/loss/ctcLoss.cpp b/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/generic/loss/ctcLoss.cpp index d37c16233..6c739b1c4 100644 --- a/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/generic/loss/ctcLoss.cpp +++ b/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/generic/loss/ctcLoss.cpp @@ -21,7 +21,7 @@ #include #include -#include +#include namespace sd { namespace ops { @@ -43,16 +43,16 @@ CUSTOM_OP_IMPL(ctc_loss, 4, 1, false, 0, 1) { REQUIRE_TRUE(targetLabelLengths->rankOf()==1, 0, "CtcLoss: target label length fails to meet rank requirement (batch_size): %i == 1 ", targetLabelLengths->rankOf()); REQUIRE_TRUE(logitInputLengths->rankOf()==1, 0, "CtcLoss: logit Input lengths fails to meet rank requirement (batch_size): %i == 1 ", logitInputLengths->rankOf()); - int batchSize0 = targetLabels->sizeAt(0); - int batchSize1 = logitInput->sizeAt(0); - int batchSize2 = targetLabelLengths->sizeAt(0); - int batchSize3 = logitInputLengths->sizeAt(0); - int batchSize4 = outputLosses->sizeAt(0); + auto batchSize0 = targetLabels->sizeAt(0); + auto batchSize1 = logitInput->sizeAt(0); + auto batchSize2 = targetLabelLengths->sizeAt(0); + auto batchSize3 = logitInputLengths->sizeAt(0); + auto batchSize4 = outputLosses->sizeAt(0); bool check_batches = (batchSize0 == batchSize1) && (batchSize2 == batchSize3); check_batches = check_batches && (batchSize0 == batchSize4) && (batchSize0 == batchSize2); - REQUIRE_TRUE(check_batches, 0, "CtcLoss: All batch sizes should be equal %i", batchSize0); + REQUIRE_TRUE(check_batches, 0, "CtcLoss: All batch sizes should be %i", batchSize0); REQUIRE_TRUE(outputLosses->isSameShape(targetLabelLengths), 0, "CtcLoss: wrong shape of output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(targetLabelLengths).c_str(), ShapeUtils::shapeAsString(outputLosses).c_str()); auto emptyGradients = NDArrayFactory::empty(); @@ -95,16 +95,16 @@ CUSTOM_OP_IMPL(ctc_loss_grad, 4, 1, false, 0, 1) { REQUIRE_TRUE(targetLabelLengths->rankOf()==1, 0, "CtcLoss: target label length fails to meet rank requirement (batch_size): %i == 1 ", targetLabelLengths->rankOf()); REQUIRE_TRUE(logitInputLengths->rankOf()==1, 0, "CtcLoss: logit Input lengths fails to meet rank requirement (batch_size): %i == 1 ", logitInputLengths->rankOf()); - int batchSize0 = targetLabels->sizeAt(0); - int batchSize1 = logitInput->sizeAt(0); - int batchSize2 = targetLabelLengths->sizeAt(0); - int batchSize3 = logitInputLengths->sizeAt(0); - int batchSize4 = outputGradients->sizeAt(0); + auto batchSize0 = targetLabels->sizeAt(0); + auto batchSize1 = logitInput->sizeAt(0); + auto batchSize2 = targetLabelLengths->sizeAt(0); + auto batchSize3 = logitInputLengths->sizeAt(0); + auto batchSize4 = outputGradients->sizeAt(0); bool check_batches = (batchSize0 == batchSize1) && (batchSize2 == batchSize3); check_batches = check_batches && (batchSize0 == batchSize4) && (batchSize0 == batchSize2); - REQUIRE_TRUE(check_batches, 0, "CtcLoss Gradient: All batch sizes should be equal %i", batchSize0); + REQUIRE_TRUE(check_batches, 0, "CtcLoss Gradient: All batch sizes should be %i", batchSize0); REQUIRE_TRUE(outputGradients->isSameShape(logitInput), 0, "CtcLoss Gradient: wrong shape of output array, expected is %s but got %s instead !", ShapeUtils::shapeAsString(logitInput).c_str(), ShapeUtils::shapeAsString(outputGradients).c_str()); auto emptyLoss = NDArrayFactory::empty(); diff --git a/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/headers/decoder.h b/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/headers/decoder.h new file mode 100644 index 000000000..f72b257e3 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/headers/decoder.h @@ -0,0 +1,65 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author AbdelRauf +// + +#ifndef LIBND4J_HEADERS_DECODER_H +#define LIBND4J_HEADERS_DECODER_H + +#include + +namespace sd { +namespace ops { + + /** + * Implementation of CTC beam search + * + * Input arrays: + * 0: logits - logits NDArray logit NDArray {BATCH_LEN, MAX_FRAME_LEN, CLASS_LEN }. It should include a blank label as well. type float + * 1: sequence_length - NDArray {BATCH_LEN} length of frames. type integer + * + * Input integer arguments (IArgs): + * 0: blank_index the index of the blank label in logits. default is last class. CLASS_LEN-1 + * 1: beam_width the width of the beam search. default is 25 + * 2: nbest_len the number of top best results that should be returned. default is 1 + * NOTE: if it is > beam_width it will be defaulted to beam_width size. + * Input bool argument (BArgs): + * 0: normalize_logit when its true it will normalize logits. by default it is assumed logit contains already normalized log-probabilities + * Output array: + * 0: result_sequences NDArray {BATCH_LEN, NBEST, MAX_FRAME_LEN} result sequences. + * NOTE: result_sequences NdArray should be c order and have ews == 1. type integer + * 1: result_probs NDArray {BATCH_LEN, NBEST} negative log probabilities for each sequence. type float + * NOTE: result_probs NdArray should be c order and have ews == 1 + * 2: result_sequence_length NDArray {BATCH_LEN, NBEST} the length of the each sequence. type integer + * NOTE: result_sequence_length NdArray should be c order and have ews == 1 + * + * NOTE: + * maximum value of integer indexing type should be >= CLASS_LEN to make sense. And also it should consider frame lengthes as well. + * For now this case is mostly fine as only Indexing types are allowed as integer. + */ + #if NOT_EXCLUDED(OP_ctc_beam) + DECLARE_CUSTOM_OP(ctc_beam, 2, 3, false, 0, -2); + #endif + + +} +} + +#endif diff --git a/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/headers/loss.h b/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/headers/loss.h index 303470792..aecdc9cb0 100644 --- a/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/headers/loss.h +++ b/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/headers/loss.h @@ -365,7 +365,8 @@ namespace ops { * * Input arrays: * 0: labels - labels NDArray {BATCH_LEN, MAX_TARGET_LEN}, type integer - * 1: logits - logits NDArray {BATCH_LEN, FRAME_LEN, CLASS_LEN }. log softmax of rnn output. It should include a blank label as well, type float + * 1: logits - logits NDArray {BATCH_LEN, FRAME_LEN, CLASS_LEN }. It should include a blank label as well, type float + * NOTE: we expect normalized logits (softmax normalized logarithm values for logits). * 2: targetLabelLengths - Length of label sequence in labels NDArray {BATCH_LEN}, type integer * 3: logitsLengths - Length of input sequence in logits NDArray {BATCH_LEN}, type integer * diff --git a/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/helpers/cpu/ctcLoss.cpp b/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/helpers/cpu/ctcLoss.cpp index 1920167f2..7fbd7142b 100644 --- a/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/helpers/cpu/ctcLoss.cpp +++ b/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/helpers/cpu/ctcLoss.cpp @@ -25,7 +25,7 @@ #include #include #include -#include +#include namespace sd { @@ -34,26 +34,11 @@ namespace sd namespace helpers { - //choose ptr[index*element_stride] - template - typename std::enable_if::type - element(Type *ptr, int index, int element_stride) - { - return ptr[index * element_stride]; - } - - //choose ptr[index] assuming element_stride is 1 - template - typename std::enable_if::type - element(Type *ptr, int index, int element_stride) - { - return ptr[index]; - } template Type forward(Type *alphaPtr, const Nd4jLong &incA, const Type *logP, const Nd4jLong &incP, const IndexType *lbl, const Nd4jLong &lenSB, const Nd4jLong &lenT, const int &blankIndex, int elwiseP = 1, int elwiseS = 1) { - Type negInf = -DataTypeUtils::infOrMax(); + Type negInf = negative_infinity(); //initialize alphas at t=0 alphaPtr[0] = element(logP, blankIndex, elwiseP); //alphaPtr[1] =logP[lbl[0]]; @@ -82,23 +67,17 @@ namespace sd // {t-1,s} Type alphaS = alphaPrevPtr[s]; Type alphaS_1 = s > 0 ? alphaPrevPtr[s - 1] : negInf; - Type cMax = std::max(alphaS, alphaS_1); //logP[currentInd] or logP[currentInd*elwiseP] auto currentProb = element(logP, currentInd, elwiseP); // if blank or the same as previous if (s > 1 && currentInd != blankIndex && currentInd != element(lbl, ind - 1, elwiseS)) { Type alphaS_2 = alphaPrevPtr[s - 2]; - cMax = std::max(cMax, alphaS_2); - if (cMax == negInf) - cMax = 0; - alphaPtr[s] = std::log(std::exp(alphaS - cMax) + std::exp(alphaS_1 - cMax) + std::exp(alphaS_2 - cMax)) + cMax + currentProb; + alphaPtr[s] = log_sum_exp(alphaS, alphaS_1, alphaS_2) + currentProb; } else { - if (cMax == negInf) - cMax = 0; - alphaPtr[s] = std::log(std::exp(alphaS - cMax) + std::exp(alphaS_1 - cMax)) + cMax + currentProb; + alphaPtr[s] = log_sum_exp(alphaS, alphaS_1) + currentProb; } } @@ -109,8 +88,7 @@ namespace sd } auto logP0 = alphaPrevPtr[lenSB - 1]; auto logP1 = alphaPrevPtr[lenSB - 2]; - auto cMax = std::max(logP0, logP1); - return -(std::log(std::exp(logP0 - cMax) + std::exp(logP1 - cMax)) + cMax); + return -log_sum_exp(logP0, logP1 ); } //#undef CALCULATE_ALL_IN_ONE_FRAME_LOOP @@ -121,7 +99,7 @@ namespace sd int elwiseP = 1, int elwiseS = 1, int elwiseG = 1) { - Type negInf = -DataTypeUtils::infOrMax(); + Type negInf = negative_infinity(); Nd4jLong lenSB = 2 * lenS + 1; auto origBetta = bettaPtr; auto origLogP = logP; @@ -197,23 +175,17 @@ namespace sd // {t-1,s} Type bettaS = bettaPrevPtr[s]; Type bettaS_1 = s < lenSB - 1 ? bettaPrevPtr[s + 1] : negInf; - Type cMax = std::max(bettaS, bettaS_1); //logP[currentInd] auto currentProb = element(logP, currentInd, elwiseP); // if blank or the same as previous if (s < lenSB - 2 && currentInd != blankIndex && currentInd != element(lbl, ind + 1, elwiseS)) { Type bettaS_2 = bettaPrevPtr[s + 2]; - cMax = std::max(cMax, bettaS_2); - if (cMax == negInf) - cMax = 0; - bettaPtr[s] = std::log(std::exp(bettaS - cMax) + std::exp(bettaS_1 - cMax) + std::exp(bettaS_2 - cMax)) + cMax + currentProb; + bettaPtr[s] = log_sum_exp(bettaS, bettaS_1, bettaS_2) + currentProb; } else { - if (cMax == negInf) - cMax = 0; - bettaPtr[s] = std::log(std::exp(bettaS - cMax) + std::exp(bettaS_1 - cMax)) + cMax + currentProb; + bettaPtr[s] = log_sum_exp(bettaS, bettaS_1) + currentProb; } #if defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP) @@ -262,8 +234,7 @@ namespace sd auto logBP0 = bettaPrevPtr[0]; auto logBP1 = bettaPrevPtr[1]; - auto bcMax = std::max(logBP0, logBP1); - auto blogLoss = -(std::log(std::exp(logBP0 - bcMax) + std::exp(logBP1 - bcMax)) + bcMax); + auto blogLoss = -log_sum_exp(logBP0, logBP1); #if !defined(CALCULATE_ALL_IN_ONE_FRAME_LOOP) //alpha*betta @@ -289,8 +260,7 @@ namespace sd } else { - Type cMax = std::max(currentGrad, alphaBettaS); - currentGrad = std::log(std::exp(currentGrad - cMax) + std::exp(alphaBettaS - cMax)) + cMax; + currentGrad = log_sum_exp(currentGrad, alphaBettaS); } //alphaPtr[s] = alphaBettaS; } @@ -345,7 +315,7 @@ namespace sd auto bufferPtr = bufferArr.bufferAsT(); auto incA = bufferArr.stridesOf()[1]; auto bettaBufferPtr = bufferPtr + bufferArr.stridesOf()[0]; - Type negInf = -DataTypeUtils::infOrMax(); + Type negInf = negative_infinity(); #if 1 if (gradPtr) @@ -421,7 +391,8 @@ namespace sd elwiseLL = logLosses.stridesOf()[0]; logLossPtr = logLosses.bufferAsT(); } - + //defaulting blankIndex to the last class if its incorrect or -1 + if (blankIndex > maxLenS || blankIndex < 0) blankIndex = maxLenS - 1; auto func = [logP, batchP, incP, elwiseP, lenK, lenTPtr, lenSPtr, logLossPtr, lblPtr, maxLenT, maxLenS, batchLbl, blankIndex, elwiseT, elwiseLL, elwiseSLen, elwiseS, &gradients](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void { Type *gradPtr = nullptr; @@ -450,7 +421,7 @@ namespace sd lenS = lenS > maxLenS ? maxLenS : lenS; if (lenS <= 0 || lenT <= 0) { - resultLoss = -DataTypeUtils::infOrMax(); + resultLoss = negative_infinity(); } else { @@ -475,7 +446,7 @@ namespace sd lenS = lenS > maxLenS ? maxLenS : lenS; if (lenS <= 0 || lenT <= 0) { - resultLoss = -DataTypeUtils::infOrMax(); + resultLoss = negative_infinity(); } else { @@ -495,11 +466,11 @@ namespace sd void ctcLoss(graph::Context& block, const NDArray &logits, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex){ - BUILD_DOUBLE_SELECTOR(logits.dataType(), targetLabels.dataType(), ctc_loss_, (logits, targetLabels, logitsLengths, targetLabelLengths, logLosses, gradients, blankIndex), FLOAT_TYPES, INDEXING_TYPES); - } + BUILD_DOUBLE_SELECTOR(logits.dataType(), targetLabels.dataType(), ctc_loss_, (logits, targetLabels, logitsLengths, targetLabelLengths, logLosses, gradients, blankIndex), FLOAT_TYPES, INDEXING_TYPES); + } - BUILD_DOUBLE_TEMPLATE(template void ctc_loss_, (const NDArray &logits, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex), FLOAT_TYPES, INDEXING_TYPES); + BUILD_DOUBLE_TEMPLATE(template void ctc_loss_, (const NDArray &logits, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex), FLOAT_TYPES, INDEXING_TYPES); } // namespace helpers diff --git a/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/helpers/ctc.h b/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/helpers/ctc.h new file mode 100644 index 000000000..46155f6c7 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/helpers/ctc.h @@ -0,0 +1,153 @@ +/******************************************************************************* + * Copyright (c) 2021 Deeplearning4j Contributors + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + *******************************************************************************/ + +// +// @author AbdelRauf +// + +#ifndef LIBND4J_HELPERS_CTCLOSS_H +#define LIBND4J_HELPERS_CTCLOSS_H + +#include +#include +#include +#include +namespace sd { +namespace ops { +namespace helpers { + + //#define LOGIT_SOFTMAX_NORMALIZATION 1 + + template + constexpr T negative_infinity() + { + return -DataTypeUtils::infOrMax(); + } + + //choose ptr[index*element_stride] + template + typename std::enable_if::type + element(Type *ptr, int index, int element_stride) + { + return ptr[index * element_stride]; + } + + //choose ptr[index] assuming element_stride is 1 + template + typename std::enable_if::type + element(Type *ptr, int index, int element_stride) + { + return ptr[index]; + } + + template + T local_log(T x) + { + if (x > 0) + { + return (sd::math::p_log(x)); + } + return (negative_infinity()); + } + + template + T log_sum_exp(T x1, T x2) + { + //substituting this : std::log(std::exp(arg1 - cMax) + std::exp(arg2 - cMax)) + cMax + //if arg1==cMax : std::log(1 + std::exp(arg2 - cMax)) + cMax + if (x1 >= x2) + { + //x1 is max + return (x1 + local_log(1 + sd::math::p_exp(x2 - x1))); + } + //x2 is max + return (x2 + local_log(1 + sd::math::p_exp(x1 - x2))); + } + + template + T log_sum_exp(T arg1, T arg2, T arg3) + { + auto c_max = std::max(arg1, arg2); + c_max = std::max(c_max, arg3); + if (negative_infinity() == c_max) + { + c_max = 0; + } + return sd::math::p_log(sd::math::p_exp(arg1 - c_max) + sd::math::p_exp(arg2 - c_max) + sd::math::p_exp(arg3 - c_max)) + c_max; + } + + template + Type softmax_normalization_term(const Type* log_p, const uint64_t len_c, const uint64_t element_stride) + { + Type max_p; + for (auto c = 0; c < len_c; ++c) { + max_p = std::max(max_p, element(log_p, c, element_stride)); + } + // Get normalization term of softmax: log(sum(exp(logit[j]-max_p))). + Type logsumexp = Type(0.0); + for (auto c = 0; c < len_c; ++c) { + logsumexp += sd::math::p_exp(element(log_p, c, element_stride) - max_p); + } + logsumexp = sd::math::p_log(logsumexp); + return max_p + logsumexp; + } + + /** + * @brief Implementation of CTC loss function + * References: + Connectionist Temporal Classification - Labeling Unsegmented Sequence Data + with Recurrent Neural Networks: + [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891) + ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf)) + * + * @param block Context + * @param logits NDArray {BATCH_LEN, MAX_FRAME_LEN, CLASS_LEN }. It should include a blank label as well. + * NOTE: log softmax of rnn output. so we expect softmax normalized + * @param targetLabels NDArray {BATCH_LEN, MAX_TARGET_LEN} + * @param logitsLengths NDArray {BATCH_LEN} Length of input sequence in logits + * @param targetLabelLengths NDArray {BATCH_LEN} Length of label sequence in labels + * @param logLosses NDArray {BATCH_LEN} or EMPTY. if empty it will be skipped. negative log probabilities of loss + * @param gradients NDArray {BATCH_LEN, MAX_FRAME_LEN, CLASS_LEN } or EMPTY. gradients + * @param blankIndex index of the blank label in logits + */ + void ctcLoss(graph::Context& block, const NDArray &logitsInput, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex); + + + /** + * @brief Implementation of CTC beam search + * + * @param logit NDArray {BATCH_LEN, MAX_FRAME_LEN, CLASS_LEN }. log probabilities. It should include a blank label as well. + * @param sequence_length NDArray {BATCH_LEN} length of frames. type integer + * @param result_sequences NDArray {BATCH_LEN, NBEST, MAX_FRAME_LEN} result sequences. + * NOTE: result_sequences NdArray should be c order and have ews == 1. type integer. + * @param result_probs NDArray {BATCH_LEN, NBEST} negative log probabilities for each sequence. + * NOTE: result_probs NdArray should be c order and have ews == 1 + * @param result_sequences_length NDArray {BATCH_LEN, NBEST} the length of each sequence in result_sequences. + * NOTE: result_sequences_length NdArray should be c order and have ews == 1 + * @param blank_index the index of the blank label in logits + * @param beam_width the width of the beam search. + * @param nbest_len the number of top best results that should be returned. if it is greather than beam_width it will be defaulted to beam_width size. + * @param normalize_logits when its true it will normalize logits. by default it is assumed logit contains already normalized log-probabilities + * NOTE: + * maximum value of integer type should be >= CLASS_LEN to make sense. And also user should consider frame lengthes as well. + */ +void beamSearch(const NDArray& logit, const NDArray& sequence_length, NDArray& result_sequences, NDArray& result_probs, NDArray& result_sequences_length, int blank_index, int beam_width , int nbest_len, bool normalize_logits); +} +} +} + + +#endif diff --git a/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/helpers/ctcLoss.h b/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/helpers/ctcLoss.h deleted file mode 100644 index 320442456..000000000 --- a/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/helpers/ctcLoss.h +++ /dev/null @@ -1,55 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2021 Deeplearning4j Contributors - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - *******************************************************************************/ - -// -// @author AbdelRauf -// - -#ifndef LIBND4J_HELPERS_CTCLOSS_H -#define LIBND4J_HELPERS_CTCLOSS_H - -#include -#include - -namespace sd { -namespace ops { -namespace helpers { - - /** - * @brief Implementation of CTC loss function - * References: - Connectionist Temporal Classification - Labeling Unsegmented Sequence Data - with Recurrent Neural Networks: - [Graves et al., 2006](https://dl.acm.org/citation.cfm?id=1143891) - ([pdf](http://www.cs.toronto.edu/~graves/icml_2006.pdf)) - * - * @param block Context - * @param logits NDArray {BATCH_LEN, FRAME_LEN, CLASS_LEN }. log softmax of rnn output. It should include a blank label as well. - * @param targetLabels NDArray {BATCH_LEN, MAX_TARGET_LEN} - * @param logitsLengths NDArray {BATCH_LEN} Length of input sequence in logits - * @param targetLabelLengths NDArray {BATCH_LEN} Length of label sequence in labels - * @param logLosses NDArray {BATCH_LEN} or EMPTY. if empty it will be skipped. negative log probabilities of loss - * @param gradients NDArray {BATCH_LEN, FRAME_LEN, CLASS_LEN } or EMPTY. gradients - * @param blankIndex index of the blank label in logits - */ - void ctcLoss(graph::Context& block, const NDArray &logitsInput, const NDArray &targetLabels, const NDArray &logitsLengths, const NDArray &targetLabelLengths, NDArray &logLosses, NDArray &gradients, int blankIndex); - -} -} -} - - -#endif // LIBND4J_ADDBIAS_H \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/helpers/cuda/ctcLoss.cu b/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/helpers/cuda/ctcLoss.cu index 953dd6984..ead5f3ffb 100644 --- a/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/helpers/cuda/ctcLoss.cu +++ b/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/helpers/cuda/ctcLoss.cu @@ -25,7 +25,7 @@ #include #include #include -#include +#include namespace sd { diff --git a/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/helpers/impl/ctcBeam.cpp b/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/helpers/impl/ctcBeam.cpp new file mode 100644 index 000000000..9a5da12c5 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/main/cpp/blas/ops/declarable/helpers/impl/ctcBeam.cpp @@ -0,0 +1,718 @@ + +/******************************************************************************* + * Copyright (c) 2021 Deeplearning4j Contributors + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + *******************************************************************************/ + +// +// @author AbdelRauf +// + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace sd { +namespace ops { +namespace helpers { + +template +struct BeamProb +{ + T total = negative_infinity(); + T non_blank = negative_infinity(); + T blank = negative_infinity(); //log(1) +}; + + +template +struct DefaultInvalid +{ + static constexpr T value = T(); +}; + + +template +struct DefaultInvalid::value>::type> +{ + static constexpr T value = static_cast(-1); +}; + +template +struct SequenceNode +{ + //intrusive double links + SequenceNode* prev = nullptr; + SequenceNode* next = nullptr; + + //sequence prefix/parent + SequenceNode* prefix = nullptr; + + T value = DefaultInvalid::value; + + int state = 0; + + void markAsFullyExtended() + { + state |= 1; + } + + void increaseRef() + { + //we will have just two copies in bad case. so just or + state = state | 2; + } + + void decreaseRef() + { + //we will have just two cases in bad case, so just remove that + state = state & (-2); + } + + bool safeToRemove() + { + + if (state & 1) return false; + + decreaseRef(); + //we do not want to remove parent nodes in our case. otherwise just returning state<=1 is ok + return state == 0; + } + + bool isFullyExtended() const { return state & 1; } +}; + +/*** + * Sequence container. + * + * NOTE: it is not thread-safe + * + * Extend path - O(1) + * Remove path - O(1) + * Generating Sequence with backtracking prefix: O(n) + * + * Note: Sequence container is implemented primitively and only usable within this task. + * As it does not behave as a fully capable tree. some cases should be handled manually + * + * Here is special cases that should be handled manually to exploit tree/graph behaviour: + * + * Extending new path value: + * + * To extend the path one need to give path and value and in return get new_path: + * new_path = container.extendPath ( path, new_value ); + * + * Also note that: + * SequenceContainer has already default empty path as a beginning point for paths. + * So as an initial node one should use it. + * initial_path = container.getEmptyPath(); + * + * Adding new path that could be already in container: + * + * Assume we have two paths that can overlap in next step + * 1st path: node#0() -> node#1(1) => generated sequence {},{1} + * 2nd path: node#0() -> node#1(1) -> node#2(2) => generated sequence {},{1}, {2} + * + * While extending the first path with value (2). it will be: + * + * node#0() -> node#0(1) -> node#( either new or old)(2) => generated sequence {},{1}, {2} + * + * For some tasks its not desired to have additional node that will generate the same sequence. + * For example: + * Assume you wanted to use it as sequence entry in map with just (entry->prefix, entry->value). + * so in that case having different paths is not correct and will not be unique in map. + * + * there is not direct way to handle that in our container other than searching. + * So one should look for the node with prefix node#1(1) and value(2) and return that node instead of adding new one + + * Fortunately, for our beam search case: + * + * we need only look for such overlapped cases within the candidates list. + * which makes it easy to determine them beforehand while finding and marking overlapped cases. instead of looking for it in SequenceContainer + * + * Removing the same nodes multiple times: + * It is fast to remove nodes. As nodes can be stored externally One should follow this rule: + * + * One should not remove the same node twice as it will lead to double free. as Nodes are pointers the same applies to removing a copy + * + * There could be cases where you would like to store copy of nodes. in that cases you can use below method to be able to safely remove: + * node should have mutable method named safeToRemove(). + * Basic implementation will be decreasing reference/copy counts and returning true if it is safe to delete + * + * + */ +template +class SequenceContainer +{ +public: + SequenceContainer() : count_(1) + { + empty_path = new SequenceNode(); + current_ = empty_path; + } + + SequenceContainer(const SequenceContainer& s) = delete; + + SequenceContainer(SequenceContainer&& other) noexcept + { + this->current_ = other.current_; + other.current_ = nullptr; + } + + SequenceContainer& operator=(const SequenceContainer& other) = delete; + + SequenceContainer& operator=(SequenceContainer&& other) noexcept + { + if (this != other) + { + clear(); + this->current_ = other.current_; + this->count_ = other.count_; + other.current_ = nullptr; + other.count_ = 0; + } + return *this; + } + + SequenceNode* getEmptyPath() + { + return current_; + } + + SequenceNode* extendPath(SequenceNode* prefix, T value) + { + auto new_node = new SequenceNode(); + + new_node->value = value; + new_node->prefix = prefix; + //add in the holder + new_node->next = nullptr; + new_node->prev = current_; + /*std::cout << "add " << (long long)new_node << std::endl; + print_seq1(new_node);*/ + if (current_) current_->next = new_node; + + current_ = new_node; + count_++; + return new_node; + } + + void remove(SequenceNode* seq) + { + if (seq == nullptr) return; + + if (!seq->safeToRemove()) return; + + SequenceNode* previous = seq->prev; + SequenceNode* next = seq->next; + if (previous) previous->next = next; + if (next) next->prev = previous; + + if (current_ == seq) + { + current_ = previous; + } + //std::cout << "remove " << (long long)seq << " " << std::endl; + //print_seq1(seq); + delete seq; + count_--; + } + + static std::vector getSequence(SequenceNode* seq, size_t reserve_size = 1024) + { + std::vector ret; + ret.reserve(reserve_size); + SequenceNode* backtrack = seq; + while (backtrack) + { + ret.push_back(backtrack->value); + backtrack = backtrack->prefix; + } + if (ret.size() > 1) + { + //remove last default node + ret.pop_back(); + //reverse + std::reverse(std::begin(ret), std::end(ret)); + return ret; + } + return {}; + } + + void clear() + { + //destruct all nodes + SequenceNode* del = current_; + //int i = 0; + while (del) + { + //++i; + SequenceNode* temp = del->prev; + delete del; + del = temp; + } + current_ = nullptr; + //assert(count_==i); + } + + ~SequenceContainer() + { + clear(); + } + +private: + SequenceNode* current_ = nullptr; + + SequenceNode* empty_path = nullptr; + + int count_ = 0; +}; + +template +struct BeamEntry +{ + SequenceNode* sequence{}; + BeamProb prob; +}; + + +template +struct BeamEntryEx +{ + BeamEntry entry; + //keep indices for lookUp + int index_as_child = -1; + int index_as_parent = -1; + int children_count = 0; +}; + +template +struct LookUpEntry +{ + U last_c; //this is is the same as node->value. just we added for the speed + SequenceNode* node = nullptr; + int next_beam_index = -1; //index inside next_beam array +}; + +template +bool compare_beam_prob(const BeamEntry& i1, const BeamEntry& i2) +{ + return (i1.prob.total > i2.prob.total); +} + + +template +T pr(const int c, const BeamProb& beam_prob, const SequenceNode* seq, const T prob) +{ + return seq->value == c ? beam_prob.blank + prob : beam_prob.total + prob; +} + +template +void inner_beam_search(const Type* log_p, const uint64_t inc_p, IndexType* result_sequence, const uint64_t inc_res_seq, + const uint64_t max_len_t, Type* result_prob, IndexType* result_seq_length, uint64_t len_t, + const uint64_t len_c, const int blank_index, int beam_width, int nbest_len, bool normalize_logits, const uint64_t element_stride = 1L) +{ + + using BeamEntryType = BeamEntry; + using BeamEntryTypeEx = BeamEntryEx; + + if (beam_width < 1) beam_width = 1; + if (nbest_len > beam_width) nbest_len = beam_width; + //if len_t is greater than max_len_t truncate it + len_t = len_t > max_len_t ? max_len_t : len_t; + + SequenceContainer sequence_container; + BeamEntryType empty; + empty.prob.blank = 0; + empty.prob.total = log_sum_exp(empty.prob.blank, empty.prob.non_blank); + empty.sequence = sequence_container.getEmptyPath(); + + //vectors: we will use it as array, here + std::vector last_beams; + std::vector next_beams; + last_beams.resize(beam_width); + //as we skip blank indexes the count is beam_width * len_c + next_beams.resize(beam_width * len_c); + last_beams[0].entry = empty; + last_beams[0].index_as_child = -1; + last_beams[0].index_as_parent = -1; + last_beams[0].children_count = 0; + auto last_beam_size = 1; + + // lookupContainer: + // it will keep sorted entries. so we will just move and compare the entry + // in each step there will be overlapped cases + // the size of overlapped cases in last_beam[0:beam_width]: + // as we have beam_width size in each step after sort and pruning + // there is at least one item who will not have any parent + // and for the rest (beam_width-1) it will check has_parent_in_container() ? 1 : 0 + // so maximum size of overlapped pairs is beam_width-1 + + std::vector> lookUp; + lookUp.resize(beam_width - 1); + + //additional storage to sort overlapped case by classes + std::vector> child_class_sorter_help; + child_class_sorter_help.resize(beam_width - 1); + Type norm_offset = 0; + + for (uint64_t t = 0; t < len_t; t++) + { + auto next_beam_size = 0; + if (normalize_logits){ + norm_offset = softmax_normalization_term(log_p, len_c, element_stride); + } + for (auto j = 0; j < last_beam_size; j++) + { + SequenceNode* seq = last_beams[j].entry.sequence; + auto& cur_prob = last_beams[j].entry.prob; + //if len(seq) > 0 then + const auto log_p_blank = element(log_p, blank_index, element_stride); + Type blank_prob, non_blank_prob; + //log_p[seq->value] + non_blank_prob = seq->value != -1 ? (element(log_p, seq->value, element_stride) + cur_prob.non_blank) : negative_infinity(); + blank_prob = log_p_blank + cur_prob.total; + + if (normalize_logits){ + non_blank_prob = non_blank_prob - norm_offset; + blank_prob = blank_prob - norm_offset; + } + + auto look_up_beam_index = -1; + + if (last_beams[j].index_as_child != -1) + { + //check entry + look_up_beam_index = lookUp[last_beams[j].index_as_child].next_beam_index; + } + + if (look_up_beam_index == -1) + { + BeamEntryType entry; + entry.sequence = seq; + entry.prob.blank = blank_prob; + entry.prob.non_blank = non_blank_prob; + entry.prob.total = log_sum_exp(blank_prob, non_blank_prob); + next_beams[next_beam_size] = entry; + //map if its overlapped one. in this case just being child is enough + if (last_beams[j].index_as_child != -1) + { + lookUp[last_beams[j].index_as_child].next_beam_index = next_beam_size; + } + ++next_beam_size; + } + else + { + //note: here we took as ref & + auto& entry_prob = next_beams[look_up_beam_index].prob; + entry_prob.blank = log_sum_exp(entry_prob.blank, blank_prob); + entry_prob.non_blank = log_sum_exp(entry_prob.non_blank, non_blank_prob); + entry_prob.total = log_sum_exp(entry_prob.blank, entry_prob.non_blank); + } + //check to see if it is overlapped parent + auto start_index = last_beams[j].index_as_parent; + auto end_index = last_beams[j].index_as_parent + last_beams[j].children_count; + + for (int c = 0; c < len_c; c++) + { + if (c == blank_index) continue; + + const auto prob = element(log_p, c, element_stride);//log_p[c]; + + non_blank_prob = pr(c, cur_prob, seq, prob); + if(normalize_logits) non_blank_prob = non_blank_prob - norm_offset; + //extend by new character + auto look_up_beam_index_ex = -1; + int found_index = -1; + + //get index within array if its that class index + if (start_index < end_index && lookUp[start_index].last_c == c){ + look_up_beam_index_ex = lookUp[start_index].next_beam_index; + + found_index = start_index; + ++start_index; + } + + if (look_up_beam_index_ex == -1) + { + BeamEntryType entry; + SequenceNode* extended_sequence; + if (found_index!=-1) + { + extended_sequence = lookUp[found_index].node; + //assing next_beam_index for lookup + lookUp[found_index].next_beam_index = next_beam_size; + extended_sequence->increaseRef(); + } + else { + extended_sequence = sequence_container.extendPath(seq, c); + } + entry.prob.non_blank = non_blank_prob; + entry.prob.total = non_blank_prob; + entry.sequence = extended_sequence; + next_beams[next_beam_size] = entry; + + ++next_beam_size; + } + else + { + auto& entry_prob = next_beams[look_up_beam_index_ex].prob; + entry_prob.non_blank = log_sum_exp(entry_prob.non_blank, non_blank_prob); + entry_prob.total = log_sum_exp(entry_prob.total, non_blank_prob); + } + } //iteration over classes + + //mark it as extended + seq->markAsFullyExtended(); + + } //iteration over beams + + log_p += inc_p; + + last_beam_size = std::min(next_beam_size, beam_width); +#if !defined(NTH_ELEMENT) + //sort next beams to get candidates + std::partial_sort(std::begin(next_beams), + std::begin(next_beams) + last_beam_size, + std::begin(next_beams) + next_beam_size, compare_beam_prob); + +#else + std::nth_element(std::begin(next_beams), + std::begin(next_beams) + last_beam_size, + std::begin(next_beams) + next_beam_size, compare_beam_prob); + +#endif + + if (t < len_t) + { + //copy top beams + for (int j = 0; j < last_beam_size; j++) + { + last_beams[j].entry = next_beams[j]; + last_beams[j].index_as_child = -1; + last_beams[j].index_as_parent = -1; + last_beams[j].children_count = 0; + } + + //delete sequences from the sequence_holder to decrease memory + for (auto j = beam_width; j < next_beam_size; j++) + { + sequence_container.remove(next_beams[j].sequence); + } + + //check overlapping cases and create lookUp with sorted classes as well + int look_up_index = 0; + for (auto j = 0; j < last_beam_size; j++) + { + //if it is not parent node then there is not any need to check + if (last_beams[j].entry.sequence->isFullyExtended()) + { + auto parent_seq=last_beams[j].entry.sequence; + int children_count = 0; + for (int k = 0; k < last_beam_size; k++) + { + auto current = last_beams[k].entry.sequence; + if (current->prefix == parent_seq) + { + child_class_sorter_help[children_count].first = current->value; + child_class_sorter_help[children_count].second = k ; + ++children_count ; + } + } + + if (children_count > 0) + { + + //sort by class + if(children_count<2){ + // + if (children_count > 1 && child_class_sorter_help[0].first > child_class_sorter_help[1].first) + { + std::swap(child_class_sorter_help[0], child_class_sorter_help[1]); + } + } + else + { + std::sort(std::begin(child_class_sorter_help), std::begin(child_class_sorter_help) + children_count, + [](const std::pair& left, const std::pair& right) { + return left.first < right.first; + }); + } + last_beams[j].index_as_parent = look_up_index; + last_beams[j].children_count = children_count; + + for (int l = 0; l < children_count; l++) + { + + int c = child_class_sorter_help[l].first; + int k = child_class_sorter_help[l].second; + //std::cout << c <<" , " << k << std::endl; + last_beams[k].index_as_child = look_up_index; + auto seq = last_beams[k].entry.sequence; + lookUp[look_up_index].last_c = c; + lookUp[look_up_index].node = seq; + lookUp[look_up_index].next_beam_index = -1; + //next one + ++look_up_index; + } + }//add sorted lookUps + + } + } //overlap_direction identified to speed up lookUp + + } + + }//iterate over t +#if defined(NTH_ELEMENT) + //use sort for n elements as only nth_element was used + std::sort(std::begin(next_beams), std::begin(next_beams) + last_beam_size, compare_beam_prob); +#endif + //store nbest results + if (nbest_len <= last_beam_size) { + for (int j = 0; j < nbest_len; j++) + { + auto top = next_beams[j]; + auto result_vector = SequenceContainer::getSequence(top.sequence, len_t); + const auto seq_size = result_vector.size(); + + result_prob[j] = top.prob.total; + result_seq_length[j] = seq_size; + //copy sequence + for (auto s = 0; s < seq_size; s++) + { + result_sequence[s] = result_vector[s]; + } + + result_sequence += inc_res_seq; + + } + } + else + { + for (int j = 0; j < nbest_len; j++) + { + result_prob[j] = negative_infinity(); + result_seq_length[j] = 0;; + } + } + return; +} + +template +void +beamSearch_(const NDArray& logit, const NDArray& sequence_length, NDArray& result_sequences, NDArray& result_probs, NDArray& result_sequences_length, int blank_index, int beam_width, int nbest_len, bool normalize_logits ) +{ + + const auto shapes = logit.shapeOf(); + const auto strides = logit.stridesOf(); + const auto rank = logit.rankOf(); + + const IndexType* len_t_ptr = nullptr; + uint64_t element_stride_t = 1; + + //checks before + if (rank < 2) return; + auto batch_len = rank > 2 ? shapes[0] : 1; + auto max_len_t = shapes[rank - 2]; + auto len_c = shapes[rank - 1]; + + if (len_c < 1 || max_len_t < 1) return; + //defaulting blankIndex to the last class if its incorrect or -1 + if (blank_index > len_c || blank_index < 0) blank_index = static_cast(len_c) - 1; + if (sequence_length.rankOf() == 1 && sequence_length.shapeOf()[0] == batch_len) + { + len_t_ptr = sequence_length.bufferAsT(); + element_stride_t = sequence_length.stridesOf()[0]; + } + + //strides + auto batch_stride = rank > 2 ? strides[0] : 0; + auto inc_p = strides[rank - 2]; + auto element_stride = logit.stridesOf()[rank - 1]; + + auto logits_ptr = logit.bufferAsT(); + +#if defined(ASSERT_INNER) + //result_probs should be [batch_len, nbest_len] + assert(result_probs.ews() == 1 && result_probs.rankOf() == 2 && result_probs.shapeOf()[0] == batch_len && result_probs.shapeOf()[1] == nbest_len); + //result sequence should be [batch_len, nbest_len, max_len_t] + assert(result_sequences.ews() == 1 && result_sequences.rankOf() == 3 && result_sequences.shapeOf()[0] == batch_len && result_sequences.shapeOf()[1] == nbest_len + && result_sequences.shapeOf()[2] == max_len_t); +#endif + auto result_seq_ptr = result_sequences.bufferAsT(); + auto result_probs_ptr = result_probs.bufferAsT(); + auto result_seq_length_ptr = result_sequences_length.bufferAsT(); + + const auto batch_stride_res = result_sequences.stridesOf()[0]; + const auto inc_res = result_sequences.stridesOf()[1]; + const auto batch_stride_res_prob = result_probs.stridesOf()[0]; + const auto batch_stride_res_seq_length = result_sequences_length.stridesOf()[0]; + auto func = [max_len_t, len_c, batch_stride, inc_p, element_stride, element_stride_t, logits_ptr, len_t_ptr, blank_index, beam_width, normalize_logits, + nbest_len, result_seq_ptr, result_seq_length_ptr, result_probs_ptr, batch_stride_res, inc_res, batch_stride_res_prob, batch_stride_res_seq_length] + (uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void + { + + auto ptr = logits_ptr + start * batch_stride; + + if (element_stride == 1) + { + //choose ews one + for (auto b = start; b < stop; b += increment) + { + auto prob_ptr = &(result_probs_ptr[b * batch_stride_res_prob]); + auto seq_length_ptr = &(result_seq_length_ptr[b * batch_stride_res_seq_length]); + auto seq_ptr = &(result_seq_ptr[b * batch_stride_res]); + + auto len_t = len_t_ptr ? len_t_ptr[b * element_stride_t] : max_len_t; + inner_beam_search(ptr, inc_p, seq_ptr, inc_res, max_len_t, prob_ptr, seq_length_ptr, len_t, len_c, blank_index, beam_width, nbest_len, normalize_logits); + + ptr += batch_stride; + + } + } + else + { + // element with stride case + for (auto b = start; b < stop; b += increment) + { + auto prob_ptr = &(result_probs_ptr[b * batch_stride_res_prob]); + auto seq_length_ptr = &(result_seq_length_ptr[b * batch_stride_res_seq_length]); + auto seq_ptr = &(result_seq_ptr[b * batch_stride_res]); + + auto len_t = len_t_ptr ? len_t_ptr[b * element_stride_t] : max_len_t; + inner_beam_search(ptr, inc_p, seq_ptr, inc_res, max_len_t, prob_ptr, seq_length_ptr, len_t, len_c, blank_index, beam_width, nbest_len, normalize_logits, element_stride); + + ptr += batch_stride; + } + } + }; + samediff::Threads::parallel_for(func, 0, batch_len, 1); + return; +} + +void beamSearch(const NDArray& logit, const NDArray& sequence_length, NDArray& result_sequences, NDArray& result_probs, NDArray& result_sequences_length, int blank_index, int beam_width , int nbest_len, bool normalize_logits = true){ + + BUILD_DOUBLE_SELECTOR(logit.dataType(), result_sequences.dataType(), beamSearch_, (logit, sequence_length, result_sequences, result_probs, result_sequences_length, blank_index, beam_width , nbest_len, normalize_logits), FLOAT_TYPES, INDEXING_TYPES); +} + + +BUILD_DOUBLE_TEMPLATE(template void beamSearch_, (const NDArray& logit, const NDArray& sequence_length, NDArray& result_sequences, NDArray& result_probs, NDArray& result_sequences_length, int blank_index, int beam_width , int nbest_len, bool normalize_logits), FLOAT_TYPES, INDEXING_TYPES); + +}}} diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/CMakeLists.txt b/cavis-native/cavis-native-lib/src/test/tests_cpu/CMakeLists.txt new file mode 100644 index 000000000..5de17a2d1 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/CMakeLists.txt @@ -0,0 +1,31 @@ +cmake_minimum_required(VERSION 3.15) +project(tests_cpu) + +# Download and unpack googletest at configure time +configure_file(CMakeLists.txt.in googletest-download/CMakeLists.txt) +execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . + RESULT_VARIABLE result + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/googletest-download ) +if(result) + message(FATAL_ERROR "CMake step for googletest failed: ${result}") +endif() +execute_process(COMMAND ${CMAKE_COMMAND} --build . + RESULT_VARIABLE result + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/googletest-download ) +if(result) + message(FATAL_ERROR "Build step for googletest failed: ${result}") +endif() + +# Prevent overriding the parent project's compiler/linker +# settings on Windows +set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) + +# Add googletest directly to our build. This defines +# the gtest and gtest_main targets. +add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/googletest-src + ${CMAKE_CURRENT_BINARY_DIR}/googletest-build + EXCLUDE_FROM_ALL) + +set(gtest_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/googletest-src) +#add_subdirectory(libnd4j_tests) +add_subdirectory(layers_tests) diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/CMakeLists.txt.in b/cavis-native/cavis-native-lib/src/test/tests_cpu/CMakeLists.txt.in new file mode 100644 index 000000000..a3cba4d27 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/CMakeLists.txt.in @@ -0,0 +1,16 @@ +cmake_minimum_required(VERSION 2.8.2) + +project(googletest-download NONE) + +include(ExternalProject) +ExternalProject_Add(googletest + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG release-1.10.0 + SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/googletest-src" + BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/googletest-build" + CMAKE_ARGS "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" +) diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/AllTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/AllTests.cpp new file mode 100644 index 000000000..f1328894a --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/AllTests.cpp @@ -0,0 +1,47 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 04.08.17. +// +// +#include "testlayers.h" +/* +#include "DenseLayerTests.cpp" +#include "NDArrayTests.cpp" +#include "VariableSpaceTests.cpp" +#include "VariableTests.cpp" +#include "DeclarableOpsTests.cpp" +#include "HashUtilsTests.cpp" +#include "WorkspaceTests.cpp" +#include "ConvolutionTests.cpp" +#include "TadTests.cpp" +#include "StashTests.cpp" +#include "SessionLocalTests.cpp" +#include "GraphTests.cpp" +#include "FlatBuffersTests.cpp" + */ +/////// + +//#include "CyclicTests.h" +// #include "ProtoBufTests.cpp" + +int main(int argc, char **argv) { + testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ArrayOptionsTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ArrayOptionsTests.cpp new file mode 100644 index 000000000..1436d77c3 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ArrayOptionsTests.cpp @@ -0,0 +1,111 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 13.01.2018. +// + +#include "testlayers.h" +#include +#include + +using namespace sd; + + +class ArrayOptionsTests : public testing::Test { +public: + Nd4jLong shape[8] = {2, 5, 5, 5, 1, 0, 1, 99}; +}; + +TEST_F(ArrayOptionsTests, TestShape_Basic_0) { + shape[5] = 1; + + ASSERT_TRUE(ArrayOptions::isNewFormat(shape)); + ASSERT_FALSE(ArrayOptions::isSparseArray(shape)); +} + + +TEST_F(ArrayOptionsTests, TestShape_Basic_1) { + shape[5] = 2; + + + ASSERT_TRUE(ArrayOptions::isNewFormat(shape)); + ASSERT_TRUE(ArrayOptions::isSparseArray(shape)); +} + +TEST_F(ArrayOptionsTests, TestShape_Basic_2) { + shape[5] = 258; + + ASSERT_TRUE(ArrayOptions::isNewFormat(shape)); + + ASSERT_TRUE(ArrayOptions::isSparseArray(shape)); + ASSERT_EQ(SpaceType::CONTINUOUS, ArrayOptions::spaceType(shape)); +} + +TEST_F(ArrayOptionsTests, TestShape_Basic_3) { + ASSERT_EQ(0, shape::extra(shape)); + + ASSERT_EQ(SpaceType::CONTINUOUS, ArrayOptions::spaceType(shape)); +} + +TEST_F(ArrayOptionsTests, TestShape_Basic_4) { + + ArrayOptions::setPropertyBits(shape, {ARRAY_HALF, ARRAY_QUANTIZED}); + + auto dtype = ArrayOptions::dataType(shape); + + ASSERT_FALSE(ArrayOptions::isSparseArray(shape)); + ASSERT_TRUE(sd::DataType::HALF == ArrayOptions::dataType(shape)); + ASSERT_EQ(sd::ArrayType::DENSE, ArrayOptions::arrayType(shape)); + ASSERT_EQ(sd::SpaceType::QUANTIZED, ArrayOptions::spaceType(shape)); +} + +TEST_F(ArrayOptionsTests, TestShape_Basic_5) { + ArrayOptions::setPropertyBits(shape, {ARRAY_SPARSE, ARRAY_INT, ARRAY_CSC}); + + ASSERT_TRUE(ArrayOptions::isSparseArray(shape)); + ASSERT_TRUE(sd::DataType::INT32 == ArrayOptions::dataType(shape)); + ASSERT_EQ(sd::SparseType::CSC, ArrayOptions::sparseType(shape)); +} + +TEST_F(ArrayOptionsTests, TestShape_Basic_6) { + ArrayOptions::setPropertyBits(shape, {ARRAY_EMPTY, ARRAY_INT, ARRAY_CSC}); + + ASSERT_EQ(sd::ArrayType::EMPTY, ArrayOptions::arrayType(shape)); +} + +TEST_F(ArrayOptionsTests, TestShape_Basic_7) { + ArrayOptions::setDataType(shape, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shape, sd::DataType::FLOAT32); + + ASSERT_EQ(sd::DataType::FLOAT32, ArrayOptions::dataType(shape)); +} + +TEST_F(ArrayOptionsTests, TestShape_Basic_8) { + ArrayOptions::setDataType(shape, sd::DataType::DOUBLE); + ArrayOptions::setDataType(shape, sd::DataType::FLOAT32); + + ASSERT_EQ(sd::DataType::FLOAT32, ArrayOptions::dataType(shape)); +} + +TEST_F(ArrayOptionsTests, TestShape_Basic_9) { + ArrayOptions::setDataType(shape, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shape, sd::DataType::DOUBLE); + + ASSERT_EQ(sd::DataType::DOUBLE, ArrayOptions::dataType(shape)); +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/AtomicTests.cu b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/AtomicTests.cu new file mode 100644 index 000000000..5c9c1aa1a --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/AtomicTests.cu @@ -0,0 +1,243 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include + + +using namespace sd; + + +class AtomicTests : public testing::Test { +public: + AtomicTests() { + // + } +}; + +template +static _CUDA_G void multiplyKernel(void *vbuffer, uint64_t length, void *vresult) { + auto buffer = reinterpret_cast(vbuffer); + auto result = reinterpret_cast(vresult); + + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (auto e = tid; e < length; e += gridDim.x * blockDim.x) { + auto rem = e % 4; + auto i = (e - rem) / 4; + + sd::math::atomics::nd4j_atomicMul(&result[i], buffer[e]); + } +} + +template +static void multiplyLauncher(void *vbuffer, uint64_t length, void *vresult) { + multiplyKernel<<<256, 256, 1024, *sd::LaunchContext::defaultContext()->getCudaStream()>>>(vbuffer, length, vresult); + auto err = cudaStreamSynchronize(*sd::LaunchContext::defaultContext()->getCudaStream()); + if (err != 0) + throw sd::cuda_exception::build("multiply failed", err); +} + +template +static _CUDA_G void sumKernel(void *vbuffer, uint64_t length, void *vresult) { + auto buffer = reinterpret_cast(vbuffer); + auto result = reinterpret_cast(vresult); + + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (auto e = tid; e < length; e += gridDim.x * blockDim.x) { + auto rem = e % 4; + auto i = (e - rem) / 4; + + sd::math::atomics::nd4j_atomicAdd(&result[i], buffer[e]); + } +} + +template +static void sumLauncher(void *vbuffer, uint64_t length, void *vresult) { + sumKernel<<<256, 256, 1024, *sd::LaunchContext::defaultContext()->getCudaStream()>>>(vbuffer, length, vresult); + auto err = cudaStreamSynchronize(*sd::LaunchContext::defaultContext()->getCudaStream()); + if (err != 0) + throw sd::cuda_exception::build("sum failed", err); +} + +template +static _CUDA_G void subKernel(void *vbuffer, uint64_t length, void *vresult) { + auto buffer = reinterpret_cast(vbuffer); + auto result = reinterpret_cast(vresult); + + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (auto e = tid; e < length; e += gridDim.x * blockDim.x) { + auto rem = e % 4; + auto i = (e - rem) / 4; + + sd::math::atomics::nd4j_atomicSub(&result[i], buffer[e]); + } +} + +template +static void subLauncher(void *vbuffer, uint64_t length, void *vresult) { + subKernel<<<256, 256, 1024, *sd::LaunchContext::defaultContext()->getCudaStream()>>>(vbuffer, length, vresult); + auto err = cudaStreamSynchronize(*sd::LaunchContext::defaultContext()->getCudaStream()); + if (err != 0) + throw sd::cuda_exception::build("sub failed", err); +} + +template +static _CUDA_G void divKernel(void *vbuffer, uint64_t length, void *vresult) { + auto buffer = reinterpret_cast(vbuffer); + auto result = reinterpret_cast(vresult); + + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + + for (auto e = tid; e < length; e += gridDim.x * blockDim.x) { + auto rem = e % 4; + auto i = (e - rem) / 4; + + sd::math::atomics::nd4j_atomicDiv(&result[i], buffer[e]); + } +} + +template +static void divLauncher(void *vbuffer, uint64_t length, void *vresult) { + divKernel<<<256, 256, 1024, *sd::LaunchContext::defaultContext()->getCudaStream()>>>(vbuffer, length, vresult); + auto err = cudaStreamSynchronize(*sd::LaunchContext::defaultContext()->getCudaStream()); + if (err != 0) + throw sd::cuda_exception::build("div failed", err); +} + +static void multiplyHost(NDArray &input, NDArray &output) { + BUILD_SINGLE_SELECTOR(input.dataType(), multiplyLauncher, (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), NUMERIC_TYPES); +} + +static void sumHost(NDArray &input, NDArray &output) { + BUILD_SINGLE_SELECTOR(input.dataType(), sumLauncher, (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), NUMERIC_TYPES); +} + +static void subHost(NDArray &input, NDArray &output) { + BUILD_SINGLE_SELECTOR(input.dataType(), subLauncher, (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), FLOAT_TYPES); +} + +static void divHost(NDArray &input, NDArray &output) { + BUILD_SINGLE_SELECTOR(input.dataType(), divLauncher, (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), FLOAT_TYPES); +} + +TEST_F(AtomicTests, test_multiply) { + std::vector dtypes = {sd::DataType::FLOAT32, sd::DataType::DOUBLE, sd::DataType::INT16, sd::DataType::HALF}; + + for (auto t:dtypes) { + nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); + NDArray input('c', {4, 25}, t); + NDArray output('c', {input.lengthOf() / 4}, t); + NDArray exp = output.ulike(); + + input.assign(2); + output.assign(2); + exp.assign(32); + + multiplyHost(input, output); + ASSERT_EQ(exp, output); + } +} + +TEST_F(AtomicTests, test_multiply_2) { + std::vector dtypes = {sd::DataType::FLOAT32, sd::DataType::DOUBLE, sd::DataType::HALF, sd::DataType::BFLOAT16}; + + for (auto t:dtypes) { + nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); + NDArray input('c', {4, 25}, t); + NDArray output('c', {input.lengthOf() / 4}, t); + NDArray exp = output.ulike(); + + input.assign(1.5); + output.assign(2); + exp.assign(10.125); + + multiplyHost(input, output); +// output.printBuffer("multiply 2"); + ASSERT_EQ(exp, output); + } +} + +TEST_F(AtomicTests, test_sum) { + std::vector dtypes = {sd::DataType::FLOAT32, sd::DataType::DOUBLE, sd::DataType::BFLOAT16, sd::DataType::HALF, sd::DataType::INT16}; + + for (auto t:dtypes) { + nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); + NDArray input('c', {4, 25}, t); + NDArray output('c', {input.lengthOf() / 4}, t); + NDArray exp = output.ulike(); + + input.assign(1); + output.assign(1); + exp.assign(5); + + sumHost(input, output); +// output.printIndexedBuffer("Sum"); + ASSERT_EQ(exp, output); + } +} + +TEST_F(AtomicTests, test_sub) { + std::vector dtypes = {sd::DataType::FLOAT32, sd::DataType::DOUBLE, sd::DataType::HALF}; + + for (auto t:dtypes) { + nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); + NDArray input('c', {4, 25}, t); + NDArray output('c', {input.lengthOf() / 4}, t); + NDArray exp = output.ulike(); + + input.assign(1); + output.assign(5); + exp.assign(1); + + subHost(input, output); +// output.printBuffer("Sub"); + + ASSERT_EQ(exp, output); + } +} + +TEST_F(AtomicTests, test_div) { + std::vector dtypes = {sd::DataType::FLOAT32, sd::DataType::DOUBLE, sd::DataType::BFLOAT16, sd::DataType::HALF}; + + for (auto t:dtypes) { + nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); + NDArray input('c', {4, 25}, t); + NDArray output('c', {input.lengthOf() / 4}, t); + NDArray exp = output.ulike(); + + input.assign(2); + output.assign(32); + exp.assign(2); + + divHost(input, output); +// output.printBuffer("Div"); + ASSERT_EQ(exp, output); + } +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/AttentionTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/AttentionTests.cpp new file mode 100644 index 000000000..74a7e5e6b --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/AttentionTests.cpp @@ -0,0 +1,222 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include + + +using namespace sd; + + +class AttentionTests : public testing::Test { +public: + AttentionTests() { + printf("\n"); + fflush(stdout); + } +}; + +TEST_F(AttentionTests, basic_dot_product_attention) { + auto keys = NDArrayFactory::create('c', {10, 4, 3}); + auto values = NDArrayFactory::create('c', {10, 4, 3}); + auto queries = NDArrayFactory::create('c', {10, 4, 1}); + + sd::ops::dot_product_attention op; + auto result = op.evaluate({&queries, &keys, &values}, {1, 0}); + ASSERT_EQ(Status::OK(), result.status()); +} + +/* +//Ignored: AB 2019/05/21 - Segmentation fault on on linux-ppc64le-cpu - https://github.com/deeplearning4j/deeplearning4j/issues/7657 +TEST_F(AttentionTests, basic_dot_product_attention_bp) { + auto keys = NDArrayFactory::create('c', {10, 4, 3}); + auto values = NDArrayFactory::create('c', {10, 4, 3}); + auto queries = NDArrayFactory::create('c', {10, 4, 1}); + auto eps = NDArrayFactory::create('c', {10, 4, 1}); + + sd::ops::dot_product_attention_bp op; + auto result = op.execute({&queries, &keys, &values, &eps}, {}, {1, 0}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + delete result; +} +*/ + +TEST_F(AttentionTests, basic_dot_product_attention_with_weights) { + auto keys = NDArrayFactory::create('c', {10, 4, 3}); + auto values = NDArrayFactory::create('c', {10, 4, 3}); + auto queries = NDArrayFactory::create('c', {10, 4, 1}); + + sd::ops::dot_product_attention op; + auto result = op.evaluate({&queries, &keys, &values}, {1, 1}); + ASSERT_EQ(Status::OK(), result.status()); +} + +TEST_F(AttentionTests, basic_dot_product_attention_with_mask) { + auto keys = NDArrayFactory::create('c', {10, 4, 3}); + auto values = NDArrayFactory::create('c', {10, 4, 3}); + auto queries = NDArrayFactory::create('c', {10, 4, 1}); + auto mask = NDArrayFactory::create('c', {10, 3}); + mask.assign(1.); + + sd::ops::dot_product_attention op; + auto result = op.evaluate({&queries, &keys, &values, &mask}, {1, 0}); + ASSERT_EQ(Status::OK(), result.status()); +} + +/* +//AB 2019/05/28 - Segfault on ppc64le - See issue #7657 +TEST_F(AttentionTests, basic_dot_product_attention_bp_with_mask) { + auto keys = NDArrayFactory::create('c', {10, 4, 3}); + auto values = NDArrayFactory::create('c', {10, 4, 3}); + auto queries = NDArrayFactory::create('c', {10, 4, 1}); + auto eps = NDArrayFactory::create('c', {10, 4, 1}); + auto mask = NDArrayFactory::create('c', {10, 3}); + mask.assign(1.); + + sd::ops::dot_product_attention_bp op; + auto result = op.execute({&queries, &keys, &values, &eps, &mask}, {}, {1, 0}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + delete result; +} + */ + +TEST_F(AttentionTests, multi_head_input_dot_product_attention_with_mask) { + auto keys = NDArrayFactory::create('c', {2, 5, 4, 3}); + auto values = NDArrayFactory::create('c', {2, 5, 4, 3}); + auto queries = NDArrayFactory::create('c', {2, 5, 4, 1}); + auto mask = NDArrayFactory::create('c', {2, 3}); + mask.assign(1.); + + sd::ops::dot_product_attention op; + auto result = op.evaluate({&queries, &keys, &values, &mask}, {1, 0}); + ASSERT_EQ(Status::OK(), result.status()); +} + +/* +//AB 2019/05/30 - Segfault on ppc64le - See issue #7657 +TEST_F(AttentionTests, multi_head_input_dot_product_attention_bp_with_mask) { + auto keys = NDArrayFactory::create('c', {2, 5, 4, 3}); + auto values = NDArrayFactory::create('c', {2, 5, 4, 3}); + auto queries = NDArrayFactory::create('c', {2, 5, 4, 1}); + auto eps = NDArrayFactory::create('c', {2, 5, 4, 1}); + auto mask = NDArrayFactory::create('c', {2, 3}); + mask.assign(1.); + + sd::ops::dot_product_attention_bp op; + auto result = op.execute({&queries, &keys, &values, &eps, &mask}, {}, {1, 0}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + delete result; +} + */ + + +TEST_F(AttentionTests, basic_multi_head_dot_product_attention) { + auto keys = NDArrayFactory::create('c', {10, 4, 5}); + auto values = NDArrayFactory::create('c', {10, 4, 5}); + auto queries = NDArrayFactory::create('c', {10, 4, 2}); + + auto Wk = NDArrayFactory::create('c', {2, 3, 4}); + auto Wv = NDArrayFactory::create('c', {2, 3, 4}); + auto Wq = NDArrayFactory::create('c', {2, 3, 4}); + auto Wo = NDArrayFactory::create('c', {2* 3, 4}); + + sd::ops::multi_head_dot_product_attention op; + auto result = op.evaluate({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo}, {1, 0}); + ASSERT_EQ(Status::OK(), result.status()); +} + +/* +//AB 2019/05/30 - Other attention BP tests are segfaulting on ppc64le - disabling this pre-emptively - See issue #7657 +TEST_F(AttentionTests, basic_multi_head_dot_product_bp_attention) { + auto keys = NDArrayFactory::create('c', {10, 4, 5}); + auto values = NDArrayFactory::create('c', {10, 4, 5}); + auto queries = NDArrayFactory::create('c', {10, 4, 2}); + + auto Wk = NDArrayFactory::create('c', {2, 3, 4}); + auto Wv = NDArrayFactory::create('c', {2, 3, 4}); + auto Wq = NDArrayFactory::create('c', {2, 3, 4}); + auto Wo = NDArrayFactory::create('c', {2* 3, 7}); + + auto eps = NDArrayFactory::create('c', {10, 7, 2}); + + + sd::ops::multi_head_dot_product_attention_bp op; + auto result = op.execute({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo, &eps}, {}, {1, 0}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + delete result; +} + */ + +TEST_F(AttentionTests, basic_multi_head_dot_product_attention_with_mask) { + auto keys = NDArrayFactory::create('c', {10, 4, 5}); + auto values = NDArrayFactory::create('c', {10, 4, 5}); + auto queries = NDArrayFactory::create('c', {10, 4, 2}); + + auto Wk = NDArrayFactory::create('c', {2, 3, 4}); + auto Wv = NDArrayFactory::create('c', {2, 3, 4}); + auto Wq = NDArrayFactory::create('c', {2, 3, 4}); + auto Wo = NDArrayFactory::create('c', {2* 3, 4}); + + auto mask = NDArrayFactory::create('c', {10, 5}); + mask.assign(1.); + + + sd::ops::multi_head_dot_product_attention op; + auto result = op.evaluate({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo, &mask}, {1, 0}); + ASSERT_EQ(Status::OK(), result.status()); +} + +/* +//AB 2019/05/30 - Other attention BP tests are segfaulting on ppc64le - disabling this pre-emptively - See issue #7657 +TEST_F(AttentionTests, basic_multi_head_dot_product_bp_attention_with_mask) { + auto keys = NDArrayFactory::create('c', {10, 4, 5}); + auto values = NDArrayFactory::create('c', {10, 4, 5}); + auto queries = NDArrayFactory::create('c', {10, 4, 2}); + + auto Wk = NDArrayFactory::create('c', {2, 3, 4}); + auto Wv = NDArrayFactory::create('c', {2, 3, 4}); + auto Wq = NDArrayFactory::create('c', {2, 3, 4}); + auto Wo = NDArrayFactory::create('c', {2* 3, 7}); + + auto eps = NDArrayFactory::create('c', {10, 7, 2}); + + auto mask = NDArrayFactory::create('c', {10, 5}); + mask.assign(1.); + + + sd::ops::multi_head_dot_product_attention_bp op; + auto result = op.execute({&queries, &keys, &values, &Wk, &Wv, &Wq, &Wo, &eps, &mask}, {}, {1, 0}, {}); + ASSERT_EQ(Status::OK(), result->status()); + + delete result; +} + */ diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/BackpropTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/BackpropTests.cpp new file mode 100644 index 000000000..09aa3e87b --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/BackpropTests.cpp @@ -0,0 +1,51 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 13.01.2018. +// + +#include "testlayers.h" +#include + +using namespace sd; +using namespace sd::ops; +using namespace sd::graph; + +class BackpropTests : public testing::Test { +public: + +}; + +TEST_F(BackpropTests, Test_Add_1) { + + NDArray x('c', {2, 3, 4}, sd::DataType::FLOAT32); + NDArray y('c', {3, 4}, sd::DataType::FLOAT32); + NDArray e('c', {2, 3, 4}, sd::DataType::FLOAT32); + + sd::ops::add_bp op; + auto result = op.evaluate({&x, &y, &e}); + + ASSERT_EQ(Status::OK(), result.status()); + + auto eps = result.at(0); + auto grad = result.at(1); + + ASSERT_TRUE(x.isSameShape(eps)); + ASSERT_TRUE(y.isSameShape(grad)); +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/BitwiseUtilsTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/BitwiseUtilsTests.cpp new file mode 100644 index 000000000..8c9b0ddb5 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/BitwiseUtilsTests.cpp @@ -0,0 +1,78 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 10.11.2017. +// + +#include "testlayers.h" +#include +#include +#include +#include + +using namespace sd; +using namespace sd::graph; + +class BitwiseUtilsTests : public testing::Test { +public: + +}; + +// oviously, this test will fail on big-endian machines, but who cares +TEST_F(BitwiseUtilsTests, Test_Runtime_Endianess_1) { + bool isBE = BitwiseUtils::isBE(); + + ASSERT_FALSE(isBE); +} + +TEST_F(BitwiseUtilsTests, Test_ValueBit_1) { + int idx = BitwiseUtils::valueBit(1); + + ASSERT_EQ(0, idx); +} + +TEST_F(BitwiseUtilsTests, Test_ValueBit_2) { + int idx = BitwiseUtils::valueBit(2); + + ASSERT_EQ(1, idx); +} + +TEST_F(BitwiseUtilsTests, Test_ValueBits_1) { + std::vector expected({1, 1}); + while (expected.size() < 32) + expected.push_back(0); + + std::vector result = BitwiseUtils::valueBits(3); + + ASSERT_EQ(32, result.size()); + ASSERT_EQ(expected, result); +} + +TEST_F(BitwiseUtilsTests, Test_ValueBits_2) { + int value = 48; + int flipped = BitwiseUtils::flip_bits(value); + + ASSERT_NE(value, flipped); + + auto o = BitwiseUtils::valueBits(value); + auto f = BitwiseUtils::valueBits(flipped); + + for (int e = 0; e < o.size(); e++) + ASSERT_NE(o.at(e), f.at(e)); +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/BooleanOpsTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/BooleanOpsTests.cpp new file mode 100644 index 000000000..78354111b --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/BooleanOpsTests.cpp @@ -0,0 +1,150 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 13.10.2017. +// + +#include "testlayers.h" +#include + +using namespace sd; +using namespace sd::ops; +using namespace sd::graph; + +class BooleanOpsTests : public testing::Test { +public: + +}; + + +TEST_F(BooleanOpsTests, LtTest_1) { + auto x = NDArrayFactory::create_(1.0f); + auto y = NDArrayFactory::create_(2.0f); + + sd::ops::lt_scalar op; + + + ASSERT_TRUE(op.verify({x, y})); + + delete x; + delete y; +} + +TEST_F(BooleanOpsTests, LtTest_2) { + auto x = NDArrayFactory::create_(2.0f); + auto y = NDArrayFactory::create_(1.0f); + + sd::ops::lt_scalar op; + + + ASSERT_FALSE(op.verify({x, y})); + + delete x; + delete y; +} + +TEST_F(BooleanOpsTests, Is_non_decreasing_1) { + auto x = NDArrayFactory::create('c', {2 , 2}, {1, 2, 4, 4}); + + sd::ops::is_non_decreasing op; + + ASSERT_TRUE(op.verify({&x})); + +} + +TEST_F(BooleanOpsTests, Is_non_decreasing_2) { + auto x = NDArrayFactory::create('c', {2 , 2}, {1, 2, 4, 3}); + + sd::ops::is_non_decreasing op; + + ASSERT_FALSE(op.verify({&x})); + +} + +TEST_F(BooleanOpsTests, Is_strictly_increasing_1) { + auto x = NDArrayFactory::create('c', {2 , 2}, {1, 2, 4, 5}); + + sd::ops::is_strictly_increasing op; + + ASSERT_TRUE(op.verify({&x})); + +} + +TEST_F(BooleanOpsTests, Is_strictly_increasing_2) { + auto x = NDArrayFactory::create('c', {2 , 2}, {1, 2, 3, 3}); + + sd::ops::is_strictly_increasing op; + + ASSERT_FALSE(op.verify({&x})); + +} + +TEST_F(BooleanOpsTests, Is_strictly_increasing_3) { + auto x = NDArrayFactory::create('c', {2 , 2}, {1, 2, 4, 3}); + + sd::ops::is_strictly_increasing op; + + ASSERT_FALSE(op.verify({&x})); +} + +TEST_F(BooleanOpsTests, Is_strictly_increasing_5) { + auto x = NDArrayFactory::create('c', {64, 512}); + x.linspace(1.0); + + sd::ops::is_strictly_increasing op; + + ASSERT_TRUE(op.verify({&x})); +} + +TEST_F(BooleanOpsTests, Is_strictly_increasing_6) { + auto x = NDArrayFactory::create('c', {64, 512}); + x.linspace(1.0); + + x.p(18, 1000323.f); + + sd::ops::is_strictly_increasing op; + + ASSERT_FALSE(op.verify({&x})); +} + +TEST_F(BooleanOpsTests, Is_numeric_tensor_1) { + auto x = NDArrayFactory::create('c', {2 , 2}, {1.f, 2.f, 4.f, 3.f}); + + sd::ops::is_numeric_tensor op; + + ASSERT_TRUE(op.verify({&x})); +} + +TEST_F(BooleanOpsTests, test_where_1) { + auto x = NDArrayFactory::create('c', {6}, { 1, -3, 4, 8, -2, 5 }); + auto y = NDArrayFactory::create('c', {6}, { 2, -3, 1, 1, -2, 1 }); + auto e = NDArrayFactory::create('c', {3}, { 4, 8, 5 }); + + sd::ops::choose op; + + auto result = op.evaluate({&x, &y}, {3}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + //z->printIndexedBuffer("z"); + + ASSERT_EQ(e, *z); +} + diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/BroadcastableOpsTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/BroadcastableOpsTests.cpp new file mode 100644 index 000000000..85e41d33b --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/BroadcastableOpsTests.cpp @@ -0,0 +1,857 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 23.11.17. +// + + +#include "testlayers.h" +#include +#include +#include + +using namespace sd; +using namespace sd::graph; + +class BroadcastableOpsTests : public testing::Test { +public: + +}; + +TEST_F(BroadcastableOpsTests, Test_Add_1) { + + NDArray x('c', {5, 5}, sd::DataType::FLOAT32); + NDArray y('c', {1, 5}, sd::DataType::FLOAT32); + NDArray exp('c', {5, 5}, sd::DataType::FLOAT32); + x.linspace(1); + y.linspace(1); + exp.linspace(1); + + //exp.printIndexedBuffer("E B"); + + exp.applyBroadcast(broadcast::Add, {1}, y, exp); + + sd::ops::add op; + auto result = op.evaluate({&x, &y}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + //exp.printIndexedBuffer("E A"); + //z->printIndexedBuffer("Z"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + + +TEST_F(BroadcastableOpsTests, Test_Multiply_1) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {1, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + x.linspace(1); + y.linspace(1); + exp.linspace(1); + + exp.applyBroadcast(broadcast::Multiply, {1}, y, exp); + + sd::ops::multiply op; + auto result = op.evaluate({&x, &y}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + + +TEST_F(BroadcastableOpsTests, Test_SquaredSubtract_1) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {1, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + x.linspace(1); + y.linspace(1); + exp.linspace(1); + + exp.applyBroadcast(broadcast::SquaredSubtract, {1}, y, exp); + + + sd::ops::squaredsubtract op; + auto result = op.evaluate({&x, &y}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + + +TEST_F(BroadcastableOpsTests, Test_ScalarBroadcast_1) { + auto x = NDArrayFactory::create('c', {1, 1}, {1}); + auto y = NDArrayFactory::create('c', {1, 3}, {0, 1, 2}); + auto exp = NDArrayFactory::create('c', {1,3}, {1, 0, -1}); + + sd::ops::subtract op; + auto result = op.evaluate({&x, &y}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + + +TEST_F(BroadcastableOpsTests, Test_ScalarBroadcast_2) { + auto x = NDArrayFactory::create('c', {1, 1}, {1}); + auto y = NDArrayFactory::create('c', {1, 3}, {0, 1, 2}); + auto exp = NDArrayFactory::create('c', {1,3}, {1, 2, 3}); + + sd::ops::add op; + auto result = op.evaluate({&x, &y}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + + +TEST_F(BroadcastableOpsTests, Test_Maximum_1) { + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 1, 2, 3, 2}); + auto row = NDArrayFactory::create('c', {1, 3}, {2, 2, 2}); + auto exp = NDArrayFactory::create('c', {2, 3}, {2, 2, 2, 2, 3, 2}); + + sd::ops::maximum op; + auto result = op.evaluate({&x, &row}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + + +TEST_F(BroadcastableOpsTests, Test_Minimum_1) { + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 1, 2, 3, 2}); + auto col = NDArrayFactory::create('c', {2, 1}, {2, 1}); + auto exp = NDArrayFactory::create('c', {2, 3}, {1, 2, 1, 1, 1, 1}); + + sd::ops::minimum op; + auto result = op.evaluate({&x, &col}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + + ASSERT_TRUE(exp.equalsTo(z)); + +} + + +TEST_F(BroadcastableOpsTests, Test_Shape_1) { + sd::ops::minimum op; + + Nd4jLong shapeX[] = {2, 2, 5, 5, 1, 8192, 1, 99}; + Nd4jLong shapeY[] = {2, 2, 5, 5, 1, 8192, 1, 99}; + ShapeList inputShape({shapeX, shapeY}); + VariableSpace vs; + Context ctx(1, &vs, false); + + auto shapes = op.calculateOutputShape(&inputShape, ctx); + + auto shapeZ = shapes->at(0); + ASSERT_TRUE(shape::shapeEquals(shapeX, shapeZ)); + + delete shapes; +} + +TEST_F(BroadcastableOpsTests, Test_Shape_2) { + sd::ops::minimum op; + + const Nd4jLong shapeX[] = {2, 1, 1, 1, 1, 8192, 1, 99}; + const Nd4jLong shapeY[] = {2, 2, 5, 5, 1, 8192, 1, 99}; + ShapeList inputShape({shapeX, shapeY}); + VariableSpace vs; + Context ctx(1, &vs, false); + + auto shapes = op.calculateOutputShape(&inputShape, ctx); + + auto shapeZ = shapes->at(0); + ASSERT_TRUE(shape::shapeEquals(shapeY, shapeZ)); + + delete shapes; +} + + +TEST_F(BroadcastableOpsTests, Test_Shape_3) { + sd::ops::minimum op; + + const Nd4jLong shapeX[] = {2, 5, 3, 1, 1, 8192, 1, 99}; + const Nd4jLong shapeY[] = {2, 1, 3, 3, 1, 8192, 1, 99}; + ShapeList inputShape({shapeX, shapeY}); + VariableSpace vs; + Context ctx(1, &vs, false); + + auto shapes = op.calculateOutputShape(&inputShape, ctx); + + auto shapeZ = shapes->at(0); + ASSERT_TRUE(shape::shapeEquals(shapeX, shapeZ)); + + delete shapes; +} + + +TEST_F(BroadcastableOpsTests, Test_Shape_4) { + sd::ops::minimum op; + + const Nd4jLong shapeX[] = {2, 5, 3, 1, 1, 8192, 1, 99}; + const Nd4jLong shapeY[] = {2, 5, 1, 1, 1, 8192, 1, 99}; + ShapeList inputShape({shapeX, shapeY}); + VariableSpace vs; + Context ctx(1, &vs, false); + + auto shapes = op.calculateOutputShape(&inputShape, ctx); + + auto shapeZ = shapes->at(0); + ASSERT_TRUE(shape::shapeEquals(shapeX, shapeZ)); + + delete shapes; +} + +// (2,1,3) + (4,3) = (2,4,3) + +TEST_F(BroadcastableOpsTests, Test_Shape_5) { + sd::ops::minimum op; + + const Nd4jLong shapeX[] = {3, 2, 1, 3, 3, 3, 1, 8192, 1, 99}; + const Nd4jLong shapeY[] = {2, 4, 3, 3, 1, 8192, 1, 99}; + const Nd4jLong shapeE[] = {3, 2, 4, 3, 12, 3, 1, 8192, 1, 99}; + ShapeList inputShape({shapeX, shapeY}); + VariableSpace vs; + Context ctx(1, &vs, false); + + auto shapes = op.calculateOutputShape(&inputShape, ctx); + + auto shapeZ = shapes->at(0); + ASSERT_TRUE(shape::shapeEquals(shapeE, shapeZ)); + + delete shapes; +} + +TEST_F(BroadcastableOpsTests, Test_Scalar_Add_1) { + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create(2.0f); + auto exp = NDArrayFactory::create('c', {2, 2}, {3, 4, 5, 6}); + + sd::ops::add op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + + +TEST_F(BroadcastableOpsTests, Test_Inplace_Output_1) { + auto x = NDArrayFactory::create('c', {2, 1, 3}); + auto y = NDArrayFactory::create('c', {4, 3}); + auto o = NDArrayFactory::create('c', {2, 4, 3}); + auto e = NDArrayFactory::create('c', {2, 4, 3}); + auto buffO1 = reinterpret_cast(o.buffer()); + y.assign(1.0f); + e.assign(1.0f); + + sd::ops::add op; + auto result = op.execute({&x, &y}, {&o}, {}, {}, {}); + ASSERT_EQ(Status::OK(), result); + + auto buffO2 = reinterpret_cast(o.buffer()); + + ASSERT_TRUE(e.isSameShape(o)); + ASSERT_TRUE(e.equalsTo(o)); + + ASSERT_TRUE(buffO1 == buffO2); +} + +TEST_F(BroadcastableOpsTests, Test_Subtract_1) { + + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); + auto e = NDArrayFactory::create('c', {2}, {1.0f, 0.0f}); + + auto z = x - y; + + ASSERT_TRUE(e.equalsTo(z)); +} + +TEST_F(BroadcastableOpsTests, Test_Subtract_2) { + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); + auto e = NDArrayFactory::create('c', {2}, {1.0f, 0.0f}); + + sd::ops::subtract op; + auto result = op.evaluate({&x, &y}); + auto z = result.at(0); + + ASSERT_TRUE(e.equalsTo(z)); +} + +TEST_F(BroadcastableOpsTests, Test_Subtract_3) { + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); + auto z = NDArrayFactory::create('c', {2}, {0.0f, 0.0f}); + auto e = NDArrayFactory::create('c', {2}, {1.0f, 0.0f}); + + sd::ops::subtract op; + auto result = op.execute({&x, &y}, {&z}, {}, {}, {}); + + ASSERT_EQ(Status::OK(), result); + ASSERT_TRUE(e.equalsTo(z)); +} + +TEST_F(BroadcastableOpsTests, Test_Subtract_4) { + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); + auto e = NDArrayFactory::create('c', {2}, {1.0f, 0.0f}); + + auto z = x.applyTrueBroadcast(BroadcastOpsTuple::Subtract(), y); + + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(BroadcastableOpsTests, Test_Subtract_5) { + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); + auto e = NDArrayFactory::create('c', {2}, {-1., 0.}); + + auto z = y - x; + + ASSERT_TRUE(e.equalsTo(z)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(BroadcastableOpsTests, Test_Subtract_6) { + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create(4.f); + auto e = NDArrayFactory::create(3.f); + + auto z = y - x; + + ASSERT_TRUE(e.equalsTo(z)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(BroadcastableOpsTests, Test_Subtract_7) { + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create(4.f); + auto e = NDArrayFactory::create(-3.f); + + auto z = x - y; + + ASSERT_TRUE(e.equalsTo(z)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(BroadcastableOpsTests, Test_Add_2) { + + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); + auto e = NDArrayFactory::create('c', {2}, {1.f, 2.f}); + + auto z = x + y; + + ASSERT_TRUE(e.equalsTo(z)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(BroadcastableOpsTests, Test_Add_3) { + + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create('c', {2}, {0.0f, 1.0f}); + auto e = NDArrayFactory::create('c', {2}, {1.f, 2.f}); + + auto z = y + x; + + ASSERT_TRUE(e.equalsTo(z)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(BroadcastableOpsTests, Test_Add_4) { + + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create(4.f); + auto e = NDArrayFactory::create(5.f); + + auto z = x + y; + + ASSERT_TRUE(e.equalsTo(z)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(BroadcastableOpsTests, Test_Add_5) { + + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create(4.f); + auto e = NDArrayFactory::create(5.f); + + auto z = y + x; + + ASSERT_TRUE(e.equalsTo(z)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(BroadcastableOpsTests, Test_Multiply_2) { + + auto x = NDArrayFactory::create(2.0f); + auto y = NDArrayFactory::create('c', {2}, {3.f, 4.f}); + auto e = NDArrayFactory::create('c', {2}, {6.f, 8.f}); + + auto z = y * x; + + ASSERT_TRUE(e.equalsTo(z)); +} + + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(BroadcastableOpsTests, Test_Multiply_3) { + + auto x = NDArrayFactory::create(2.0f); + auto y = NDArrayFactory::create('c', {2}, {3.f, 4.f}); + auto e = NDArrayFactory::create('c', {2}, {6.f, 8.f}); + + auto z = x * y; + + ASSERT_TRUE(e.equalsTo(z)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(BroadcastableOpsTests, Test_Multiply_4) { + + auto x = NDArrayFactory::create(2.0f); + auto y = NDArrayFactory::create(4.f); + auto e = NDArrayFactory::create(8.f); + + auto z = y * x; + + ASSERT_TRUE(e.equalsTo(z)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(BroadcastableOpsTests, Test_Multiply_5) { + + auto x = NDArrayFactory::create(2.0f); + auto y = NDArrayFactory::create(4.f); + auto e = NDArrayFactory::create(8.f); + + auto z = x * y; + + ASSERT_TRUE(e.equalsTo(z)); +} + +TEST_F(BroadcastableOpsTests, Test_Multiply_6) { + auto x = NDArrayFactory::create(2.0f); + auto y = NDArrayFactory::create('c', {1}, {4.f}); + auto e = NDArrayFactory::create('c', {1}, {8.f}); + + auto z = x * y; + + ASSERT_TRUE(e.equalsTo(z)); +} + +TEST_F(BroadcastableOpsTests, Test_Multiply_7) { + auto x = NDArrayFactory::create(2.0f); + auto y = NDArrayFactory::create('c', {1}, {4.f}); + auto e = NDArrayFactory::create('c', {1}, {8.f}); + + sd::ops::multiply op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(e.equalsTo(z)); + +} + +TEST_F(BroadcastableOpsTests, Test_Multiply_8) { + auto x = NDArrayFactory::create(2.0f); + auto y = NDArrayFactory::create('c', {1, 1}, {4.f}); + auto e = NDArrayFactory::create('c', {1, 1}, {8.f}); + + sd::ops::multiply op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(e.equalsTo(z)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(BroadcastableOpsTests, broadcast_add_1) { + + NDArray x('c', {4}, {1,1,1,1}); + NDArray y('c', {1,4}, {1,2,3,4}); + NDArray z('c', {1,4}, sd::DataType::DOUBLE); + NDArray exp('c', {1,4}, {2,3,4,5}, sd::DataType::DOUBLE); + + sd::ops::add op; + auto status = op.execute({&x, &y}, {&z}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(z.equalsTo(exp)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(BroadcastableOpsTests, broadcast_equals_1) { + + NDArray x('c', {1,4}, {1,2,3,4}); + NDArray y('c', {3,4}, {0,0,0,0, 1,2,3,4, 1,2,3,4}); + NDArray z('c', {3,4}, sd::DataType::BOOL); + NDArray exp('c', {3,4}, {0,0,0,0, 1,1,1,1, 1,1,1,1}, sd::DataType::BOOL); + + sd::ops::equals op; + auto status = op.execute({&x, &y}, {&z}); + // z.printIndexedBuffer(); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(z.equalsTo(exp)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(BroadcastableOpsTests, broadcast_empty_1) { + + NDArray y('c', {3,4}, {0,0,0,0, 1,2,3,4, 1,2,3,4}); + NDArray x(sd::DataType::DOUBLE, y.getContext(), false); + NDArray z(sd::DataType::DOUBLE, y.getContext(), false); + NDArray zExp(sd::DataType::DOUBLE, y.getContext(), false); + + sd::ops::multiply op; + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(z.isSameShape(zExp)); + ASSERT_TRUE(z.equalsTo(zExp)); +} + +TEST_F(BroadcastableOpsTests, broadcast_empty_2) { + + NDArray y('c', {1,4}, {1,2,3,4}); + NDArray x = NDArrayFactory::create('c', {0, 4}); + NDArray e = NDArrayFactory::create('c', {0, 4});; + + sd::ops::multiply op; + auto status = op.execute({&x, &y}, {&x}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(e.isSameShape(x)); + ASSERT_TRUE(e.equalsTo(x)); +} + +TEST_F(BroadcastableOpsTests, broadcast_empty_3) { + + NDArray x = NDArrayFactory::create('c', {1, 0, 2}); + NDArray y('c', {}, std::vector{0.1}, sd::DataType::FLOAT32); + NDArray e = NDArrayFactory::create('c', {1, 0, 2});; + + sd::ops::maximum op; + auto result = op.evaluate({&x, &y}); + + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(*z)); +} + +TEST_F(BroadcastableOpsTests, broadcast_empty_4) { + + NDArray x = NDArrayFactory::create('c', {1, 0, 1}); + NDArray y = NDArrayFactory::create('c', {1, 0, 2}); + NDArray e = NDArrayFactory::create('c', {1, 0, 2});; + + sd::ops::maximum op; + auto result = op.evaluate({&x, &y}); + + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(*z)); + +} + +TEST_F(BroadcastableOpsTests, broadcast_empty_5) { + + NDArray x = NDArrayFactory::create('c', {1, 0, 1}); + NDArray y = NDArrayFactory::create('c', {1, 0, 2}); + NDArray e = NDArrayFactory::create('c', {1, 0, 2});; + + sd::ops::realdiv op; + auto result = op.evaluate({&x, &y}); + + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(*z)); + +} + +TEST_F(BroadcastableOpsTests, broadcast_empty_6) { + + NDArray x = NDArrayFactory::create('c', {1, 0, 1}); + NDArray y = NDArrayFactory::create('c', {1, 2}, {2, 2}); + NDArray e = NDArrayFactory::create('c', {1, 0, 2});; + + sd::ops::realdiv op; + auto result = op.evaluate({&x, &y}); + + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(*z)); + +} + +TEST_F(BroadcastableOpsTests, broadcast_empty_7) { + + NDArray x = NDArrayFactory::create('c', {1, 0, 2, 1}); + NDArray y = NDArrayFactory::create('c', {1, 2, 0}); + NDArray e = NDArrayFactory::create('c', {1, 0, 2, 0});; + + sd::ops::realdiv op; + auto result = op.evaluate({&x, &y}); + + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(*z)); +} + + +TEST_F(BroadcastableOpsTests, broadcast_bool_empty_1) { + + NDArray y('c', {3,4}, {0,0,0,0, 1,2,3,4, 1,2,3,4}); + NDArray x(sd::DataType::DOUBLE, y.getContext(), false); + NDArray z(sd::DataType::BOOL, y.getContext(), false); + NDArray zExp(sd::DataType::BOOL, y.getContext(), false); + + sd::ops::greater op; + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(z.isSameShape(zExp)); + ASSERT_TRUE(z.equalsTo(zExp)); +} + +TEST_F(BroadcastableOpsTests, broadcast_bool_empty_2) { + + NDArray y('c', {1,4}, {1,2,3,4}); + NDArray x = NDArrayFactory::create('c', {0, 4}); + NDArray e = NDArrayFactory::create('c', {0, 4});; + + + sd::ops::greater op; + auto result = op.evaluate({&x, &y}); + + auto z = result.at(0); + + // z->printShapeInfo("z"); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(*z)); +} + +TEST_F(BroadcastableOpsTests, broadcast_bool_1) { + + NDArray x('c', {3, 1, 2}, sd::DataType::FLOAT32); + NDArray y('c', {2, 2}, sd::DataType::FLOAT32); + NDArray z('c', {3, 2, 2}, sd::DataType::BOOL); + NDArray e('c', {3, 2, 2}, sd::DataType::BOOL); + + x.assign(4.f); + y.assign(2.f); + e.assign(true); + + sd::ops::greater op; + + auto status = op.execute({&x, &y}, {&z}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + + // z.printIndexedBuffer("Z"); + + ASSERT_TRUE(z.isSameShape(e)); + ASSERT_TRUE(z.equalsTo(e)); +} + +TEST_F(BroadcastableOpsTests, broadcast_bool_2) { + + NDArray x('c', {3, 1, 2}, sd::DataType::FLOAT32); + NDArray y('c', {2, 2}, sd::DataType::FLOAT32); + NDArray z('c', {3, 2, 2}, sd::DataType::BOOL); + NDArray e('c', {3, 2, 2}, sd::DataType::BOOL); + + x.assign(1.f); + y.assign(2.f); + e.assign(false); + + sd::ops::equals op; + + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + + // z.printIndexedBuffer("Z"); + + ASSERT_TRUE(z.isSameShape(e)); + ASSERT_TRUE(z.equalsTo(e)); +} + +TEST_F(BroadcastableOpsTests, broadcast_bool_3) { + + auto x = NDArrayFactory::create(0); + auto y = NDArrayFactory::create('c', {3}, {2, 1, 2}); + NDArray z('c', {3}, sd::DataType::BOOL); + NDArray e('c', {3}, sd::DataType::BOOL); + + e.assign(true); + + sd::ops::less op; + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + + // z.printIndexedBuffer("Z"); + + ASSERT_TRUE(z.isSameShape(e)); + ASSERT_TRUE(z.equalsTo(e)); +} + +TEST_F(BroadcastableOpsTests, broadcast_2) { + NDArray x('c', {3, 1, 2}, sd::DataType::FLOAT32); + NDArray y('c', {2, 2}, sd::DataType::FLOAT32); + NDArray z('c', {3, 2, 2}, sd::DataType::FLOAT32); + NDArray e('c', {3, 2, 2}, sd::DataType::FLOAT32); + + x = 4.f; + y = 2.f; + e = -2.f; + + sd::ops::reversesubtract op; // z = y - x; + + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + + // z.printIndexedBuffer("Z"); + + ASSERT_TRUE(z.isSameShape(e)); + ASSERT_TRUE(z.equalsTo(e)); +} + +TEST_F(BroadcastableOpsTests, broadcast_3) { + auto x = NDArrayFactory::create(0); + auto y = NDArrayFactory::create('c', {3}, {2, 1, 2}); + NDArray z('c', {3}, sd::DataType::INT32); + auto e = NDArrayFactory::create('c', {3}, {2, 1, 2}); + + sd::ops::add op; + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + + // z.printIndexedBuffer("Z"); + + ASSERT_TRUE(z.isSameShape(e)); + ASSERT_TRUE(z.equalsTo(e)); +} + +TEST_F(BroadcastableOpsTests, test_bert_multiply_1) { + auto x = NDArrayFactory::create('c', {4, 128, 1}); + auto y = NDArrayFactory::create('c', {4, 1, 128}); + auto z = NDArrayFactory::create('c', {4, 128, 128}); + auto e = NDArrayFactory::create('c', {4, 128, 128}); + + x.assign(0.f); + y.assign(1.f); + z.assign(119.f); + e.assign(0.f); +/* + Context ctx(1); + ctx.setInputArray(0, &x); + ctx.setInputArray(1, &y); + ctx.setOutputArray(0, &z); + + sd::ops::multiply op; + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); + + z.printIndexedBuffer(); +*/ + + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); + + //z.printIndexedBuffer(); + + ASSERT_EQ(e, z); +} + +TEST_F(BroadcastableOpsTests, test_bert_multiply_2) { + auto x = NDArrayFactory::create('c', {4, 128, 1}); + auto y = NDArrayFactory::create('c', {768}); + auto z = NDArrayFactory::create('c', {4, 128, 768}); + auto e = NDArrayFactory::create('c', {4, 128, 768}); + + x.assign(1.f); + y.assign(2.f); + z.assign(119.f); + e.assign(2.f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); + + ASSERT_EQ(e, z); +} diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/BrodcastTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/BrodcastTests.cpp new file mode 100644 index 000000000..886de211d --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/BrodcastTests.cpp @@ -0,0 +1,68 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by agibsonccc on 1/19/17. +// + +#include "testinclude.h" +#include + +class BroadcastMultiDimTest : public testing::Test { +public: + int dimensions[2] = {0,2}; + Nd4jLong inputShapeBuffer[10] = {3,2,3,5,15,5,1,8192,1,99}; + float inputData[30] = {1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0,16.0,17.0,18.0,19.0,20.0,21.0,22.0,23.0,24.0,25.0,26.0,27.0,28.0,29.0,30.0}; + float dataAssertion[30] = {1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0,16.0,17.0,18.0,0.0,0.0,21.0,22.0,23.0,0.0,0.0,26.0,27.0,28.0,0.0,0.0}; + float result[30] = {0.0}; + float broadcastData[10] = {1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0}; + Nd4jLong broadcastShapeInfo[8] = {2,2,5,5,1,8192,1,99}; + int opNum = 2; + int dimensionLength = 2; +}; + +#ifndef __CUDABLAS__ + +TEST_F(BroadcastMultiDimTest,MultimDimTest) { + auto tad = new shape::TAD(); + tad->init(inputShapeBuffer,dimensions,dimensionLength); + tad->createTadOnlyShapeInfo(); + tad-> createOffsets(); + functions::broadcast::Broadcast::exec( + opNum, + inputData, //x + inputShapeBuffer, //xShapeInfo + broadcastData, //y + broadcastShapeInfo, //yShapeInfo + result, //result + inputShapeBuffer, //resultShapeInfo + dimensions, //dimension + dimensionLength, //dimensionLength + tad->tadOnlyShapeInfo, //tadShapeInfo + tad->tadOffsets, //tadOffset + tad->tadOnlyShapeInfo, //tadShapeInfoZ + tad->tadOffsets, sd::LoopKind::COMMON, 0, tad->numTads); //tadOffsetZ + + for(int i = 0; i < 30; i++) { + ASSERT_EQ(dataAssertion[i],result[i]); + } + + delete tad; +} + +#endif \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/CMakeLists.txt b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/CMakeLists.txt new file mode 100644 index 000000000..ac8578af0 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/CMakeLists.txt @@ -0,0 +1,171 @@ +include_directories(${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) +if(LINUX) + link_directories(/usr/local/lib) + link_directories(/usr/lib) + link_directories(/lib) +endif() + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +if(APPLE) + message("Using apple") + link_directories(/usr/local/lib) + link_directories(/usr/lib) + link_directories(/lib) +endif() +if(WIN32) + get_property(dirs DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES) + foreach(dir ${dirs}) + message(STATUS "dir='${dir}'") + endforeach() +endif() + +if (SD_CUDA) + find_package(CUDA) + message("Tests CUDA include directory: ${CUDA_INCLUDE_DIRS}") + include_directories(${CUDA_INCLUDE_DIRS}) + add_definitions(-D__CUDABLAS__=true) + + if(WIN32) + message("CUDA on Windows: enabling /EHsc") + SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /EHsc /FS") + endif() + + string( TOLOWER "${COMPUTE}" COMPUTE_CMP ) + if ("${COMPUTE_CMP}" STREQUAL "all") + CUDA_SELECT_NVCC_ARCH_FLAGS(CUDA_ARCH_FLAGS "Common") + elseif("${COMPUTE_CMP}" STREQUAL "auto") + CUDA_SELECT_NVCC_ARCH_FLAGS(CUDA_ARCH_FLAGS "Auto") + elseif(COMPUTE_CMP MATCHES "^[0-9]+$") + #matches USER COMPUTE old way + set(CUDA_ARCH_FLAGS "-gencode arch=compute_${COMPUTE},code=sm_${COMPUTE} ") + else() + #matches numbers NAME | NUM.NUM | NUM.NUM(NUM.NUM) | NUM.NUM+PTX + #NAME: Fermi Kepler Maxwell Kepler+Tegra Kepler+Tesla Maxwell+Tegra Pascal + #NUM: 2.0 2.1 3.0 3.2 3.5 3.7 5.0 5.2 5.3 6.0 6.2 et cetera + CUDA_SELECT_NVCC_ARCH_FLAGS(CUDA_ARCH_FLAGS "${COMPUTE}") + endif() + # list to spaces + string (REPLACE ";" " " CUDA_ARCH_FLAGS "${CUDA_ARCH_FLAGS}") + + set(CMAKE_CUDA_FLAGS " ${CMAKE_CUDA_FLAGS} -DCUDA_VERSION_MAJOR=${CUDA_VERSION_MAJOR} ${EXPM} -w --cudart=static --expt-extended-lambda -Xfatbin -compress-all ${CUDA_ARCH_FLAGS}") + +endif() + +# -fsanitize=address +# -fsanitize=leak +if (APPLE) + set(CMAKE_CXX_FLAGS " -fPIC -D__APPLE_OS__=true") +elseif(WIN32) + if (SD_CPU) + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -fPIC -mmmx -msse -msse2 -msse3 -mssse3 -msse4.1 -msse4.2 -msse4 -mavx -mavx2 -O3") + endif() + + if (SD_CPU AND LINUX) + set(CMAKE_CXX_FLAGS " -fPIC") + endif() +else() + set(CMAKE_CXX_FLAGS " -fPIC") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3") + IF(${SD_ARCH} MATCHES "arm*") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=${SD_ARCH}") + else() + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3") + + if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64*") + set(CMAKE_CXX_FLAGS " ${CMAKE_CXX_FLAGS} -mcpu=native") + else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native -mtune=native") + endif() + endif() + if (SD_CPU AND SD_SANITIZE) + set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fsanitize=address") + else() + # CUDA? + endif() +endif() + + +# tests are always compiled with all ops included +SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DSD_ALL_OPS=true -DBUILD_TESTS=true") + +if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") + # using Clang + SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE}") +elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Intel") + # using Intel C++ + SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE} -fp-model fast") +elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC") + # using Visual Studio C++ + +elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + # using GCC + SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fmax-errors=2") + + if (CMAKE_BUILD_TYPE STREQUAL "Debug" AND ${CMAKE_SYSTEM_NAME} MATCHES "Linux" AND NOT(MINGW)) + SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -rdynamic -Wl,-export-dynamic") + SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -export-dynamic") + endif() +endif() + +IF(${CMAKE_SYSTEM_NAME} MATCHES "Linux") + include_directories("/usr/include") + include_directories("/usr/local/include") +ENDIF(${CMAKE_SYSTEM_NAME} MATCHES "Linux") + +if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 4.9) + message(FATAL_ERROR "You need at least GCC 4.9") +endif() + +if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + find_package(OpenMP) +endif() +if (OPENMP_FOUND) + set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") + set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") +else() + message("OPENMP NOT FOUND") +endif() + +if (SD_CPU) + file(GLOB_RECURSE TEST_SOURCES false ./*.cpp ./*.h) +elseif (SD_CUDA) + file(GLOB_RECURSE TEST_SOURCES false ./*.cpp ./*.cu ./*.h) +endif() + +# Filter out any source files from */CMakeFiles/* paths. these tend to cause problems such a multiple main definitions. +set (EXCLUDE_DIR "/CMakeFiles/") +foreach (TMP_PATH ${TEST_SOURCES}) + string (FIND ${TMP_PATH} ${EXCLUDE_DIR} EXCLUDE_DIR_FOUND) + if (NOT ${EXCLUDE_DIR_FOUND} EQUAL -1) + list (REMOVE_ITEM TEST_SOURCES ${TMP_PATH}) + endif () +endforeach(TMP_PATH) + +if (SD_CPU) + if (NOT BLAS_LIBRARIES) + set(BLAS_LIBRARIES "") + endif() + + add_executable(runtests ${TEST_SOURCES}) + target_link_libraries(runtests samediff_obj ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES} ${ARMCOMPUTE_LIBRARIES} gtest gtest_main) +elseif(SD_CUDA) + + add_executable(runtests ${TEST_SOURCES}) + + if (WIN32) + message("MSVC runtime for tests: ${MSVC_RT_LIB}") + endif() + + # applies to windows only + set_property(TARGET runtests PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$:Debug>") + set_property(TARGET gtest PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$:Debug>") + set_property(TARGET gtest_main PROPERTY MSVC_RUNTIME_LIBRARY "${MSVC_RT_LIB}$<$:Debug>") + + if (HAVE_CUDNN) + message("CUDNN library: ${CUDNN}") + endif() + + target_link_libraries(runtests samediff_obj ${CUDA_LIBRARIES} ${CUDA_CUBLAS_LIBRARIES} ${CUDA_cusolver_LIBRARY} ${CUDNN} ${MKLDNN} gtest gtest_main) +endif() \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/CnpyTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/CnpyTests.cpp new file mode 100644 index 000000000..407de9e36 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/CnpyTests.cpp @@ -0,0 +1,96 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by agibsonccc on 3/30/17. +// + +#include "testinclude.h" +#include +#include + +class FileTest : public testing::Test { + +}; + +class LoadFromStringTest : public testing::Test { + +}; + +class HeaderTest : public testing::Test { + +}; + +TEST_F(HeaderTest, test_dataTypes_1) { + std::string header("0NUMPY6789{'descr': '>f4"); + + + ASSERT_EQ(sd::DataType::FLOAT32, dataTypeFromNpyHeader(const_cast(header.data()))); +} + +TEST_F(HeaderTest, test_dataTypes_2) { + std::string header("0NUMPY6789{'descr': '>f8"); + + + ASSERT_EQ(sd::DataType::DOUBLE, dataTypeFromNpyHeader(const_cast(header.data()))); +} + +TEST_F(HeaderTest, test_dataTypes_3) { + std::string header("0NUMPY6789{'descr': '(header.data()))); +} + +TEST_F(HeaderTest, test_dataTypes_4) { + std::string header("0NUMPY6789{'descr': '>u2"); + + + ASSERT_EQ(sd::DataType::UINT16, dataTypeFromNpyHeader(const_cast(header.data()))); +} + +/* +TEST_F(FileTest,T) { + cnpy::NpyArray npy = cnpy::npyLoad(std::string("/home/agibsonccc/code/libnd4j/test.npy")); + ASSERT_FALSE(npy.fortranOrder); + + ASSERT_EQ(2,npy.shape[0]); + ASSERT_EQ(2,npy.shape[1]); +} + +TEST_F(LoadFromStringTest,PathTest) { + char *loaded = cnpy::loadFile("/home/agibsonccc/code/libnd4j/test.npy"); + cnpy::NpyArray loadedArr = cnpy::loadNpyFromPointer(loaded); + ASSERT_FALSE(loadedArr.fortranOrder); + ASSERT_EQ(2,loadedArr.shape[0]); + ASSERT_EQ(2,loadedArr.shape[1]); + double *data = reinterpret_cast(loadedArr.data); + ASSERT_EQ(1.0,data[0]); + ASSERT_EQ(2.0,data[1]); + ASSERT_EQ(3.0,data[2]); + ASSERT_EQ(4.0,data[3]); + Nd4jPointer pointer = reinterpret_cast(&loadedArr); + int *shapeBuffer = shape::shapeBufferOfNpy(loadedArr); + Nd4jPointer pointer1 = dataPointForNumpy(loaded); + delete[] shapeBuffer; + + double *data2 = reinterpret_cast(pointer1); + delete[] loaded; +} + +*/ diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ConditionalTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ConditionalTests.cpp new file mode 100644 index 000000000..b4dfab6fc --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ConditionalTests.cpp @@ -0,0 +1,334 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 16.10.2017. +// + +#include "testlayers.h" +#include +#include +#include +#include + +using namespace sd; +using namespace sd::graph; + +class ConditionalTests : public testing::Test { +public: + ConditionalTests(){ + //Environment::getInstance().setVerbose(true); + //Environment::getInstance().setDebug(true); + } + + ~ConditionalTests(){ + //Environment::getInstance().setVerbose(false); + //Environment::getInstance().setDebug(false); + } +}; + + +TEST_F(ConditionalTests, BasicTests_1) { + Graph graph; + + auto x = NDArrayFactory::valueOf({2, 2}, 1.0f); + auto y0 = NDArrayFactory::valueOf({2, 2}, 5.0f); + auto y1 = NDArrayFactory::valueOf({2, 2}, -5.0f); + auto scalar = NDArrayFactory::create_(1.0f); + + auto variableSpace = graph.getVariableSpace(); + + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y0); + variableSpace->putVariable(-3, y1); + variableSpace->putVariable(-4, scalar); + + + auto scopeCondition = new Node(OpType_LOGIC, logic::Scope, 1); + scopeCondition->setName("scopeCondition"); + + auto scopeFalse = new Node(OpType_LOGIC, logic::Scope, 2); + scopeFalse->setName("scopeFalse"); + + auto scopeTrue = new Node(OpType_LOGIC, logic::Scope, 3); + scopeTrue->setName("scopeTrue"); + + auto nodeF = new Node(OpType_PAIRWISE, pairwise::Add, 5, {-1, -2}); + nodeF->setScopeInfo(2, "scopeFalse"); + + auto nodeT = new Node(OpType_PAIRWISE, pairwise::Subtract, 6, {-1, -2}); + nodeT->setScopeInfo(3, "scopeTrue"); + + auto nodeC0 = new Node(OpType_REDUCE_SAME, reduce::Sum, 7, {-1}); + nodeC0->setScopeInfo(1, "scopeCondition"); + + sd::ops::eq_scalar op; + auto nodeC1 = new Node(&op, 8, {7, -4}); + nodeC1->setScopeInfo(1, "scopeCondition"); + + graph.addNode(scopeCondition); + graph.addNode(scopeFalse); + graph.addNode(scopeTrue); + graph.addNode(nodeF); + graph.addNode(nodeT); + graph.addNode(nodeC0); + graph.addNode(nodeC1); + + // at this point graph should ounly have Nodes referring to the Scopes: condition scope, true scope and false scope + ASSERT_EQ(3, graph.totalNodes()); + + // now we're adding Condition op, that'll take all of those in + auto nodeCondition = new Node(OpType_LOGIC, logic::Conditional, 10, {1, 2, 3}); + graph.addNode(nodeCondition); + + ASSERT_EQ(4, graph.totalNodes()); + + Nd4jStatus status = GraphExecutioner::execute(&graph); + ASSERT_EQ(ND4J_STATUS_OK, status); + + ASSERT_TRUE(variableSpace->hasVariable(10, 0)); + auto conditionalResult = variableSpace->getVariable(10, 0)->getNDArray(); + ASSERT_NE(nullptr, conditionalResult); + + ASSERT_NEAR(6.0, conditionalResult->meanNumber().e(0), 1e-5); +} +#ifdef GRAPH_FILES_OK +/** + * Condition is False + */ +TEST_F(ConditionalTests, Flat_Test_1) { + sd::ops::identity op0; + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simpleif_0_1.fb"); + auto varSpace = graph->getVariableSpace(); + //varSpace->getVariable(1)->getNDArray()->assign(2.0); + //varSpace->getVariable(2)->getNDArray()->assign(0.0); + + //graph->printOut(); + + auto status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(varSpace->hasVariable(15)); + + auto z = varSpace->getVariable(15)->getNDArray(); + + ASSERT_NE(nullptr, z); + + auto exp = NDArrayFactory::create('c', {2, 2}, {-2, -2, -2, -2}); + + ASSERT_TRUE(exp.equalsTo(z)); + + delete graph; +} + +/** + * Condition is True + */ +TEST_F(ConditionalTests, Flat_Test_2) { + Environment::getInstance().setDebug(true); + Environment::getInstance().setVerbose(true); + sd::ops::identity op0; + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simpleif_0.fb"); + auto varSpace = graph->getVariableSpace(); + varSpace->getVariable(1)->getNDArray()->assign(-1.0); + + graph->printOut(); + + auto status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(varSpace->hasVariable(15)); + + auto z = varSpace->getVariable(15)->getNDArray(); + + ASSERT_NE(nullptr, z); + + auto exp = NDArrayFactory::create('c', {2, 2}, {1, 1, 1, 1}); + + ASSERT_TRUE(exp.equalsTo(z)); + delete graph; +} + + +/** + * Condition is false here, so there loop will be skipped + */ +TEST_F(ConditionalTests, Flat_Test_3) { + sd::ops::identity op0; + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simplewhile_0_3.fb"); + auto varSpace = graph->getVariableSpace(); + varSpace->getVariable(1)->getNDArray()->assign(1.0); + + //graph->printOut(); + + auto status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(varSpace->hasVariable(17)); + + auto z = varSpace->getVariable(17)->getNDArray(); + + ASSERT_NE(nullptr, z); + + auto exp = NDArrayFactory::create('c', {2, 2}, {1, 1, 1, 1}); + ASSERT_TRUE(exp.equalsTo(z)); + + delete graph; +} + +/** + * just one cycle in body + */ +TEST_F(ConditionalTests, Flat_Test_4) { + sd::ops::identity op0; + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simplewhile_0_4.fb"); + auto varSpace = graph->getVariableSpace(); + varSpace->getVariable(2)->getNDArray()->assign(4.0); + + //graph->printOut(); + + auto status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(varSpace->hasVariable(17)); + + auto z = varSpace->getVariable(17)->getNDArray(); + + ASSERT_NE(nullptr, z); + + // 0.0 + 2.0 = 2.0 in each element + auto exp = NDArrayFactory::create('c', {2, 2}, {2, 2, 2, 2}); + ASSERT_TRUE(exp.equalsTo(z)); + + delete graph; +} + + +/** + * just two cycles in body + */ +TEST_F(ConditionalTests, Flat_Test_5) { + sd::ops::identity op0; + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simplewhile_0_4.fb"); + auto varSpace = graph->getVariableSpace(); + varSpace->getVariable(2)->getNDArray()->assign(9.0); + + //graph->printOut(); + + auto status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(varSpace->hasVariable(17)); + + auto z = varSpace->getVariable(17)->getNDArray(); + + ASSERT_NE(nullptr, z); + + // 0.0 + 2.0 + 2.0 = 4.0 in each element + auto exp = NDArrayFactory::create('c', {2, 2}, {4, 4, 4, 4}); + ASSERT_TRUE(exp.equalsTo(z)); + + delete graph; +} + +/** + * While loop with multiple variables + */ +TEST_F(ConditionalTests, Flat_Test_6) { + sd::ops::identity op0; + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simplewhile_1.fb"); + auto varSpace = graph->getVariableSpace(); + varSpace->getVariable(1)->getNDArray()->assign(-4.0f); + varSpace->getVariable(2)->getNDArray()->assign(1.0f); + + //graph->printOut(); + + auto status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(varSpace->hasVariable(25)); + + auto z = varSpace->getVariable(25)->getNDArray(); + + ASSERT_NE(nullptr, z); + + //z->printIndexedBuffer(); + + auto exp = NDArrayFactory::create('c', {2, 2}, {-1, -1, -1, -1}); + ASSERT_TRUE(exp.equalsTo(z)); + + delete graph; +} + +TEST_F(ConditionalTests, Flat_Test_7) { + sd::ops::identity op0; + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simplewhile_1.fb"); + auto varSpace = graph->getVariableSpace(); + varSpace->getVariable(1)->getNDArray()->assign(-9.0f); + varSpace->getVariable(2)->getNDArray()->assign(1.0f); + + //graph->printOut(); + + auto status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(varSpace->hasVariable(25)); + + auto z = varSpace->getVariable(25)->getNDArray(); + + ASSERT_NE(nullptr, z); + + //z->printIndexedBuffer(); + + auto exp = NDArrayFactory::create('c', {2, 2}, {-3, -3, -3, -3}); + ASSERT_TRUE(exp.equalsTo(z)); + + delete graph; +} + +/** + * This test checks nested while execution + */ +TEST_F(ConditionalTests, Flat_Test_8) { + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simplewhile_nested.fb"); + auto varSpace = graph->getVariableSpace(); + //graph->printOut(); + + auto status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(varSpace->hasVariable(52)); + + auto z = varSpace->getVariable(52)->getNDArray(); + + ASSERT_NE(nullptr, z); + + //val exp = Nd4j.create(2, 2).assign(15.0); + auto exp = NDArrayFactory::create('c', {2, 2}, {15, 15, 15, 15}); + ASSERT_TRUE(exp.equalsTo(z)); + + delete graph; +} +#endif diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ConstantShapeHelperTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ConstantShapeHelperTests.cpp new file mode 100644 index 000000000..97083c52e --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ConstantShapeHelperTests.cpp @@ -0,0 +1,351 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include + +using namespace sd; +using namespace sd::ops; +using namespace sd::graph; + +class ConstantShapeHelperTests : public testing::Test { +public: + +}; + +class ConstantHelperTests : public testing::Test { +public: + +}; + +class ConstantTadHelperTests : public testing::Test { +public: + +}; + +TEST_F(ConstantShapeHelperTests, test_cachedAmount_1) { + auto ttlBefore = ConstantShapeHelper::getInstance().totalCachedEntries(); + + auto arrayA = NDArrayFactory::create('c', {7, 11, 17, 23, 31, 43}); + + auto ttlMiddle = ConstantShapeHelper::getInstance().totalCachedEntries(); + + auto arrayB = NDArrayFactory::create('c', {7, 11, 17, 23, 31, 43}); + + auto ttlAfter = ConstantShapeHelper::getInstance().totalCachedEntries(); + + ASSERT_TRUE(ttlBefore <= ttlMiddle); + ASSERT_EQ(ttlMiddle, ttlAfter); +} + +TEST_F(ConstantTadHelperTests, test_cachedAmount_1) { + auto arrayA = NDArrayFactory::create('c', {7, 11, 17, 23, 31, 43}); + auto ttlBefore = ConstantTadHelper::getInstance().totalCachedEntries(); + + auto packAA = ConstantTadHelper::getInstance().tadForDimensions(arrayA.shapeInfo(), {3, 4}); + + auto ttlMiddle = ConstantTadHelper::getInstance().totalCachedEntries(); + + auto packAB = ConstantTadHelper::getInstance().tadForDimensions(arrayA.shapeInfo(), {3, 4}); + + auto ttlAfter = ConstantTadHelper::getInstance().totalCachedEntries(); + + ASSERT_TRUE(ttlBefore <= ttlMiddle); + ASSERT_EQ(ttlMiddle, ttlAfter); +} + +TEST_F(ConstantShapeHelperTests, basic_test_1) { + auto ptr = ShapeBuilders::createShapeInfo(sd::DataType::BFLOAT16, 'f', {5, 10, 15}); + ShapeDescriptor descriptor(ptr); + ShapeDescriptor descriptor2(ptr); + + ASSERT_EQ(descriptor, descriptor2); + + ASSERT_EQ(1, descriptor.ews()); + ASSERT_EQ(3, descriptor.rank()); + ASSERT_EQ('f', descriptor.order()); + ASSERT_EQ(sd::DataType::BFLOAT16, descriptor.dataType()); + ASSERT_FALSE(descriptor.isEmpty()); + + ASSERT_FALSE(ConstantShapeHelper::getInstance().checkBufferExistenceForShapeInfo(descriptor)); + + auto buffer = ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor); + + ASSERT_TRUE(ConstantShapeHelper::getInstance().checkBufferExistenceForShapeInfo(descriptor)); + + auto buffer2 = ConstantShapeHelper::getInstance().bufferForShapeInfo(descriptor2); + + + ASSERT_TRUE(buffer.primary() != nullptr); + ASSERT_TRUE(buffer.primary() == buffer2.primary()); + ASSERT_TRUE(buffer.special() == buffer2.special()); + + delete []ptr; +} + +TEST_F(ConstantShapeHelperTests, stress_test_1) { + + for (auto x = 0; x < 1000; x++) { + auto ptr = ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', {5, x + 10, x + 1}); + ShapeDescriptor descriptor(ptr); + ConstantShapeHelper::getInstance().createShapeInfo(descriptor); + delete [] ptr; + } + ShapeDescriptor aShape(sd::DataType::FLOAT32, 'c', {(Nd4jLong)5, (Nd4jLong)382, (Nd4jLong)373}); +// nd4j_printf("%d\n", ConstantShapeHelper::getInstance().cachedEntriesForDevice(0)); + + auto timeStart = std::chrono::system_clock::now(); + ASSERT_TRUE(ConstantShapeHelper::getInstance().checkBufferExistenceForShapeInfo(aShape)); + auto timeEnd = std::chrono::system_clock::now(); + + auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); + nd4j_printf("Total time (us) %lld\n", outerTime); +} + +TEST_F(ConstantShapeHelperTests, basic_test_3) { + auto array = NDArrayFactory::create_('c', {128}); + + ASSERT_TRUE(array->shapeInfo() != nullptr); + +#ifdef __CUDABLAS__ + ASSERT_TRUE(array->specialShapeInfo() != nullptr); +#endif + + delete array; +} + + +TEST_F(ConstantShapeHelperTests, basic_test_4) { + auto array = NDArrayFactory::create_('c', {128, 256}); + + auto dup = new NDArray(array->dup('f')); + + ASSERT_TRUE(dup->shapeInfo() != nullptr); + +#ifdef __CUDABLAS__ + ASSERT_TRUE(dup->specialShapeInfo() != nullptr); + PointersManager manager(sd::LaunchContext ::defaultContext(), "test"); + // manager.printDevContentOnDev(dup->special(), shape::shapeInfoLength(2), 0); +#endif + + delete array; + delete dup; +} + + +TEST_F(ConstantShapeHelperTests, basic_test_5) { + + auto arrayA = NDArrayFactory::create(1); + auto arrayB = NDArrayFactory::create_('c', {128, 256}); + + //arrayA.printShapeInfo("A"); + //arrayB->printShapeInfo("B"); + ASSERT_EQ(0, arrayA.rankOf()); + ASSERT_EQ(2, arrayB->rankOf()); + ASSERT_NE(arrayA.dataType(), arrayB->dataType()); + + delete arrayB; +} + +TEST_F(ConstantShapeHelperTests, basic_test_6) { + ShapeDescriptor descriptorA(sd::DataType::INT32, 'c', {}); + ShapeDescriptor descriptorB(sd::DataType::FLOAT32, 'c', {10, 10}); + + // ASSERT_FALSE(descriptorA < descriptorB); + // ASSERT_TRUE(descriptorB < descriptorA); + + ASSERT_TRUE(descriptorA < descriptorB); + ASSERT_FALSE(descriptorB < descriptorA); +} + +TEST_F(ConstantShapeHelperTests, basic_test_7) { + auto array = NDArrayFactory::create_('c', {32, 256}); + + IndicesList indices({NDIndex::all(), NDIndex::interval(0,1)}); + auto strided = array->subarray(indices); + strided.assign(1.0f); + + //strided->printIndexedBuffer("column"); + + delete array; +} + +TEST_F(ConstantHelperTests, basic_test_1) { + + ConstantDescriptor descriptor({1, 2, 3}); + + ConstantDataBuffer* fBuffer = ConstantHelper::getInstance().constantBuffer(descriptor, sd::DataType::FLOAT32); + auto fPtr = fBuffer->primaryAsT(); + + ASSERT_NEAR(1.f, fPtr[0], 1e-5); + ASSERT_NEAR(2.f, fPtr[1], 1e-5); + ASSERT_NEAR(3.f, fPtr[2], 1e-5); + + auto iBuffer = ConstantHelper::getInstance().constantBuffer(descriptor, sd::DataType::INT32); + auto iPtr = iBuffer->primaryAsT(); + + ASSERT_EQ(1, iPtr[0]); + ASSERT_EQ(2, iPtr[1]); + ASSERT_EQ(3, iPtr[2]); +} + +TEST_F(ConstantHelperTests, basic_test_2) { + + double array[] = {1., 2., 3.}; + ConstantDescriptor descriptor(array, 3); + + ConstantDataBuffer* fBuffer = ConstantHelper::getInstance().constantBuffer(descriptor, sd::DataType::FLOAT32); + auto fPtr = fBuffer->primaryAsT(); + + ASSERT_NEAR(1.f, fPtr[0], 1e-5); + ASSERT_NEAR(2.f, fPtr[1], 1e-5); + ASSERT_NEAR(3.f, fPtr[2], 1e-5); + + auto iBuffer = ConstantHelper::getInstance().constantBuffer(descriptor, sd::DataType::INT32); + auto iPtr = iBuffer->primaryAsT(); + + ASSERT_EQ(1, iPtr[0]); + ASSERT_EQ(2, iPtr[1]); + ASSERT_EQ(3, iPtr[2]); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConstantShapeHelperTests, ShapeDescriptor_1) { + + Nd4jLong shapeInfo1[] = {4, 2, 5, 5, 2, 25, 5, 1, 50, 8192, 0, 99}; + Nd4jLong shapeInfo2[] = {4, 2, 5, 5, 2, 50, 10, 2, 1, 8192, 1, 99}; + + ShapeDescriptor descr1(shapeInfo1); + ShapeDescriptor descr2(shapeInfo2); + + ASSERT_FALSE(descr1 == descr2); +} + +TEST_F(ConstantShapeHelperTests, ShapeDescriptor_validation) { + + //for c order + std::vector shape{ 2,3,4,5 }; + std::vector incorrectStride1{ 20,20,5,1 }; + std::vector incorrectStride2{ 60,20,5,5 }; + std::vector correctStride1{ 60,20,5,1 }; + std::vector correctStride2{ 300,100,25,5 }; + std::vector correctStride3{ 800, 200, 40, 5 }; + + auto shapeDesc = ShapeDescriptor(DataType::FLOAT32, 'c', shape, incorrectStride1, 1); + ASSERT_TRUE(shapeDesc.validate() == SHAPE_DESC_INCORRECT_STRIDES); + shapeDesc = ShapeDescriptor(DataType::FLOAT32, 'c', shape, correctStride1, 1); + ASSERT_TRUE(shapeDesc.validate() == SHAPE_DESC_OK); + shapeDesc = ShapeDescriptor(DataType::FLOAT32, 'c', shape, incorrectStride2, 1); + ASSERT_TRUE(shapeDesc.validate() == (SHAPE_DESC_INCORRECT_STRIDES | SHAPE_DESC_INCORRECT_EWS)); + shapeDesc = ShapeDescriptor(DataType::FLOAT32, 'c', shape, correctStride2, 1); + ASSERT_TRUE(shapeDesc.validate() == SHAPE_DESC_INCORRECT_EWS); + shapeDesc = ShapeDescriptor(DataType::FLOAT32, 'c', shape, correctStride2, 5); + ASSERT_TRUE(shapeDesc.validate() == SHAPE_DESC_OK); + shapeDesc = ShapeDescriptor(DataType::FLOAT32, 'c', shape, correctStride3, 1); + ASSERT_TRUE(shapeDesc.validate() == SHAPE_DESC_INCORRECT_EWS); + shapeDesc = ShapeDescriptor(DataType::FLOAT32, 'c', shape, correctStride3, 0); + ASSERT_TRUE(shapeDesc.validate() == SHAPE_DESC_OK); + + //order f + std::reverse(std::begin(shape), std::end(shape)); + std::reverse(std::begin(incorrectStride1), std::end(incorrectStride1)); + std::reverse(std::begin(incorrectStride2), std::end(incorrectStride2)); + std::reverse(std::begin(correctStride1), std::end(correctStride1)); + std::reverse(std::begin(correctStride2), std::end(correctStride2)); + std::reverse(std::begin(correctStride3), std::end(correctStride3)); + + shapeDesc = ShapeDescriptor(DataType::FLOAT32, 'f', shape, incorrectStride1, 1); + ASSERT_TRUE(shapeDesc.validate() == SHAPE_DESC_INCORRECT_STRIDES); + shapeDesc = ShapeDescriptor(DataType::FLOAT32, 'f', shape, correctStride1, 1); + ASSERT_TRUE(shapeDesc.validate() == SHAPE_DESC_OK); + shapeDesc = ShapeDescriptor(DataType::FLOAT32, 'f', shape, incorrectStride2, 1); + ASSERT_TRUE(shapeDesc.validate() == (SHAPE_DESC_INCORRECT_STRIDES | SHAPE_DESC_INCORRECT_EWS)); + shapeDesc = ShapeDescriptor(DataType::FLOAT32, 'f', shape, correctStride2, 1); + ASSERT_TRUE(shapeDesc.validate() == SHAPE_DESC_INCORRECT_EWS); + shapeDesc = ShapeDescriptor(DataType::FLOAT32, 'f', shape, correctStride2, 5); + ASSERT_TRUE(shapeDesc.validate() == SHAPE_DESC_OK); + shapeDesc = ShapeDescriptor(DataType::FLOAT32, 'f', shape, correctStride3, 1); + ASSERT_TRUE(shapeDesc.validate() == SHAPE_DESC_INCORRECT_EWS); + shapeDesc = ShapeDescriptor(DataType::FLOAT32, 'f', shape, correctStride3, 0); + ASSERT_TRUE(shapeDesc.validate() == SHAPE_DESC_OK); + + std::vector shape1; + shape1.resize(MAX_RANK+1); + shapeDesc = ShapeDescriptor(DataType::FLOAT32, 'f', shape1, correctStride3, 0); + ASSERT_TRUE( (shapeDesc.validate() & SHAPE_DESC_INCORRECT_RANK) == SHAPE_DESC_INCORRECT_RANK); + +} + +TEST_F(ConstantShapeHelperTests, ShapeDescriptor_paddedBuffer) { + + constexpr int n = 2; + constexpr int c = 3; + constexpr int h = 4; + constexpr int w = 5; + constexpr int n_pad = 2; + constexpr int c_pad = 3; + constexpr int h_pad = 4; + constexpr int w_pad = 5; + char orders[] = { 'c', 'f' }; + + for (auto& order : orders) { + auto shapeDesc1 = ShapeDescriptor::paddedBufferDescriptor(DataType::FLOAT32, order, { n, c, h, w }, { n_pad, c_pad, h_pad, w_pad }); + auto shapeDesc2 = ShapeDescriptor(DataType::FLOAT32, order, { n + n_pad, c + c_pad, h + h_pad, w + w_pad }); + auto shapeDesc3 = ShapeDescriptor::paddedBufferDescriptor(DataType::FLOAT32, order, { n, c, h, w }, { n_pad, c_pad }); + auto shapeDesc4 = ShapeDescriptor(DataType::FLOAT32, order, { n + n_pad, c + c_pad, h, w }); + auto shapeDesc5 = ShapeDescriptor::paddedBufferDescriptor(DataType::FLOAT32, order, { n, c, h, w }, { 0, 0, h_pad, w_pad }); + auto shapeDesc6 = ShapeDescriptor(DataType::FLOAT32, order, { n, c , h + h_pad, w + w_pad }); + + ASSERT_TRUE(shapeDesc1.validate() == SHAPE_DESC_OK); + ASSERT_TRUE(shapeDesc2.validate() == SHAPE_DESC_OK); + ASSERT_TRUE(shapeDesc3.validate() == SHAPE_DESC_OK); + ASSERT_TRUE(shapeDesc4.validate() == SHAPE_DESC_OK); + ASSERT_TRUE(shapeDesc5.validate() == SHAPE_DESC_OK); + ASSERT_TRUE(shapeDesc6.validate() == SHAPE_DESC_OK); + + ASSERT_TRUE(shapeDesc1.allocLength() == shapeDesc2.allocLength()); + ASSERT_TRUE(shapeDesc3.allocLength() == shapeDesc4.allocLength()); + ASSERT_TRUE(shapeDesc5.allocLength() == shapeDesc6.allocLength()); + + const auto& v1 = shapeDesc1.strides(); + const auto& v2 = shapeDesc2.strides(); + const auto& v3 = shapeDesc3.strides(); + const auto& v4 = shapeDesc4.strides(); + const auto& v5 = shapeDesc5.strides(); + const auto& v6 = shapeDesc6.strides(); + + for (int i = 0; i < v1.size(); i++) { + ASSERT_TRUE(v1[i] == v2[i]); + } + for (int i = 0; i < v3.size(); i++) { + ASSERT_TRUE(v3[i] == v4[i]); + } + for (int i = 0; i < v5.size(); i++) { + ASSERT_TRUE(v5[i] == v6[i]); + } + } + +} diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ContextTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ContextTests.cpp new file mode 100644 index 000000000..afbe62950 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ContextTests.cpp @@ -0,0 +1,358 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 30.10.2017. +// + +#include "testlayers.h" +#include + +using namespace sd; +using namespace sd::ops; +using namespace sd::graph; + +class ContextTests : public testing::Test { +public: + +}; + + +TEST_F(ContextTests, Basic_Test_1) { + VariableSpace variableSpace; + + auto _20 = NDArrayFactory::create_('c', {2, 2}); + auto _21 = NDArrayFactory::create_('c', {2, 2}); + + _20->assign(1.0f); + _21->assign(2.0f); + + variableSpace.putVariable(2, 0, _20); + variableSpace.putVariable(2, 1, _21); + + Context block(1, &variableSpace); + + block.pickInput(2, 0); + block.pickInput(2, 1); + + ASSERT_EQ(2, block.inputs()->size()); + ASSERT_EQ(2, block.width()); + + ASSERT_TRUE(variableSpace.hasVariable(2, 0)); + ASSERT_TRUE(variableSpace.hasVariable(2, 1)); + + ASSERT_NEAR(1.0f, block.variable(0)->getNDArray()->meanNumber().e(0), 1e-5); + ASSERT_NEAR(2.0f, block.variable(1)->getNDArray()->meanNumber().e(0), 1e-5); +} + + +TEST_F(ContextTests, Basic_Test_2) { + VariableSpace variableSpace; + + auto _20 = NDArrayFactory::create_('c', {2, 2}); + auto _21 = NDArrayFactory::create_('c', {2, 2}); + + _20->assign(1.0f); + _21->assign(2.0f); + + variableSpace.putVariable(-1, _20); + variableSpace.putVariable(-2, _21); + + Context block(1, &variableSpace); + + block.pickInput(-1); + block.pickInput(-2); + + ASSERT_EQ(2, block.inputs()->size()); + ASSERT_EQ(2, block.width()); + + ASSERT_TRUE(variableSpace.hasVariable(-1)); + ASSERT_TRUE(variableSpace.hasVariable(-2)); + + ASSERT_NEAR(1.0f, block.variable(0)->getNDArray()->meanNumber().e(0), 1e-5); + ASSERT_NEAR(2.0f, block.variable(1)->getNDArray()->meanNumber().e(0), 1e-5); +} + + +TEST_F(ContextTests, Basic_Test_3) { + VariableSpace variableSpace; + + Context ctx(1, &variableSpace); + + auto _20 = NDArrayFactory::create_('c', {2, 2}); + + ctx.pushNDArrayToVariableSpace(1, 1, _20); + + ASSERT_TRUE(variableSpace.hasVariable(1, 1)); +} + + +TEST_F(ContextTests, Basic_Test_4) { + VariableSpace variableSpace; + + Context ctx(1, &variableSpace); + + auto _20 = NDArrayFactory::create_('c', {2, 2}); + _20->linspace(1); + + auto _21 = NDArrayFactory::create_('c', {2, 2}); + _21->linspace(10); + + ctx.pushNDArrayToVariableSpace(1, 1, _20); + + ASSERT_TRUE(variableSpace.hasVariable(1, 1)); + + ctx.pushNDArrayToVariableSpace(1, 1, _21); + + auto vA = ctx.variable(1, 1); + + ASSERT_TRUE(vA->getNDArray()->equalsTo(_21)); +} + +TEST_F(ContextTests, Basic_Test_5) { + VariableSpace variableSpace; + + Context ctx(1, &variableSpace); + + auto _20 = NDArrayFactory::create_('c', {2, 2}); + _20->linspace(1); + + auto exp = new NDArray(_20->dup()); + + ctx.pushNDArrayToVariableSpace(1, 1, _20); + + ASSERT_TRUE(variableSpace.hasVariable(1, 1)); + + ctx.pushNDArrayToVariableSpace(1, 1, _20); + + auto vA = ctx.variable(1, 1); + + ASSERT_TRUE(vA->getNDArray() == _20); + + ASSERT_TRUE(vA->getNDArray()->equalsTo(exp)); + + delete exp; +} + + +TEST_F(ContextTests, Basic_Test_6) { + VariableSpace variableSpace; + + Context ctx(1, &variableSpace); + + auto v0 = ctx.ensureVariable(); + auto v1 = ctx.ensureVariable(1); + + ASSERT_TRUE(variableSpace.hasVariable(1, 0)); + ASSERT_TRUE(variableSpace.hasVariable(1, 1)); + + auto var0 = variableSpace.getVariable(1, 0); + auto var1 = variableSpace.getVariable(1, 1); + + ASSERT_TRUE(v0 == var0); + ASSERT_TRUE(v1 == var1); +} + + +TEST_F(ContextTests, Basic_Test_7) { + VariableSpace variableSpace; + + Context ctx(1, &variableSpace); + + auto v0 = ctx.ensureVariable(); + auto v1 = ctx.ensureVariable(1); + + ASSERT_TRUE(variableSpace.hasVariable(1, 0)); + ASSERT_TRUE(variableSpace.hasVariable(1, 1)); + + auto var0 = variableSpace.getVariable(1, 0); + auto var1 = variableSpace.getVariable(1, 1); + + ASSERT_TRUE(v0 == var0); + ASSERT_TRUE(v1 == var1); + + + auto _10 = NDArrayFactory::create_('c', {2, 2}); + _10->linspace(1); + + auto _11 = NDArrayFactory::create_('c', {2, 2}); + _11->linspace(10); + + ctx.pushNDArrayToVariableSpace(1, 0, _10); + ctx.pushNDArrayToVariableSpace(1, 1, _11); + + auto z0 = variableSpace.getVariable(1, 0); + auto z1 = variableSpace.getVariable(1, 1); + + ASSERT_TRUE(v0 == z0); + ASSERT_TRUE(v1 == z1); +} + +TEST_F(ContextTests, Basic_Test_8) { + VariableSpace variableSpace; + + Context ctx(1, &variableSpace); + + auto _10 = NDArrayFactory::create_('c', {2, 2}); + _10->linspace(1); + + auto _11 = NDArrayFactory::create_('c', {2, 2}); + _11->linspace(10); + + ctx.pushNDArrayToVariableSpace(1, 0, _10); + ctx.pushNDArrayToVariableSpace(1, 1, _11); + + auto z0 = variableSpace.getVariable(1, 0); + auto z1 = variableSpace.getVariable(1, 1); + + auto v0 = ctx.ensureVariable(); + auto v1 = ctx.ensureVariable(1); + + ASSERT_TRUE(v0 == z0); + ASSERT_TRUE(v1 == z1); +} + + +TEST_F(ContextTests, Basic_Test_9) { + VariableSpace variableSpace; + + auto in = NDArrayFactory::create('c', {5, 5}); + + Context ctx(1, &variableSpace, true); + ctx.pushNDArrayToVariableSpace(1, 1, &in, false); +} + +TEST_F(ContextTests, Basic_Test_10) { + VariableSpace variableSpace; + + Context ctx(119, &variableSpace); +} + + +TEST_F(ContextTests, Prototype_Test_1) { + ContextPrototype prototype(nullptr, 119, true); + prototype.pickInput(12, 3); + prototype.pickInput(12, 4); + + prototype.getTArguments()->push_back(2.0); + prototype.getTArguments()->push_back(-2.0); + + prototype.getIArguments()->push_back(17); + prototype.getIArguments()->push_back(119); + + Context ctx(&prototype, nullptr); + + ASSERT_EQ(ctx.nodeId(), prototype.nodeId()); + ASSERT_EQ(ctx.isInplace(), prototype.isInplace()); + + ASSERT_EQ(2, ctx.inputs()->size()); + ASSERT_EQ(2, ctx.getTArguments()->size()); + ASSERT_EQ(2, ctx.getIArguments()->size()); + + ASSERT_EQ(2.0, ctx.getTArguments()->at(0)); + ASSERT_EQ(-2.0, ctx.getTArguments()->at(1)); + + ASSERT_EQ(17, ctx.getIArguments()->at(0)); + ASSERT_EQ(119, ctx.getIArguments()->at(1)); +} + + +TEST_F(ContextTests, Prototype_Test_2) { + ContextPrototype prototype(nullptr, 119, false); + prototype.setOpNum(179); + + Context ctx(&prototype, nullptr); + + ASSERT_EQ(ctx.isInplace(), prototype.isInplace()); + ASSERT_EQ(ctx.opNum(), prototype.opNum()); + + ASSERT_EQ(0, ctx.inputs()->size()); + ASSERT_EQ(0, ctx.getTArguments()->size()); + ASSERT_EQ(0, ctx.getIArguments()->size()); +} + +TEST_F(ContextTests, test_short_context_1) { + auto array0 = NDArrayFactory::create('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto array1 = NDArrayFactory::create('c', {3, 2}, {-1.f, -2.f, -3.f, -4.f, -5.f, -6.f}); + Context ctx(1); + + ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), array0.specialBuffer(), array0.specialShapeInfo()); + ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), array1.specialBuffer(), array1.specialShapeInfo()); + + ASSERT_EQ(2, ctx.width()); + + auto input0 = ctx.array(0); + ASSERT_TRUE(input0 != nullptr); + + auto input1 = ctx.array(1); + ASSERT_TRUE(input1 != nullptr); + + ASSERT_TRUE(input0->buffer() == array0.buffer()); + ASSERT_TRUE(input0->shapeInfo() == array0.shapeInfo()); + + ASSERT_TRUE(input0->specialBuffer() == array0.specialBuffer()); + ASSERT_TRUE(input0->specialShapeInfo() == array0.specialShapeInfo()); + + ASSERT_TRUE(input1->buffer() == array1.buffer()); + ASSERT_TRUE(input1->shapeInfo() == array1.shapeInfo()); + + ASSERT_TRUE(input1->specialBuffer() == array1.specialBuffer()); + ASSERT_TRUE(input1->specialShapeInfo() == array1.specialShapeInfo()); +} + +TEST_F(ContextTests, test_short_context_2) { + auto array0 = NDArrayFactory::create('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto array1 = NDArrayFactory::create('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto z = NDArrayFactory::create('c', {3, 2}); + + auto exp = NDArrayFactory::create('c', {3, 2}, {2.f, 4.f, 6.f, 8.f, 10.f, 12.f}); + Context ctx(1); + + ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), array0.specialBuffer(), array0.specialShapeInfo()); + ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), array1.specialBuffer(), array1.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + + ASSERT_EQ(2, ctx.width()); + + sd::ops::add op; + op.execute(&ctx); + + ASSERT_EQ(exp, z); +} + +TEST_F(ContextTests, test_short_context_3) { + auto array0 = NDArrayFactory::create('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto array1 = NDArrayFactory::create('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + + auto exp = NDArrayFactory::create('c', {3, 2}, {2.f, 4.f, 6.f, 8.f, 10.f, 12.f}); + Context ctx(1); + + ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), array0.specialBuffer(), array0.specialShapeInfo()); + ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), array1.specialBuffer(), array1.specialShapeInfo()); + + ASSERT_EQ(2, ctx.width()); + + sd::ops::add op; + op.execute(&ctx); + + ASSERT_EQ(1, ctx.fastpath_out().size()); + + auto z = ctx.fastpath_out()[0]; + + ASSERT_EQ(exp, *z); +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ConvolutionTests1.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ConvolutionTests1.cpp new file mode 100644 index 000000000..3cc8be96e --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ConvolutionTests1.cpp @@ -0,0 +1,2921 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef LIBND4J_CONVOLUTIONTESTS1_H +#define LIBND4J_CONVOLUTIONTESTS1_H + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef HAVE_MKLDNN +#include +#endif + +using namespace sd; +using namespace sd::graph; + +class ConvolutionTests1 : public testing::Test { +public: + +}; + +template +class TypedConvolutionTests1 : public testing::Test { +public: + +}; + +typedef ::testing::Types TestingTypes; +TYPED_TEST_CASE(TypedConvolutionTests1, TestingTypes); + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv2d_1) { + + int bS=1, iH=5,iW=4, iC=2,oC=3, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + TypeParam _expB[]{664.0, 700.0, 736.0, 344.0, 808.0, 844.0, 880.0, 408.0, 952.0, 988.0, 1024.0, 472.0, 1096.0, 1132.0, 1168.0, 536.0, 466.0, 480.0, 494.0, 220.0, 1528.0, 1628.0, 1728.0, 856.0, 1928.0, 2028.0, 2128.0, 1048.0, 2328.0, 2428.0, 2528.0, 1240.0, 2728.0, 2828.0, 2928.0, 1432.0, 1346.0, 1392.0, 1438.0, 700.0, 2392.0, 2556.0, 2720.0, 1368.0, 3048.0, 3212.0, 3376.0, 1688.0, 3704.0, 3868.0, 4032.0, 2008.0, 4360.0, 4524.0, 4688.0, 2328.0, 2226.0, 2304.0, 2382.0, 1180.0}; + Nd4jLong _expS[]{4, 1, 3, 5, 4, 60, 20, 4, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99}; + auto input = NDArrayFactory::create_('c', {bS, iC, iH, iW}); + auto weights = NDArrayFactory::create_('c', {oC, iC, kH, kW}); + for (int e = 0; e < input->lengthOf(); e++) + input->p(e, e + 1); + + for (int e = 0; e < weights->lengthOf(); e++) + weights->p(e, e + 1); + weights->permutei({2,3,1,0}); + + // weights->printShapeInfo("weights"); + + ArrayOptions::setDataType(_expS, input->dataType()); + auto exp = new NDArray(_expB, _expS); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, input); + variableSpace->putVariable(-2, weights); + + auto block = new Context(1, variableSpace, false); // not-in-place + block->fillInputs({-1, -2}); + // 5,5 kernel + block->getIArguments()->push_back(kH); + block->getIArguments()->push_back(kW); + + // 1,1 stride + block->getIArguments()->push_back(sH); + block->getIArguments()->push_back(sW); + + // 0,0 padding + block->getIArguments()->push_back(pH); + block->getIArguments()->push_back(pW); + + // 1,1 dilation + block->getIArguments()->push_back(dH); + block->getIArguments()->push_back(dW); + + // same mode + block->getIArguments()->push_back(1); + + // is NHWC + block->getIArguments()->push_back(0); + + sd::ops::conv2d op; + + Nd4jStatus status = op.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto res = variableSpace->getVariable(1)->getNDArray(); + + + // checking output shape + ASSERT_EQ(1, res->sizeAt(0)); + ASSERT_EQ(3, res->sizeAt(1)); + ASSERT_EQ(5, res->sizeAt(2)); + ASSERT_EQ(4, res->sizeAt(3)); + + // basically the same as above + ASSERT_TRUE(res->isSameShape(exp)); + // just for visual validation + // exp->printIndexedBuffer("Expected"); + // res->printIndexedBuffer("Actual "); + // res->printShapeInfo("Result shape"); + // final check + ASSERT_TRUE(res->equalsTo(exp)); + + delete block; + delete variableSpace; + delete exp; +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv2d_2) { + auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto exp = NDArrayFactory::create('c', {1, 4, 1, 4}, {2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f}); + + weights.assign(2.0); + input.linspace(1); + + sd::ops::conv2d op; + auto result = op.evaluate({&input, &weights}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv2d_3) { + + int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=4,oW=3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); + auto bias = NDArrayFactory::create('c', {oC}, {1.f, 2.f, 3.f}); + + + auto expOutput = NDArrayFactory::create('c', {bS, oH, oW, oC},{ 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, + 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f, + 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, + 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f}); + input = 2.; + weights.linspace(0.1, 0.1); + + sd::ops::conv2d op; + auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv2d_4) { + + int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=2,oW=2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); + auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); + + auto expOutput = NDArrayFactory::create('c', {bS, oH, oW, oC},{ 170.4f,175.20001f,180.f,170.4f,175.20001f,180.f,170.4f,175.20001f,180.f,170.4f,175.20001f,180.f,170.4f,175.20001f,180.f,170.4f,175.20001f,180.f,170.4f,175.20001f,180.f,170.4f,175.20001f,180.f}); + + input = 2.; + weights.linspace(0.1, 0.1); + + sd::ops::conv2d op; + auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv2d_5) { + + int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=2,oW=2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto weights = NDArrayFactory::create('c', {oC, iC, kH, kW}); + auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); + + auto expOutput = NDArrayFactory::create('c', {bS, oC, oH, oW}, {61.f, 61.f, 61.f, 61.f, 177.2f, 177.2f, 177.2f, 177.2f, 293.4f, 293.4f, 293.4f, 293.4f, 61.f, 61.f, 61.f, 61.f, 177.2f, 177.2f, 177.2f, 177.2f, 293.4f, 293.4f, 293.4f, 293.4f}); + + input = 2.; + weights.linspace(0.1, 0.1); + weights.permutei({2,3,1,0}); + + sd::ops::conv2d op; + auto results = op.evaluate({&input, &weights, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto output = results.at(0); + + // output->printIndexedBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv2d_6) { + auto input = NDArrayFactory::create('c', {54, 1, 12, 12}); + auto weights = NDArrayFactory::create('c', {1, 2, 12, 2}); + + sd::ops::conv2d op; + auto result = op.evaluate({&input, &weights}, {}, {-1,-1, 1,1, 0,0, 1,1, 1,1}); + ASSERT_EQ(Status::OK(), result.status()); +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv2d_7) { + + int bS=1, iH=256,iW=256, iC=1,oC=1, kH=4,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + // int oH=256,oW=256; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); + + input = 5.; + weights = 3.; + + sd::ops::conv2d op; + auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto* output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv2d_8) { + + int bS=1, iH=6,iW=8, iC=2,oC=2, kH=2,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=6,oW=8; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iH, iW}, {0.679350, 0.355087, 0.842789, 0.200313, 0.701499, 0.310693, 0.447940, 0.938010, 0.326674, 0.151873, 0.383318, 0.782123, 0.198807, + 0.798564, 0.163263, 0.146968, 0.260897, 0.135058, 0.756209, 0.275454, 0.369088, 0.092826, 0.836492, 0.268413, 0.095062, 0.312795, 0.135918, 0.517544, 0.328703, + 0.061736, 0.396431, 0.248016, 0.548959, 0.115046, 0.814362, 0.721564, 0.404494, 0.299089, 0.403884, 0.988311, 0.022296, 0.927782, 0.318416, 0.068546, 0.284533, + 0.232720, 0.352142, 0.058909, 0.711221, 0.674457, 0.196946, 0.699497, 0.074322, 0.420425, 0.584263, 0.149574, 0.446406, 0.723072, 0.064481, 0.483078, 0.875996, + 0.569819, 0.445863, 0.527755, 0.016646, 0.753678, 0.140636, 0.754129, 0.161932, 0.775037, 0.332645, 0.117394, 0.017711, 0.608476, 0.525152, 0.917194, 0.849891, + 0.589423, 0.852278, 0.390636, 0.889683, 0.669445, 0.698873, 0.961480, 0.157401, 0.157364, 0.493520, 0.569937, 0.126832, 0.115728, 0.786368, 0.737939, 0.490079, 0.608414, 0.956500, 0.390098}); + + NDArray weights('c', {kH, kW, iC, oC}, {0.07581716775894165, 0.8706002235412598, 0.29345420002937317, 0.5281786322593689, 0.10540834069252014, 0.3663792014122009, 0.17209206521511078, 0.6257694959640503}); + NDArray bias('c', {1, oC}, {0.7414038777351379, 0.8980839848518372}); + + NDArray expOutput('c', {bS, oC, oH, oW}, {1.112878, 1.106691, 0.914598, 1.127438, 0.988108, 1.070572, 1.040759, 0.962728, 0.927537, 1.109045, 0.893301, 1.101278, 1.080314, + 1.112327, 1.030041, 0.955914, 0.779137, 1.110499, 0.944709, 1.195986, 0.997814, 1.083822, 1.090898, 0.889572, 0.964781, 1.071012, 1.111928, 1.291319, 1.085454, 0.977661, + 1.149068, 1.077099, 1.068283, 1.064290, 1.177125, 1.212480, 0.932593, 0.939493, 1.118576, 1.056927, 0.780314, 0.845707, 0.996308, 0.963152, 0.906792, 0.937590, 1.048791, + 0.860346, 2.264212, 2.071576, 1.916629, 2.030785, 2.169075, 2.039786, 1.935480, 2.177816, 1.524273, 1.933327, 1.630923, 2.406983, 1.770406, 2.413284, 1.790349, 1.476586, + 1.179925, 1.909109, 2.009143, 2.299778, 1.957207, 1.779718, 2.480604, 1.529086, 1.748063, 1.952856, 2.029487, 2.699131, 1.879842, 1.471205, 2.150177, 2.039078, 1.933456, + 1.764169, 2.584944, 2.521004, 1.744296, 1.707578, 2.237938, 2.325231, 0.984485, 1.766936, 1.590640, 1.347524, 1.404648, 1.422042, 1.709862, 1.155412}); + + sd::ops::conv2d op; + auto results = op.evaluate({&input, &weights, &bias}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto output = results.at(0); + + // output->printBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv2d_9) { + + int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=2,oW=2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = 1; // 0-[kH, kW, iC, oC], 1-[oC, iC, kH, kW], 2-[oC, kH, kW, iC] + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray weights('c', {oC, iC, kH, kW}, {-3., -1.8, -0.6, 0.6, 1.8, 3., -2.7, -1.5, -0.3, 0.9, 2.1, 3.3, -2.4, -1.2, 0., 1.2, 2.4, 3.6, -2.1, -0.9, 0.3, 1.5, + 2.7, 3.9, -2.9, -1.7, -0.5, 0.7, 1.9, 3.1, -2.6, -1.4, -0.2, 1., 2.2, 3.4, -2.3, -1.1, 0.1, 1.3, 2.5, 3.7, -2., -0.8, 0.4, 1.6, + 2.8, 4., -2.8, -1.6, -0.4, 0.8, 2., 3.2, -2.5, -1.3, -0.1, 1.1, 2.3, 3.5, -2.2, -1., 0.2, 1.4, 2.6, 3.8, -1.9, -0.7, 0.5, 1.7, 2.9, 4.1}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {-1,2,0.5}, sd::DataType::FLOAT32); + + NDArray expOutput('c', {bS, oC, oH, oW}, {37.699997, 32.300041, 21.499989, 16.100004, 74.900024, 68.300003, 55.100006, 48.499969, 107.599983, 99.799988, + 84.200005, 76.400009, -221.5, -226.899994, -237.699997, -243.099991, -241.899994, -248.5, -261.700012, -268.299988, + -266.799988, -274.600006, -290.200012, -298.}, sd::DataType::FLOAT32); + + input.linspace(25,-0.5); + + sd::ops::conv2d op; + auto results = op.evaluate({&input, &weights, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv2d_10) { + + int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=4,oW=3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = 2; // 0-[kH, kW, iC, oC], 1-[oC, iC, kH, kW], 2-[oC, kH, kW, iC] + + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {oC, kH, kW, iC}, {-3., -2.7, -2.4, -2.1, -1.8, -1.5, -1.2, -0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9, 1.2, 1.5, 1.8, 2.1, 2.4, 2.7, 3., 3.3, + 3.6, 3.9, -2.9, -2.6, -2.3, -2., -1.7, -1.4, -1.1, -0.8, -0.5, -0.2, 0.1, 0.4, 0.7, 1., 1.3, 1.6, 1.9, 2.2, 2.5, 2.8, + 3.1, 3.4, 3.7, 4., -2.8, -2.5, -2.2, -1.9, -1.6, -1.3, -1., -0.7, -0.4, -0.1, 0.2, 0.5, 0.8, 1.1, 1.4, 1.7, 2., 2.3, 2.6, + 2.9, 3.2, 3.5, 3.8, 4.1}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {-1,2,0.5}, sd::DataType::FLOAT32); + + NDArray expOutput('c', {bS, oH, oW, oC}, {463.400055, 498.800018, 529.700012, 410.600006, 442.799988, 470.500031, 113.600006, 130.400009, 142.699982, + -63.999958, -19.600082, 20.300007, -85.600052, -45.999939, -10.899940, -144.100021, -124., -108.399994, -128.799988, -98.799973, -73.300011, + -150.400009, -125.200012, -104.500008, -133.300003, -120.399994, -112.000008, -170.199997, -154., -142.299988, -146.200012, -133.199997, -124.699997, + -88.000008, -80.800003, -78.099991, -170.200012, -173.199997, -180.699982, -223., -229.199997, -239.900009, -88., -90.400002, -97.300003, -323.200012, + -336.399994, -354.100037, -344.800018, -362.799988, -385.299957, -100.900002, -109.600006, -122.800003, -388.000031, -415.599976, -447.700012, -409.599976, + -442., -478.900024, -90.099991, -105.999992, -126.399994, 117.800003, 95.599991, 68.899994, 141.799988, 116.399994, 86.5, 171.200012, 159.200012, 142.699997}, sd::DataType::FLOAT32); + + input.linspace(25,-0.5); + + sd::ops::conv2d op; + auto results = op.evaluate({&input, &weights, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, sconv2d_1) { + float _expB[] = {10025.0f, 10350.0f, 10675.0f, 11000.0f, 11325.0f, 11650.0f, 13275.0f, 13600.0f, 13925.0f, 14250.0f, 14575.0f, 14900.0f, 16525.0f, 16850.0f, 17175.0f, 17500.0f, 17825.0f, 18150.0f, 19775.0f, 20100.0f, 20425.0f, 20750.0f, 21075.0f, 21400.0f, 23025.0f, 23350.0f, 23675.0f, 24000.0f, 24325.0f, 24650.0f, 26275.0f, 26600.0f, 26925.0f, 27250.0f, 27575.0f, 27900.0f, 38775.0f, 40350.0f, 41925.0f, 43500.0f, 45075.0f, 46650.0f, 54525.0f, 56100.0f, 57675.0f, 59250.0f, 60825.0f, 62400.0f, 70275.0f, 71850.0f, 73425.0f, 75000.0f, 76575.0f, 78150.0f, 86025.0f, 87600.0f, 89175.0f, 90750.0f, 92325.0f, 93900.0f, 101775.0f, 103350.0f, 104925.0f, 106500.0f, 108075.0f, 109650.0f, 117525.0f, 119100.0f, 120675.0f, 122250.0f, 123825.0f, 125400.0f, 67525.0f, 70350.0f, 73175.0f, 76000.0f, 78825.0f, 81650.0f, 95775.0f, 98600.0f, 101425.0f, 104250.0f, 107075.0f, 109900.0f, 124025.0f, 126850.0f, 129675.0f, 132500.0f, 135325.0f, 138150.0f, 152275.0f, 155100.0f, 157925.0f, 160750.0f, 163575.0f, 166400.0f, 180525.0f, 183350.0f, 186175.0f, 189000.0f, 191825.0f, 194650.0f, 208775.0f, 211600.0f, 214425.0f, 217250.0f, 220075.0f, 222900.0f, 119400.0f, 120350.0f, 121300.0f, 122250.0f, 123200.0f, 124150.0f, 128900.0f, 129850.0f, 130800.0f, 131750.0f, 132700.0f, 133650.0f, 138400.0f, 139350.0f, 140300.0f, 141250.0f, 142200.0f, 143150.0f, 147900.0f, 148850.0f, 149800.0f, 150750.0f, 151700.0f, 152650.0f, 157400.0f, 158350.0f, 159300.0f, 160250.0f, 161200.0f, 162150.0f, 166900.0f, 167850.0f, 168800.0f, 169750.0f, 170700.0f, 171650.0f, 273150.0f, 275350.0f, 277550.0f, 279750.0f, 281950.0f, 284150.0f, 295150.0f, 297350.0f, 299550.0f, 301750.0f, 303950.0f, 306150.0f, 317150.0f, 319350.0f, 321550.0f, 323750.0f, 325950.0f, 328150.0f, 339150.0f, 341350.0f, 343550.0f, 345750.0f, 347950.0f, 350150.0f, 361150.0f, 363350.0f, 365550.0f, 367750.0f, 369950.0f, 372150.0f, 383150.0f, 385350.0f, 387550.0f, 389750.0f, 391950.0f, 394150.0f, 426900.0f, 430350.0f, 433800.0f, 437250.0f, 440700.0f, 444150.0f, 461400.0f, 464850.0f, 468300.0f, 471750.0f, 475200.0f, 478650.0f, 495900.0f, 499350.0f, 502800.0f, 506250.0f, 509700.0f, 513150.0f, 530400.0f, 533850.0f, 537300.0f, 540750.0f, 544200.0f, 547650.0f, 564900.0f, 568350.0f, 571800.0f, 575250.0f, 578700.0f, 582150.0f, 599400.0f, 602850.0f, 606300.0f, 609750.0f, 613200.0f, 616650.0f, 75025.0f, 75350.0f, 75675.0f, 76000.0f, 76325.0f, 76650.0f, 78275.0f, 78600.0f, 78925.0f, 79250.0f, 79575.0f, 79900.0f, 81525.0f, 81850.0f, 82175.0f, 82500.0f, 82825.0f, 83150.0f, 84775.0f, 85100.0f, 85425.0f, 85750.0f, 86075.0f, 86400.0f, 88025.0f, 88350.0f, 88675.0f, 89000.0f, 89325.0f, 89650.0f, 91275.0f, 91600.0f, 91925.0f, 92250.0f, 92575.0f, 92900.0f, 353775.0f, 355350.0f, 356925.0f, 358500.0f, 360075.0f, 361650.0f, 369525.0f, 371100.0f, 372675.0f, 374250.0f, 375825.0f, 377400.0f, 385275.0f, 386850.0f, 388425.0f, 390000.0f, 391575.0f, 393150.0f, 401025.0f, 402600.0f, 404175.0f, 405750.0f, 407325.0f, 408900.0f, 416775.0f, 418350.0f, 419925.0f, 421500.0f, 423075.0f, 424650.0f, 432525.0f, 434100.0f, 435675.0f, 437250.0f, 438825.0f, 440400.0f, 632525.0f, 635350.0f, 638175.0f, 641000.0f, 643825.0f, 646650.0f, 660775.0f, 663600.0f, 666425.0f, 669250.0f, 672075.0f, 674900.0f, 689025.0f, 691850.0f, 694675.0f, 697500.0f, 700325.0f, 703150.0f, 717275.0f, 720100.0f, 722925.0f, 725750.0f, 728575.0f, 731400.0f, 745525.0f, 748350.0f, 751175.0f, 754000.0f, 756825.0f, 759650.0f, 773775.0f, 776600.0f, 779425.0f, 782250.0f, 785075.0f, 787900.0f, 309400.0f, 310350.0f, 311300.0f, 312250.0f, 313200.0f, 314150.0f, 318900.0f, 319850.0f, 320800.0f, 321750.0f, 322700.0f, 323650.0f, 328400.0f, 329350.0f, 330300.0f, 331250.0f, 332200.0f, 333150.0f, 337900.0f, 338850.0f, 339800.0f, 340750.0f, 341700.0f, 342650.0f, 347400.0f, 348350.0f, 349300.0f, 350250.0f, 351200.0f, 352150.0f, 356900.0f, 357850.0f, 358800.0f, 359750.0f, 360700.0f, 361650.0f, 713150.0f, 715350.0f, 717550.0f, 719750.0f, 721950.0f, 724150.0f, 735150.0f, 737350.0f, 739550.0f, 741750.0f, 743950.0f, 746150.0f, 757150.0f, 759350.0f, 761550.0f, 763750.0f, 765950.0f, 768150.0f, 779150.0f, 781350.0f, 783550.0f, 785750.0f, 787950.0f, 790150.0f, 801150.0f, 803350.0f, 805550.0f, 807750.0f, 809950.0f, 812150.0f, 823150.0f, 825350.0f, 827550.0f, 829750.0f, 831950.0f, 834150.0f, 1116900.0f, 1120350.0f, 1123800.0f, 1127250.0f, 1130700.0f, 1134150.0f, 1151400.0f, 1154850.0f, 1158300.0f, 1161750.0f, 1165200.0f, 1168650.0f, 1185900.0f, 1189350.0f, 1192800.0f, 1196250.0f, 1199700.0f, 1203150.0f, 1220400.0f, 1223850.0f, 1227300.0f, 1230750.0f, 1234200.0f, 1237650.0f, 1254900.0f, 1258350.0f, 1261800.0f, 1265250.0f, 1268700.0f, 1272150.0f, 1289400.0f, 1292850.0f, 1296300.0f, 1299750.0f, 1303200.0f, 1306650.0f,}; + Nd4jLong _expS[] = {4, 2, 6, 6, 6, 144, 36, 6, 1, 8192, 1, 99}; + NDArray exp(_expB, _expS); + + int sY = 1; + int sX = 1; + int pY = 0; + int pX = 0; + int iC = 2; + int oC = 3; + int kY = 5; + int kX = 5; + int iY = 10; + int iX = 10; + int B = 2; + + auto input = NDArrayFactory::create_('c', {B, iC, iY, iX}); + for (int e = 0; e < input->lengthOf(); e++) + input->p(e, e+1); + + auto weights = NDArrayFactory::create_('c', {oC, iC, kY, kX}); + for (int e = 0; e < weights->lengthOf(); e++) + weights->p(e, e+1); + weights->permutei({2,3,1,0}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, input); + variableSpace->putVariable(-2, weights); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); + + block->getIArguments()->push_back(kY); + block->getIArguments()->push_back(kX); + + block->getIArguments()->push_back(sY); + block->getIArguments()->push_back(sX); + + block->getIArguments()->push_back(pY); + block->getIArguments()->push_back(pX); + + // dilation + block->getIArguments()->push_back(1); + block->getIArguments()->push_back(1); + + // NOT same mode + block->getIArguments()->push_back(0); + + sd::ops::sconv2d op; + + Nd4jStatus status = op.execute(block); + + ASSERT_EQ(ND4J_STATUS_OK, status); + auto output = variableSpace->getVariable(1)->getNDArray(); + + //exp.printShapeInfo("Expected shape"); + //output->printShapeInfo("Result shape"); + ASSERT_TRUE(exp.isSameShape(output)); + + //exp.printBuffer("Expctd buffer"); + //output->printBuffer("Result buffer"); + ASSERT_TRUE(exp.equalsTo(output)); + + delete block; + delete variableSpace; +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, sconv2d_2) { + TypeParam _expBFF[] = {108.9405008f, 109.5920008f, 110.2435008f, 110.8950008f, 111.5465008f, 112.1980008f, 115.4555008f, 116.1070008f, 116.7585008f, 117.410000f, 118.061500f, 118.7130009f, 121.9705009f, 122.6220009f, 123.2735009f, 123.9250009f, 124.5765009f, 125.2280009f, 128.4855009f, 129.1370009f, 129.7885009f, 130.4400009f, 131.09150f, 131.74300f, 135.0005010f, 135.6520010f, 136.3035010f, 136.9550010f, 137.6065010f, 138.2580010f, 141.5155010f, 142.1670010f, 142.8185010f, 143.4700010f, 144.1215010f, 144.7730010f, 248.9617514f, 250.670751f, 252.3797515f, 254.0887515f, 255.7977515f, 257.5067515f, 266.0517515f, 267.7607515f, 269.469751f, 271.1787516f, 272.8877516f, 274.5967516f, 283.1417516f, 284.8507516f, + 286.5597516f, 288.268751f, 289.9777517f, 291.6867517f, 300.2317517f, 301.9407517f, 303.6497517f, 305.3587517f, 307.067751f, 308.7767518f, 317.3217518f, 319.0307518f, 320.7397518f, 322.4487518f, 324.157751f, 325.866751f, 334.4117519f, 336.1207519f, 337.8297519f, 339.5387519f, 341.2477519f, 342.95675f, 388.9829964f, 391.7494964f, 394.5159964f, 397.2824964f, 400.048996f, 402.8154963f, 416.647996f, 419.4144962f, 422.1809962f, 424.9474962f, 427.7139962f, 430.4804962f, 444.3129961f, 447.0794961f, 449.8459961f, 452.6124960f, 455.3789960f, 458.1454960f, 471.9779959f, 474.7444959f, 477.5109959f, 480.2774959f, 483.0439959f, 485.8104958f, 499.6429958f, 502.4094957f, 505.1759957f, 507.9424957f, + 510.7089957f, 513.4754957f, 527.3079956f, 530.0744956f, 532.8409956f, 535.607495f, 538.3739955f, 541.1404955f, 529.0042487f, 532.8282487f, 536.6522487f, 540.4762487f, 544.3002487f, 548.1242487f, 567.2442487f, 571.068248f, 574.892248f, 578.716248f, 582.540248f, 586.3642486f, 605.4842486f, 609.3082486f, 613.1322486f, 616.9562486f, 620.7802486f, 624.6042486f, 643.7242486f, 647.5482486f, 651.3722486f, 655.1962486f, 659.0202486f, 662.8442486f, 681.9642486f, 685.7882486f, 689.6122486f, 693.4362486f, 697.2602486f, 701.0842486f, 720.2042486f, 724.0282486f, 727.852248f, 731.676248f, 735.500248f, 739.324248f, 669.0255044f, 673.9070044f, 678.7885044f, 683.6700044f, 688.5515044f, 693.4330044f, + 717.8405044f, 722.7220044f, 727.6035044f, 732.4850044f, 737.3665044f, 742.2480044f, 766.6555043f, 771.5370043f, 776.4185043f, 781.3000043f, 786.1815043f, 791.0630043f, 815.4705043f, 820.3520043f, 825.2335043f, 830.1150043f, 834.9965043f, 839.8780043f, 864.2855042f, 869.1670042f, 874.0485042f, 878.9300042f, 883.8115042f, 888.6930042f, 913.1005042f, 917.9820042f, 922.8635042f, 927.7450042f, 932.6265042f, 937.5080042f, 809.0467424f, 814.9857424f, 820.9247424f, 826.8637423f, 832.8027423f, 838.7417423f, 868.4367421f, 874.3757421f, 880.3147420f, 886.2537420f, 892.1927420f, 898.13174f, 927.8267418f, 933.7657418f, 939.7047417f, 945.6437417f, 951.5827417f, 957.5217416f, 987.2167415f, 993.155741f, + 999.0947414f, 1005.0337414f, 1010.972741f, 1016.9117413f, 1046.6067412f, 1052.5457411f, 1058.4847411f, 1064.4237411f, 1070.3627410f, 1076.3017410f, 1105.996740f, 1111.9357408f, 1117.8747408f, 1123.8137408f, 1129.7527407f, 1135.6917407f, 949.0679815f, 956.0644814f, 963.060981f, 970.0574813f, 977.0539812f, 984.0504811f, 1019.0329807f, 1026.0294807f, 1033.0259806f, 1040.0224805f, 1047.0189804f, 1054.0154804f, 1088.9979800f, 1095.9944799f, 1102.9909798f, 1109.987479f, 1116.9839797f, 1123.9804796f, 1158.9629792f, 1165.9594791f, 1172.9559791f, 1179.9524790f, 1186.9489789f, 1193.9454788f, 1228.9279785f, 1235.9244784f, 1242.9209783f, 1249.9174782f, 1256.913978f, 1263.9104781f, 1298.8929777f, 1305.8894776f, 1312.8859775f, 1319.8824775f, 1326.8789774f, 1333.8754773f, 1089.0892560f, 1097.1432561f, 1105.1972562f, 1113.251256f, 1121.3052563f, 1129.3592564f, 1169.6292568f, 1177.6832568f, 1185.7372569f, 1193.7912570f, 1201.845257f, 1209.8992571f, 1250.1692575f, 1258.2232576f, 1266.2772576f, 1274.3312577f, 1282.3852578f, 1290.4392579f, 1330.7092582f, 1338.7632583f, 1346.8172584f, 1354.8712584f, 1362.9252585f, 1370.9792586f, 1411.24925f, 1419.3032590f, 1427.3572591f, 1435.4112592f, 1443.465259f, 1451.5192593f, 1491.7892597f, 1499.8432598f, 1507.8972598f, 1515.9512599f, 1524.0052600f, 1532.059260f, 1229.1105073f, 1238.2220073f, 1247.3335073f, 1256.4450073f, 1265.5565073f, 1274.668007f, 1320.2255074f, 1329.3370074f, 1338.4485074f, 1347.5600075f, 1356.6715075f, 1365.7830075f, 1411.340507f, 1420.4520076f, 1429.5635076f, 1438.6750076f, 1447.7865076f, 1456.8980076f, 1502.4555077f, 1511.5670077f, 1520.6785077f, 1529.7900077f, 1538.9015077f, 1548.013007f, 1593.5705078f, 1602.6820078f, 1611.793507f, 1620.9050079f, 1630.0165079f, 1639.1280079f, 1684.6855080f, 1693.7970080f, 1702.9085080f, 1712.0200080f, 1721.1315080f, 1730.2430080f, 1369.1317613f, 1379.3007614f, 1389.4697614f, 1399.6387615f, 1409.8077615f, 1419.976761f, 1470.8217618f, 1480.9907618f, 1491.159761f, 1501.3287619f, 1511.4977619f, 1521.6667620f, 1572.5117622f, 1582.6807622f, 1592.8497623f, 1603.0187623f, 1613.1877624f, 1623.3567624f, 1674.2017626f, 1684.3707627f, 1694.5397627f, 1704.7087628f, 1714.8777628f, 1725.046762f, 1775.8917631f, 1786.0607631f, 1796.229763f, 1806.3987632f, 1816.5677632f, 1826.7367633f, 1877.5817635f, 1887.7507635f, 1897.9197636f, 1908.0887636f, 1918.2577637f, 1928.4267637f, 304.3905022f, 305.0420022f, 305.6935022f, 306.3450022f, 306.9965022f, 307.6480022f, 310.9055022f, 311.5570022f, 312.208502f, 312.860002f, 313.5115023f, 314.1630023f, 317.4205023f, 318.0720023f, 318.7235023f, 319.3750023f, 320.0265023f, 320.6780023f, 323.9355023f, 324.5870023f, 325.2385023f, 325.8900023f, 326.541502f, 327.193002f, 330.4505024f, 331.1020024f, 331.7535024f, 332.4050024f, 333.0565024f, 333.7080024f, 336.9655024f, 337.6170024f, 338.2685024f, 338.9200024f, 339.5715024f, 340.223002f, 761.6617542f, 763.3707542f, 765.0797542f, 766.7887542f, 768.4977542f, 770.206754f, 778.7517543f, 780.4607543f, 782.1697543f, 783.8787543f, 785.5877543f, 787.2967543f, 795.8417544f, 797.5507544f, 799.2597544f, 800.9687544f, 802.6777544f, 804.3867544f, 812.9317545f, 814.6407545f, 816.3497545f, 818.0587545f, 819.7677545f, 821.4767545f, 830.0217546f, 831.7307546f, 833.4397546f, 835.1487546f, 836.8577546f, 838.5667546f, 847.1117547f, 848.8207547f, 850.5297547f, 852.2387547f, 853.9477547f, 855.6567547f, 1218.9329915f, 1221.6994915f, 1224.4659915f, 1227.232491f, 1229.9989914f, 1232.7654914f, 1246.5979913f, 1249.3644913f, 1252.1309913f, 1254.8974913f, 1257.6639913f, 1260.430491f, 1274.2629912f, 1277.029491f, 1279.7959911f, 1282.5624911f, 1285.3289911f, 1288.0954911f, 1301.9279910f, 1304.6944910f, 1307.4609910f, 1310.22749f, 1312.9939909f, 1315.7604909f, 1329.5929908f, 1332.3594908f, 1335.1259908f, 1337.8924908f, 1340.6589908f, 1343.4254908f, 1357.2579907f, + 1360.0244907f, 1362.7909906f, 1365.5574906f, 1368.3239906f, 1371.0904906f, 1676.2042479f, 1680.0282479f, 1683.8522479f, 1687.6762479f, 1691.5002479f, 1695.3242479f, 1714.4442479f, 1718.2682479f, 1722.0922479f, 1725.9162479f, 1729.7402479f, 1733.5642479f, 1752.6842479f, 1756.5082479f, 1760.3322479f, 1764.1562479f, 1767.9802479f, 1771.8042479f, 1790.9242479f, 1794.7482479f, 1798.5722479f, 1802.3962479f, 1806.2202479f, 1810.044247f, 1829.1642478f, 1832.9882478f, 1836.8122478f, 1840.6362478f, 1844.4602478f, 1848.2842478f, 1867.4042478f, 1871.2282478f, 1875.0522478f, 1878.8762478f, 1882.7002478f, 1886.5242478f, 2133.4755029f, 2138.3570029f, 2143.2385029f, 2148.1200029f, 2153.0015029f, 2157.8830029f, 2182.2905028f, 2187.1720028f, 2192.0535028f, 2196.9350028f, 2201.8165028f, 2206.6980028f, 2231.1055028f, 2235.9870028f, 2240.8685028f, 2245.7500028f, 2250.6315028f, 2255.5130028f, 2279.9205027f, 2284.8020027f, 2289.6835027f, 2294.5650027f, 2299.4465027f, 2304.3280027f, 2328.7355027f, 2333.6170027f, 2338.4985027f, 2343.3800027f, 2348.2615027f, 2353.1430027f, 2377.5505026f, 2382.4320026f, 2387.3135026f, 2392.1950026f, 2397.0765026f, 2401.9580026f, 2590.7467330f, 2596.6857330f, 2602.6247329f, 2608.5637329f, 2614.5027329f, 2620.441732f, 2650.1367327f, 2656.0757327f, 2662.0147326f, 2667.9537326f, 2673.8927326f, 2679.8317325f, 2709.5267324f, 2715.465732f, 2721.4047323f, 2727.3437323f, 2733.282732f, 2739.2217322f, 2768.9167321f, 2774.8557320f, 2780.7947320f, 2786.7337320f, 2792.6727319f, 2798.6117319f, 2828.306731f, 2834.2457317f, 2840.1847317f, 2846.1237317f, 2852.0627316f, 2858.0017316f, 2887.6967314f, 2893.6357314f, 2899.5747314f, 2905.5137313f, 2911.4527313f, 2917.3917313f, 3048.0179587f, 3055.0144586f, 3062.0109585f, 3069.0074584f, 3076.0039584f, 3083.0004583f, 3117.9829579f, 3124.9794578f, 3131.9759578f, 3138.9724577f, 3145.9689576f, 3152.9654575f, 3187.947957f, 3194.9444571f, 3201.9409570f, 3208.9374569f, 3215.933956f, 3222.9304568f, 3257.9129564f, 3264.9094563f, 3271.9059562f, 3278.9024562f, 3285.8989561f, + 3292.8954560f, 3327.8779556f, 3334.874455f, 3341.8709555f, 3348.8674554f, 3355.8639553f, 3362.860455f, 3397.8429549f, 3404.8394548f, 3411.8359547f, 3418.8324546f, 3425.8289546f, 3432.8254545f, 3505.28927f, 3513.3432780f, 3521.3972781f, 3529.4512782f, 3537.5052782f, 3545.5592783f, 3585.8292787f, 3593.8832788f, 3601.9372788f, 3609.9912789f, 3618.0452790f, 3626.099279f, + 3666.3692794f, 3674.4232795f, 3682.4772796f, 3690.5312796f, 3698.5852797f, 3706.6392798f, 3746.9092801f, 3754.9632802f, 3763.0172803f, 3771.0712804f, 3779.1252804f, 3787.1792805f, 3827.4492809f, 3835.50328f, 3843.5572810f, 3851.6112811f, 3859.6652812f, 3867.7192812f, 3907.9892816f, 3916.0432817f, 3924.097281f, + 3932.1512818f, 3940.2052819f, 3948.2592820f, 3962.5605113f, 3971.6720113f, 3980.783511f, 3989.8950114f, 3999.0065114f, 4008.1180114f, 4053.6755115f, 4062.7870115f, 4071.8985115f, 4081.0100115f, 4090.1215115f, 4099.2330115f, 4144.7905116f, 4153.9020116f, 4163.0135116f, 4172.1250116f, + 4181.236511f, 4190.3480117f, 4235.9055117f, 4245.0170117f, 4254.128511f, 4263.2400118f, 4272.3515118f, 4281.4630118f, 4327.0205119f, 4336.1320119f, 4345.2435119f, 4354.3550119f, 4363.4665119f, 4372.5780119f, 4418.1355120f, 4427.2470120f, 4436.3585120f, 4445.4700120f, 4454.581512f, 4463.6930121f, 4419.8317743f, 4430.0007744f, 4440.1697744f, 4450.338774f, 4460.5077745f, 4470.6767745f, 4521.521774f, 4531.6907748f, + 4541.8597748f, 4552.0287749f, 4562.1977749f, 4572.3667750f, 4623.2117752f, 4633.3807752f, 4643.5497753f, 4653.7187753f, 4663.8877754f, 4674.0567754f, 4724.9017756f, 4735.0707757f, 4745.2397757f, 4755.4087757f, 4765.5777758f, 4775.7467758f, 4826.591776f, 4836.7607761f, 4846.9297761f, 4857.0987762f, 4867.2677762f, 4877.4367763f, 4928.2817765f, 4938.4507765f, 4948.6197766f, 4958.7887766f, 4968.957776f, 4979.12677675f}; + Nd4jLong _expSFF[] = {4, 2, 10, 6, 6, 360, 36, 6, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99,}; + NDArray expFF(_expBFF, _expSFF); + + auto input = NDArrayFactory::create('c', {2, 3, 10, 10}); + auto weightsD = NDArrayFactory::create('c', {5, 3, 5, 5}); + auto weightsP = NDArrayFactory::create('c', {10, 15, 1, 1}); + + input.linspace(1); + weightsD.linspace(1); + weightsP.linspace(1); + weightsD.permutei({2,3,1,0}); + weightsP.permutei({2,3,1,0}); + + input.applyScalar(scalar::Divide, 100.0, input); + weightsD.applyScalar(scalar::Divide, 100.0, weightsD); + weightsP.applyScalar(scalar::Divide, 100.0, weightsP); + + sd::ops::sconv2d op; + + auto resultFF = op.evaluate({&input, &weightsD, &weightsP}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0}); + + auto z = resultFF.at(0); + //z->printShapeInfo("FF shape"); + + + ASSERT_TRUE(z->isSameShape(&expFF)); + + //expFF.printBuffer("e"); + //z->printBuffer("z"); + ASSERT_TRUE(z->equalsTo(&expFF, 1e-3)); +} + +TYPED_TEST(TypedConvolutionTests1, sconv2d_3) { + auto input = NDArrayFactory::create('c', {3, 3, 8, 8}); + auto weightsD = NDArrayFactory::create('c', {1, 3, 1, 1}); + auto weightsP = NDArrayFactory::create('c', {2, 3, 1, 1}); + auto bias = NDArrayFactory::create('c', {2}); + auto output = NDArrayFactory::create('c', {3, 2, 8, 8}); + output.assign(0.0); + + input.linspace(1); + weightsD.linspace(1); + weightsP.linspace(1); + bias.linspace(1); + weightsD.permutei({2,3,1,0}); + weightsP.permutei({2,3,1,0}); + + auto expOutput = NDArrayFactory::create('c', {3, 2, 8, 8}); + + sd::ops::sconv2d op; + Nd4jStatus status = op.execute({&input, &weightsD, &weightsP, &bias}, {&output}, {1, 1, 1, 1, 0, 0, 1, 1, 0}); + auto result = op.evaluate({&input, &weightsD, &weightsP, &bias}, {1, 1, 1, 1, 0, 0, 1, 1, 0}); + + auto z = result.at(0); + + //printf("\n"); + //output.printBuffer("output"); + //z->printBuffer("z"); + + + //ASSERT_TRUE(expOutput.isSameShape(z)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, sconv2d_4) { + + int bS=1, iH=6,iW=6, iC=3,oC=2,mC=3, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=6,oW=6; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iH, iW}, {0.679350, 0.355087, 0.842789, 0.200313, 0.701499, 0.310693, 0.447940, 0.938010, 0.326674, 0.151873, 0.383318, 0.782123, 0.198807, + 0.798564, 0.163263, 0.146968, 0.260897, 0.135058, 0.756209, 0.275454, 0.369088, 0.092826, 0.836492, 0.268413, 0.095062, 0.312795, 0.135918, 0.517544, 0.328703, + 0.061736, 0.396431, 0.248016, 0.548959, 0.115046, 0.814362, 0.721564, 0.404494, 0.299089, 0.403884, 0.988311, 0.022296, 0.927782, 0.318416, 0.068546, 0.284533, + 0.232720, 0.352142, 0.058909, 0.711221, 0.674457, 0.196946, 0.699497, 0.074322, 0.420425, 0.584263, 0.149574, 0.446406, 0.723072, 0.064481, 0.483078, 0.875996, + 0.569819, 0.445863, 0.527755, 0.016646, 0.753678, 0.140636, 0.754129, 0.161932, 0.775037, 0.332645, 0.117394, 0.017711, 0.608476, 0.525152, 0.917194, 0.849891, + 0.589423, 0.852278, 0.390636, 0.889683, 0.669445, 0.698873, 0.961480, 0.157401, 0.157364, 0.493520, 0.569937, 0.126832, 0.115728, 0.786368, 0.737939, 0.490079, + 0.608414, 0.956500, 0.390098, 0.147305, 0.850645, 0.497650, 0.071866, 0.082150, 0.035314, 0.732041, 0.369934, 0.840666, 0.273894, 0.431796, 0.133231}); + NDArray weightsD('c', {kH, kW, iC, mC}, {0.5340641736984253, 0.8257383108139038, 0.3279532492160797, 0.27217748761177063, 0.05432872101664543, 0.31322699785232544, 0.6599581837654114, 0.35526034235954285, 0.5765137672424316}); + NDArray weightsP('c', {1, 1, iC*mC, oC}, {0.4442146420478821, 0.3362849950790405, 0.5215804576873779, 0.5305071473121643, 0.7323054075241089, 0.5168435573577881, 0.8601323962211609, 0.2587810158729553, 0.9473239779472351, 0.39540114998817444, 0.04835261031985283, 0.8724213242530823, 0.8607604503631592, 0.8382210731506348, 0.8573186993598938, 0.6496091485023499, 0.8864102959632874, 0.14267340302467346}); + NDArray biases('c', {1,oC}, {0.8807470202445984, 0.6262521147727966}); + + NDArray expOutput('c', {bS, oC, oH, oW}, {1.643804, 2.135067, 2.494167, 2.628944, 2.700440, 2.257452, 2.562539, 2.293667, 2.493985, 2.014933, 2.301736, 2.939066, 1.492952, + 2.026476, 1.771098, 2.013162, 1.315507, 1.289951, 2.831223, 2.196924, 2.028261, 2.024326, 2.983223, 1.809527, 1.434322, 2.513157, 1.826834, 1.608869, 1.297912, 1.212318, + 2.295934, 1.844615, 2.591148, 1.597267, 2.317755, 1.755642, 1.324064, 1.542060, 1.892052, 1.939339, 1.922781, 1.720199, 1.833396, 1.728024, 1.757968, 1.410675, 1.661960, + 2.096277, 1.178815, 1.637460, 1.254187, 1.491076, 0.968625, 0.986342, 2.116042, 1.536920, 1.504321, 1.490398, 2.136795, 1.351860, 1.148578, 1.817408, 1.327139, 1.288620, + 0.962232, 0.980667, 1.623775, 1.417320, 1.845710, 1.237095, 1.762792, 1.352515}); + + sd::ops::sconv2d op; + auto results = op.evaluate({&input, &weightsD, &weightsP, &biases}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto* output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + +} + +TYPED_TEST(TypedConvolutionTests1, conv2D_BP_Bias_1) { + TypeParam _expWGradB[] = {9312.0, 12580.0, 9528.0, 13168.0, 17712.0, 13360.0, 9960.0, 13348.0, 10032.0, 13344.0, 18148.0, 13848.0, 19312.0, 26160.0, 19888.0, 15144.0, 20452.0, 15504.0}; + Nd4jLong _expWGradS[] = {4, 2, 1, 3, 3, 9, 9, 3, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99}; + NDArray expWGrad(_expWGradB, _expWGradS); + expWGrad.permutei({2,3,1,0}); + + TypeParam _expBGradB[] = {784.0, 1296.0}; + Nd4jLong _expBGradS[] = {2, 2, 1, 1, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99}; + + NDArray expBGrad(_expBGradB, _expBGradS); + + auto input = NDArrayFactory::create('c', {2, 1, 4, 4}); + auto weights = NDArrayFactory::create('c', {2, 1, 3, 3}); + auto bias = NDArrayFactory::create('c', {2, 1}); + auto epsilonNext = NDArrayFactory::create('c', {2, 2, 4, 4}); + + + TypeParam _expEpsB[] = {952.0, 1540.0, 1636.0, 1180.0, 1791.0, 2886.0, 3057.0, 2193.0, 2223.0, 3570.0, 3741.0, 2673.0, 1900.0, 3028.0, 3160.0, 2240.0, 2872.0, 4612.0, 4708.0, 3356.0, 5247.0, 8358.0, 8529.0, 6033.0, 5679.0, 9042.0, 9213.0, 6513.0, 4588.0, 7252.0, 7384.0, 5184.0}; + NDArray expEps(_expEpsB, input.shapeInfo()); + + input.linspace(1); + weights.linspace(1); + epsilonNext.linspace(1); + weights.permutei({2,3,1,0}); + + sd::ops::conv2d_bp op; + + auto results = op.evaluate({&input, &weights, &bias, &epsilonNext}, {}, {3, 3, 1, 1, 0, 0, 1, 1, 1}, {}); + + ASSERT_TRUE(results.size() == 3); + + auto epsilon = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_TRUE(expWGrad.isSameShape(gradW)); + + //expWGrad.printBuffer("Expctd buffer"); + // gradW->printBuffer("Result buffer"); + ASSERT_TRUE(expWGrad.equalsTo(gradW)); + + + ASSERT_TRUE(input.isSameShape(epsilon)); + + // expEps.printBuffer("Expctd buffer"); + //epsilon->printBuffer("Result buffer"); + ASSERT_TRUE(expEps.equalsTo(epsilon)); + + ASSERT_TRUE(expBGrad.isSameShape(gradB)); + + ASSERT_TRUE(expBGrad.equalsTo(gradB)); + + +} + + +TYPED_TEST(TypedConvolutionTests1, conv2D_BP_NoBias_1) { + TypeParam _expWGradB[] = {9312.0, 12580.0, 9528.0, 13168.0, 17712.0, 13360.0, 9960.0, 13348.0, 10032.0, 13344.0, 18148.0, 13848.0, 19312.0, 26160.0, 19888.0, 15144.0, 20452.0, 15504.0}; + Nd4jLong _expWGradS[] = {4, 2, 1, 3, 3, 9, 9, 3, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99}; + NDArray expWGrad(_expWGradB, _expWGradS); + expWGrad.permutei({2,3,1,0}); + + auto input = NDArrayFactory::create('c', {2, 1, 4, 4}); + auto weights = NDArrayFactory::create('c', {2, 1, 3, 3}); + auto epsilonNext = NDArrayFactory::create('c', {2, 2, 4, 4}); + + + TypeParam _expEpsB[] = {952.0, 1540.0, 1636.0, 1180.0, 1791.0, 2886.0, 3057.0, 2193.0, 2223.0, 3570.0, 3741.0, 2673.0, 1900.0, 3028.0, 3160.0, 2240.0, 2872.0, 4612.0, 4708.0, 3356.0, 5247.0, 8358.0, 8529.0, 6033.0, 5679.0, 9042.0, 9213.0, 6513.0, 4588.0, 7252.0, 7384.0, 5184.0}; + NDArray expEps(_expEpsB, input.shapeInfo()); + + input.linspace(1); + weights.linspace(1); + epsilonNext.linspace(1); + weights.permutei({2,3,1,0}); + + sd::ops::conv2d_bp op; + + auto results = op.evaluate({&input, &weights, &epsilonNext}, {}, {3, 3, 1, 1, 0, 0, 1, 1, 1}, {}); + + ASSERT_TRUE(results.size() == 2); + + auto epsilon = results.at(0); + auto gradW = results.at(1); + + ASSERT_TRUE(expWGrad.isSameShape(gradW)); + + //expWGrad.printBuffer("Expctd buffer"); + // gradW->printBuffer("Result buffer"); + ASSERT_TRUE(expWGrad.equalsTo(gradW)); + + + ASSERT_TRUE(input.isSameShape(epsilon)); + + // expEps.printBuffer("Expctd buffer"); + //epsilon->printBuffer("Result buffer"); + ASSERT_TRUE(expEps.equalsTo(epsilon)); + + +} + +TYPED_TEST(TypedConvolutionTests1, sconv2d_conv2d_1) { + + auto input = NDArrayFactory::create('c', {2, 3, 10, 10}); + auto weightsD = NDArrayFactory::create('c', {5, 5, 3, 2}, {1.f, 76.f, 26.f, 101.f, 51.f, 126.f, 2.f, 77.f, 27.f, 102.f, 52.f, 127.f, 3.f, 78.f, 28.f, 103.f, 53.f, 128.f, 4.f, 79.f, 29.f, 104.f, 54.f, 129.f, 5.f, 80.f, 30.f, 105.f, 55.f, 130.f, + 6.f, 81.f, 31.f, 106.f, 56.f, 131.f, 7.f, 82.f, 32.f, 107.f, 57.f, 132.f, 8.f, 83.f, 33.f, 108.f, 58.f, 133.f, 9.f, 84.f, 34.f, 109.f, 59.f, 134.f, 10.f, 85.f, 35.f, 110.f, 60.f, 135.f, + 11.f, 86.f, 36.f, 111.f, 61.f, 136.f, 12.f, 87.f, 37.f, 112.f, 62.f, 137.f, 13.f, 88.f, 38.f, 113.f, 63.f, 138.f, 14.f, 89.f, 39.f, 114.f, 64.f, 139.f, 15.f, 90.f, 40.f, 115.f, 65.f, 140.f, + 16.f, 91.f, 41.f, 116.f, 66.f, 141.f, 17.f, 92.f, 42.f, 117.f, 67.f, 142.f, 18.f, 93.f, 43.f, 118.f, 68.f, 143.f, 19.f, 94.f, 44.f, 119.f, 69.f, 144.f, 20.f, 95.f, 45.f, 120.f, 70.f, 145.f, + 21.f, 96.f, 46.f, 121.f, 71.f, 146.f, 22.f, 97.f, 47.f, 122.f, 72.f, 147.f, 23.f, 98.f, 48.f, 123.f, 73.f, 148.f, 24.f, 99.f, 49.f, 124.f, 74.f, 149.f, 25.f, 100.f, 50.f, 125.f, 75.f, 150.f}); + auto weightsP = NDArrayFactory::create('c', {1, 1, 6, 10}, {0.0001f, 0.0007f, 0.0013f, 0.0019f, 0.0025f, 0.0031f, 0.0037f, 0.0043f, 0.0049f, 0.0055f,0.0002f, 0.0008f, 0.0014f, 0.0020f, 0.0026f, 0.0032f, 0.0038f, 0.0044f, 0.0050f, 0.0056f, + 0.0003f, 0.0009f, 0.0015f, 0.0021f, 0.0027f, 0.0033f, 0.0039f, 0.0045f, 0.0051f, 0.0057f,0.0004f, 0.0010f, 0.0016f, 0.0022f, 0.0028f, 0.0034f, 0.0040f, 0.0046f, 0.0052f, 0.0058f, + 0.0005f, 0.0011f, 0.0017f, 0.0023f, 0.0029f, 0.0035f, 0.0041f, 0.0047f, 0.0053f, 0.0059f,0.0006f, 0.0012f, 0.0018f, 0.0024f, 0.0030f, 0.0036f, 0.0042f, 0.0048f, 0.0054f, 0.0060f}); + + auto expFF = NDArrayFactory::create('c', {2, 6, 6, 6}, {10025.0f,10350.0f,10675.0f,11000.0f,11325.0f,11650.0f,13275.0f,13600.0f,13925.0f,14250.0f,14575.0f,14900.0f,16525.0f,16850.0f, + 17175.0f,17500.0f,17825.0f,18150.0f,19775.0f,20100.0f,20425.0f,20750.0f,21075.0f,21400.0f,23025.0f,23350.0f,23675.0f,24000.0f, + 24325.0f,24650.0f,26275.0f,26600.0f,26925.0f,27250.0f,27575.0f,27900.0f,53150.0f,55350.0f,57550.0f,59750.0f,61950.0f,64150.0f, + 75150.0f,77350.0f,79550.0f,81750.0f,83950.0f,86150.0f,97150.0f,99350.0f,101550.0f,103750.0f,105950.0f,108150.0f,119150.0f, + 121350.0f,123550.0f,125750.0f,127950.0f,130150.0f,141150.0f,143350.0f,145550.0f,147750.0f,149950.0f,152150.0f,163150.0f, + 165350.0f,167550.0f,169750.0f,171950.0f,174150.0f,119400.0f,120350.0f,121300.0f,122250.0f,123200.0f,124150.0f,128900.0f, + 129850.0f,130800.0f,131750.0f,132700.0f,133650.0f,138400.0f,139350.0f,140300.0f,141250.0f,142200.0f,143150.0f,147900.0f, + 148850.0f,149800.0f,150750.0f,151700.0f,152650.0f,157400.0f,158350.0f,159300.0f,160250.0f,161200.0f,162150.0f,166900.0f, + 167850.0f,168800.0f,169750.0f,170700.0f,171650.0f,350025.0f,352850.0f,355675.0f,358500.0f,361325.0f,364150.0f,378275.0f, + 381100.0f,383925.0f,386750.0f,389575.0f,392400.0f,406525.0f,409350.0f,412175.0f,415000.0f,417825.0f,420650.0f,434775.0f, + 437600.0f,440425.0f,443250.0f,446075.0f,448900.0f,463025.0f,465850.0f,468675.0f,471500.0f,474325.0f,477150.0f,491275.0f, + 494100.0f,496925.0f,499750.0f,502575.0f,505400.0f,353775.0f,355350.0f,356925.0f,358500.0f,360075.0f,361650.0f,369525.0f, + 371100.0f,372675.0f,374250.0f,375825.0f,377400.0f,385275.0f,386850.0f,388425.0f,390000.0f,391575.0f,393150.0f,401025.0f, + 402600.0f,404175.0f,405750.0f,407325.0f,408900.0f,416775.0f,418350.0f,419925.0f,421500.0f,423075.0f,424650.0f,432525.0f, + 434100.0f,435675.0f,437250.0f,438825.0f,440400.0f,771900.0f,775350.0f,778800.0f,782250.0f,785700.0f,789150.0f,806400.0f, + 809850.0f,813300.0f,816750.0f,820200.0f,823650.0f,840900.0f,844350.0f,847800.0f,851250.0f,854700.0f,858150.0f,875400.0f, + 878850.0f,882300.0f,885750.0f,889200.0f,892650.0f,909900.0f,913350.0f,916800.0f,920250.0f,923700.0f,927150.0f,944400.0f, + 947850.0f,951300.0f,954750.0f,958200.0f,961650.0f,107525.0f,107850.0f,108175.0f,108500.0f,108825.0f,109150.0f,110775.0f, + 111100.0f,111425.0f,111750.0f,112075.0f,112400.0f,114025.0f,114350.0f,114675.0f,115000.0f,115325.0f,115650.0f,117275.0f, + 117600.0f,117925.0f,118250.0f,118575.0f,118900.0f,120525.0f,120850.0f,121175.0f,121500.0f,121825.0f,122150.0f,123775.0f, + 124100.0f,124425.0f,124750.0f,125075.0f,125400.0f,713150.0f,715350.0f,717550.0f,719750.0f,721950.0f,724150.0f,735150.0f, + 737350.0f,739550.0f,741750.0f,743950.0f,746150.0f,757150.0f,759350.0f,761550.0f,763750.0f,765950.0f,768150.0f,779150.0f, + 781350.0f,783550.0f,785750.0f,787950.0f,790150.0f,801150.0f,803350.0f,805550.0f,807750.0f,809950.0f,812150.0f,823150.0f, + 825350.0f,827550.0f,829750.0f,831950.0f,834150.0f,404400.0f,405350.0f,406300.0f,407250.0f,408200.0f,409150.0f,413900.0f, + 414850.0f,415800.0f,416750.0f,417700.0f,418650.0f,423400.0f,424350.0f,425300.0f,426250.0f,427200.0f,428150.0f,432900.0f,433850.0f,434800.0f,435750.0f,436700.0f,437650.0f,442400.0f,443350.0f,444300.0f,445250.0f,446200.0f,447150.0f,451900.0f,452850.0f,453800.0f,454750.0f,455700.0f,456650.0f,1197525.0f,1200350.0f,1203175.0f,1206000.0f,1208825.0f,1211650.0f,1225775.0f,1228600.0f,1231425.0f,1234250.0f,1237075.0f,1239900.0f,1254025.0f,1256850.0f,1259675.0f,1262500.0f,1265325.0f,1268150.0f,1282275.0f,1285100.0f,1287925.0f,1290750.0f,1293575.0f,1296400.0f,1310525.0f,1313350.0f,1316175.0f,1319000.0f,1321825.0f,1324650.0f,1338775.0f,1341600.0f,1344425.0f,1347250.0f,1350075.0f,1352900.0f,826275.0f,827850.0f,829425.0f,831000.0f,832575.0f,834150.0f,842025.0f,843600.0f,845175.0f,846750.0f,848325.0f,849900.0f,857775.0f,859350.0f,860925.0f,862500.0f,864075.0f,865650.0f,873525.0f,875100.0f,876675.0f,878250.0f,879825.0f,881400.0f,889275.0f,890850.0f,892425.0f,894000.0f,895575.0f,897150.0f,905025.0f,906600.0f,908175.0f,909750.0f,911325.0f,912900.0f,1806900.0f,1810350.0f,1813800.0f,1817250.0f,1820700.0f,1824150.0f,1841400.0f,1844850.0f,1848300.0f,1851750.0f,1855200.0f,1858650.0f,1875900.0f,1879350.0f,1882800.0f,1886250.0f,1889700.0f,1893150.0f,1910400.0f,1913850.0f,1917300.0f,1920750.0f,1924200.0f,1927650.0f,1944900.0f,1948350.0f,1951800.0f,1955250.0f,1958700.0f,1962150.0f,1979400.0f,1982850.0f,1986300.0f,1989750.0f,1993200.0f,1996650.f}); + auto exp2FF = NDArrayFactory::create('c', {2, 10, 6, 6}, {827.4900282f,832.2350283f,836.9800284f,841.725028f,846.4700287f,851.2150288f,874.9400293f,879.6850294f,884.4300295f,889.1750296f,893.9200297f,898.665029f, + 922.3900304f,927.1350305f,931.8800306f,936.6250307f,941.3700308f,946.1150309f,969.8400315f,974.5850316f,979.3300317f,984.0750318f,988.8200319f,993.5650320f, + 1017.2900326f,1022.0350327f,1026.7800328f,1031.5250329f,1036.2700330f,1041.0150331f,1064.7400337f,1069.4850338f,1074.2300339f,1078.9750340f,1083.7200341f, + 1088.4650342f,1822.4550553f,1833.995055f,1845.5350558f,1857.075056f,1868.6150563f,1880.1550566f,1937.8550578f,1949.3950581f,1960.9350583f,1972.4750586f, + 1984.015058f,1995.5550591f,2053.2550604f,2064.7950606f,2076.3350609f,2087.8750611f,2099.4150614f,2110.955061f,2168.6550629f,2180.1950632f,2191.7350634f, + 2203.2750637f,2214.8150639f,2226.3550642f,2284.0550655f,2295.5950657f,2307.1350660f,2318.6750662f,2330.2150665f,2341.7550667f,2399.4550680f,2410.9950683f, + 2422.5350685f,2434.0750688f,2445.6150690f,2457.1550693f,2817.419968f,2835.7549686f,2854.0899683f,2872.4249680f,2890.7599677f,2909.0949674f,3000.7699660f, + 3019.104965f,3037.4399655f,3055.7749652f,3074.1099649f,3092.4449646f,3184.1199632f,3202.4549629f,3220.789962f,3239.1249624f,3257.4599621f,3275.7949618f, + 3367.4699604f,3385.8049601f,3404.1399598f,3422.474959f,3440.8099593f,3459.1449590f,3550.8199576f,3569.1549573f,3587.4899570f,3605.8249567f,3624.1599565f, + 3642.4949562f,3734.1699548f,3752.5049545f,3770.8399542f,3789.1749539f,3807.5099536f,3825.8449534f,3812.385098f,3837.5150988f,3862.6450994f,3887.7751000f, + 3912.9051006f,3938.0351012f,4063.6851041f,4088.8151047f,4113.9451053f,4139.0751059f,4164.2051065f,4189.3351071f,4314.9851100f,4340.1151106f,4365.2451112f, + 4390.3751118f,4415.5051124f,4440.6351130f,4566.2851159f,4591.4151165f,4616.5451171f,4641.6751177f,4666.805118f,4691.9351188f,4817.5851218f,4842.7151224f, + 4867.8451230f,4892.975123f,4918.1051241f,4943.2351247f,5068.8851277f,5094.0151283f,5119.1451288f,5144.2751294f,5169.4051300f,5194.5351306f,4807.3499803f, + 4839.2749801f,4871.1999799f,4903.1249797f,4935.0499795f,4966.9749793f,5126.5999784f,5158.5249782f,5190.4499780f,5222.3749778f,5254.2999777f,5286.2249775f, + 5445.8499765f,5477.774976f,5509.6999762f,5541.6249760f,5573.5499758f,5605.4749756f,5765.0999747f,5797.0249745f,5828.9499743f,5860.8749741f,5892.7999739f, + 5924.724973f,6084.3499728f,6116.2749726f,6148.1999724f,6180.1249723f,6212.0499721f,6243.9749719f,6403.59997f,6435.5249708f,6467.4499706f,6499.3749704f, + 6531.2999702f,6563.2249700f,5802.3150007f,5841.0350006f,5879.7550005f,5918.4750004f,5957.195000f,5995.9150003f,6189.5149999f,6228.2349998f,6266.9549997f, + 6305.6749996f,6344.3949995f,6383.114999f,6576.7149990f,6615.4349990f,6654.1549989f,6692.8749988f,6731.5949987f,6770.3149986f,6963.9149982f,7002.6349981f, + 7041.3549981f,7080.0749980f,7118.7949979f,7157.5149978f,7351.1149974f,7389.8349973f,7428.5549972f,7467.2749972f,7505.9949971f,7544.7149970f,7738.3149966f,7777.0349965f,7815.7549964f,7854.4749963f,7893.1949963f,7931.9149962f,6797.2799488f,6842.794948f,6888.3099489f,6933.8249490f,6979.3399491f,7024.8549492f,7252.4299497f,7297.9449498f,7343.4599499f,7388.9749500f,7434.489950f,7480.0049501f,7707.5799506f,7753.0949507f,7798.6099508f,7844.1249509f,7889.6399510f,7935.1549511f,8162.7299515f,8208.2449516f,8253.7599517f,8299.2749518f,8344.7899519f,8390.3049520f,8617.8799525f,8663.394952f,8708.9099526f,8754.4249527f,8799.9399528f,8845.4549529f,9073.0299534f,9118.5449535f,9164.0599536f,9209.5749537f,9255.089953f,9300.604953f,7792.2451647f,7844.5551655f,7896.8651663f,7949.1751671f,8001.4851679f,8053.7951686f,8315.3451725f,8367.6551733f,8419.9651741f,8472.2751749f,8524.585175f,8576.8951764f,8838.4451803f,8890.7551811f,8943.0651819f,8995.3751827f,9047.6851834f,9099.9951842f,9361.5451881f,9413.8551889f,9466.1651897f,9518.475190f,9570.7851912f,9623.0951920f,9884.6451959f,9936.9551967f,9989.2651975f,10041.5751982f,10093.8851990f,10146.1951998f,10407.7452037f,10460.0552045f,10512.3652053f,10564.6752060f,10616.9852068f,10669.2952076f,8787.210074f,8846.3150748f,8905.4200750f,8964.5250752f,9023.6300755f,9082.7350757f,9378.2600768f,9437.3650770f,9496.4700773f,9555.5750775f,9614.6800777f,9673.7850779f,9969.3100791f,10028.4150793f,10087.5200795f,10146.625079f,10205.7300800f,10264.8350802f,10560.3600813f,10619.465081f,10678.5700818f,10737.6750820f,10796.7800822f,10855.8850825f,11151.4100836f,11210.5150838f,11269.6200840f,11328.7250843f,11387.8300845f,11446.9350847f,11742.4600858f,11801.5650861f,11860.6700863f,11919.7750865f,11978.880086f,12037.9850870f,9782.1750935f,9848.0750935f,9913.9750934f,9979.8750934f,10045.7750934f,10111.6750933f,10441.1750931f,10507.0750931f,10572.9750931f,10638.8750930f,10704.7750930f,10770.6750930f,11100.1750928f,11166.0750927f,11231.9750927f,11297.8750927f,11363.7750926f,11429.6750926f,11759.1750924f,11825.0750924f,11890.9750923f,11956.8750923f,12022.7750923f,12088.6750922f,12418.175092f,12484.0750920f,12549.9750920f,12615.8750919f,12681.7750919f,12747.6750919f,13077.1750917f,13143.0750916f,13208.9750916f,13274.8750916f,13340.7750915f,13406.6750915f,2250.990060f,2255.7350610f,2260.4800611f,2265.2250612f,2269.9700613f,2274.7150614f,2298.4400619f,2303.185062f,2307.9300622f,2312.6750623f,2317.4200624f,2322.1650625f,2345.8900630f,2350.6350631f,2355.380063f,2360.1250634f,2364.8700635f,2369.6150636f,2393.3400641f,2398.0850642f,2402.8300643f,2407.5750644f,2412.320064f,2417.0650647f,2440.7900652f,2445.5350653f,2450.2800654f,2455.0250655f,2459.7700656f,2464.515065f,2488.2400663f,2492.9850664f,2497.7300665f,2502.4750666f,2507.2200667f,2511.9650668f,5284.4551315f,5295.9951318f,5307.535132f,5319.0751323f,5330.6151326f,5342.1551328f,5399.8551341f,5411.3951343f,5422.9351346f,5434.475134f,5446.0151351f,5457.5551354f,5515.2551366f,5526.7951369f,5538.3351371f,5549.8751374f,5561.4151376f,5572.9551379f,5630.6551392f,5642.1951394f,5653.7351397f,5665.2751399f,5676.8151402f,5688.3551404f,5746.0551417f,5757.5951420f,5769.1351422f,5780.6751425f,5792.2151427f,5803.7551430f,5861.455144f,5872.9951445f,5884.5351448f,5896.0751450f,5907.6151453f,5919.1551455f,8317.919884f,8336.2548841f,8354.5898838f,8372.9248835f,8391.2598832f,8409.59488f,8501.2698815f,8519.6048813f,8537.9398810f,8556.2748807f,8574.6098804f,8592.9448801f,8684.6198787f,8702.9548784f,8721.2898782f,8739.6248779f,8757.9598776f,8776.2948773f,8867.9698759f,8886.3048756f,8904.6398753f,8922.9748751f,8941.3098748f,8959.6448745f,9051.3198731f,9069.6548728f,9087.9898725f,9106.3248722f,9124.6598720f,9142.9948717f,9234.6698703f,9253.0048700f,9271.3398697f,9289.6748694f,9308.0098691f,9326.3448689f,11351.3852747f,11376.5152753f,11401.6452759f,11426.7752765f,11451.9052771f,11477.0352777f,11602.6852806f,11627.8152812f,11652.9452818f,11678.0752824f,11703.2052830f,11728.335283f,11853.9852865f,11879.1152871f,11904.2452877f,11929.3752883f,11954.505288f,11979.6352894f,12105.2852924f,12130.4152930f,12155.545293f,12180.6752941f,12205.8052947f,12230.9352953f,12356.5852983f,12381.715298f,12406.8452994f,12431.9753000f,12457.1053006f,12482.2353012f,12607.8853041f,12633.0153047f,12658.1453053f,12683.2753059f,12708.4053065f,12733.5353071f,14384.8499244f,14416.7749242f,14448.6999240f,14480.6249238f,14512.549923f,14544.4749235f,14704.0999225f,14736.024922f,14767.9499222f,14799.8749220f,14831.7999218f,14863.7249216f,15023.3499207f,15055.2749205f,15087.1999203f,15119.1249201f,15151.0499199f,15182.9749197f,15342.5999188f,15374.5249186f,15406.4499184f,15438.374918f,15470.2999181f,15502.2249179f,15661.84991f,15693.7749168f,15725.6999166f,15757.6249164f,15789.5499162f,15821.4749160f,15981.0999151f,16013.0249149f,16044.9499147f,16076.8749145f,16108.7999143f,16140.7249142f,17418.314976f,17457.0349761f,17495.7549760f,17534.4749759f,17573.1949758f,17611.9149757f,17805.5149753f,17844.234975f,17882.9549752f,17921.6749751f,17960.3949750f,17999.1149749f,18192.7149745f,18231.4349744f,18270.154974f,18308.8749743f,18347.5949742f,18386.3149741f,18579.9149737f,18618.6349736f,18657.3549735f,18696.074973f,18734.7949734f,18773.5149733f,18967.1149729f,19005.8349728f,19044.5549727f,19083.2749726f,19121.994972f,19160.7149725f,19354.3149721f,19393.0349720f,19431.7549719f,19470.4749718f,19509.1949717f,19547.914971f,20451.7799765f,20497.2949766f,20542.8099767f,20588.3249768f,20633.8399769f,20679.3549770f,20906.929977f,20952.4449775f,20997.9599776f,21043.4749777f,21088.9899778f,21134.5049779f,21362.0799784f,21407.5949785f,21453.1099786f,21498.624978f,21544.139978f,21589.6549788f,21817.2299793f,21862.7449794f,21908.2599795f,21953.7749796f,21999.2899797f,22044.8049798f,22272.3799802f,22317.8949803f,22363.4099804f,22408.9249805f,22454.4399806f,22499.9549807f,22727.529981f,22773.044981f,22818.5599813f,22864.0749814f,22909.5899815f,22955.1049816f,23485.2453985f,23537.555399f,23589.8654000f,23642.1754008f,23694.4854016f,23746.7954024f,24008.3454063f,24060.655407f,24112.9654078f,24165.2754086f,24217.5854094f,24269.8954102f,24531.4454141f,24583.7554148f,24636.0654156f,24688.3754164f,24740.6854172f,24792.99541f,25054.545421f,25106.8554226f,25159.1654234f,25211.4754242f,25263.7854250f,25316.0954257f,25577.6454296f,25629.9554304f,25682.2654312f,25734.5754320f,25786.8854328f,25839.1954335f,26100.7454374f,26153.0554382f,26205.3654390f,26257.6754398f,26309.985440f,26362.2954413f,26518.7101423f,26577.8151425f,26636.920142f,26696.0251430f,26755.1301432f,26814.2351434f,27109.7601446f,27168.8651448f,27227.9701450f,27287.0751452f,27346.1801455f,27405.2851457f,27700.8101468f,27759.9151470f,27819.0201473f,27878.1251475f,27937.2301477f,27996.33514f,28291.8601491f,28350.9651493f,28410.0701495f,28469.175149f,28528.2801500f,28587.3851502f,28882.9101513f,28942.0151516f,29001.1201518f,29060.2251520f,29119.3301522f,29178.4351525f,29473.9601536f,29533.0651538f,29592.1701540f,29651.2751543f,29710.3801545f,29769.4851547f,29552.1750826f,29618.0750825f,29683.9750825f,29749.8750825f,29815.7750824f,29881.6750824f,30211.1750822f,30277.0750822f,30342.9750821f,30408.8750821f,30474.7750821f,30540.6750820f,30870.175081f,30936.0750818f,31001.9750818f,31067.8750817f,31133.7750817f,31199.6750817f,31529.1750815f,31595.075081f,31660.9750814f,31726.8750814f,31792.7750813f,31858.6750813f,32188.1750811f,32254.0750811f,32319.975081f,32385.8750810f,32451.7750810f,32517.6750809f,32847.1750808f,32913.0750807f,32978.9750807f,33044.875080f,33110.7750806f,33176.67508062f}); + + input.linspace(1); + + sd::ops::sconv2d op; + auto resultFF = op.evaluate({&input, &weightsD}, {}, {5, 5, 1, 1, 0, 0, 1, 1, 0}, {}); + + auto z = resultFF.at(0); + + ASSERT_TRUE(z->isSameShape(&expFF)); + ASSERT_TRUE(z->equalsTo(&expFF, 1)); + + + sd::ops::conv2d op2d; + // weightsP.printShapeInfo(); + auto result2D = op2d.evaluate({z, &weightsP}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {}); + + auto z2d = result2D.at(0); + // z2d->printBuffer(); + + ASSERT_TRUE(z2d->isSameShape(&exp2FF)); + ASSERT_TRUE(z2d->equalsTo(&exp2FF)); +} + +TEST_F(ConvolutionTests1, deconv2d_bp_1) { + + int bS=3, iH=4,iW=4, iC=3,oC=2, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=4,oW=4; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, sd::DataType::FLOAT32); + NDArray weights('c',{kH,kW,oC,iC}, {1,3,5,2,4,6}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oC, oH, oW},sd::DataType::FLOAT32); + + NDArray expGradI('c', {bS, iC, iH, iW}, {35.f, 38.f, 41.f, 44.f, 47.f, 50.f, 53.f, 56.f, 59.f, 62.f, 65.f, 68.f, 71.f, 74.f, + 77.f, 80.f, 71.f, 78.f, 85.f, 92.f, 99.f, 106.f, 113.f, 120.f, 127.f, 134.f, 141.f, 148.f, 155.f, 162.f, 169.f, + 176.f, 107.f, 118.f, 129.f, 140.f, 151.f, 162.f, 173.f, 184.f, 195.f, 206.f, 217.f, 228.f, 239.f, 250.f, 261.f, 272.f, + 131.f, 134.f, 137.f, 140.f, 143.f, 146.f, 149.f, 152.f, 155.f, 158.f, 161.f, 164.f, 167.f, 170.f, 173.f, 176.f, 295.f, + 302.f, 309.f, 316.f, 323.f, 330.f, 337.f, 344.f, 351.f, 358.f, 365.f, 372.f, 379.f, 386.f, 393.f, 400.f, 459.f, 470.f, + 481.f, 492.f, 503.f, 514.f, 525.f, 536.f, 547.f, 558.f, 569.f, 580.f, 591.f, 602.f, 613.f, 624.f, 227.f, 230.f, 233.f, + 236.f, 239.f, 242.f, 245.f, 248.f, 251.f, 254.f, 257.f, 260.f, 263.f, 266.f, 269.f, 272.f, 519.f, 526.f, 533.f, 540.f, + 547.f, 554.f, 561.f, 568.f, 575.f, 582.f, 589.f, 596.f, 603.f, 610.f, 617.f, 624.f, 811.f, 822.f, 833.f, 844.f, 855.f, + 866.f, 877.f, 888.f, 899.f, 910.f, 921.f, 932.f, 943.f, 954.f, 965.f, 976.f}, sd::DataType::FLOAT32); + NDArray expGradW('c', {kH, kW, oC, iC}, {160008., 191112., 222216., 203400., 246792., 290184.f}, sd::DataType::FLOAT32); + NDArray expGradB('c', {oC}, {1944.f, 2712.f}, sd::DataType::FLOAT32); + + input.linspace(1); + bias.linspace(1); + gradO.linspace(1); + + + sd::ops::deconv2d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, deconv2d_bp_2) { + + int bS=3, iH=4,iW=4, iC=3,oC=2, kH=2,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=4,oW=4; // 5,4 + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = 1; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {-0.1, 0.2}, sd::DataType::FLOAT32); + NDArray weights('c',{iC, oC, kH, kW}, {1., 7., 2., 10., 3., 8., 4., 11., 5., 9., 6., 12.}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oC, oH, oW},sd::DataType::FLOAT32); + + NDArray expGradI('c', {bS, iC, iH, iW}, {-77.400002, -77.199997, -77., -76.800003, -76.599998, -76.400002, -76.200005, -76., -75.800003, -75.599998, -75.399994, + -75.199997, -11.32, -11.29, -11.26, -11.23, -100.839996, -100.580002, -100.32, -100.059998, -99.800003, -99.540001, -99.279999, -99.019997, -98.760002, -98.50, + -98.240005, -97.979996, -26.52, -26.450001, -26.380001, -26.309999, -124.279999, -123.959991, -123.639999, -123.32, -123., -122.68, -122.360001, -122.040001, + -121.720001, -121.400009, -121.080002, -120.759995, -41.720001, -41.610001, -41.50, -41.389999, -71., -70.800003, -70.599998, -70.399994, -70.199997, -70., -69.800003, -69.600006, -69.400002, -69.199997, -69., -68.799995, -10.360001, -10.33, -10.30, -10.27, -92.519997, -92.260002, -92., -91.740005, -91.479996, -91.220001, -90.960007, -90.700005, -90.440002, -90.18, -89.919998, -89.660004, -24.280001, -24.209999, -24.139999, -24.07, -114.040001, -113.720001, -113.400009, -113.080002, -112.759995, -112.440002, -112.120003, -111.800003, -111.480003, -111.159996, -110.839996, -110.520004, -38.200001, -38.09, -37.980003, -37.869999, -64.599998, -64.400002, -64.199997, -64., -63.799995, -63.599998, -63.400002, -63.199997, -63., -62.799995, -62.599998, -62.400002, -9.40, -9.37, -9.34, -9.309999, -84.200005, -83.940002, -83.68, -83.419998, -83.160004, -82.900002, -82.639999, -82.379997, -82.119995, -81.860001, -81.600006, -81.339996, -22.040001, -21.970001, -21.90, -21.83, -103.800003, -103.480003, -103.159996, -102.839996, -102.520004, -102.200005, -101.879997, -101.559998, -101.239998, -100.919998, -100.599998, -100.279999, -34.68, -34.57, -34.459999, -34.349998}, sd::DataType::FLOAT32); + + NDArray expGradW('c', {iC, oC, kH, kW}, {-3010.799805, -2502.420410, -2899.439209, -2407.380615, -242.159332, -437.460510, -253.680466, -434.580048, 2526.479980, 1627.500000, 2392.079834, 1538.220093}, sd::DataType::FLOAT32); + NDArray expGradB('c', {oC}, {-173.040009, -165.360016}, sd::DataType::FLOAT32); + + input.linspace(70., -1); + gradO.linspace(-4, 0.01); + + sd::ops::deconv2d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, deconv2d_bp_3) { + + int bS=3, iH=4,iW=4, iC=3,oC=2, kH=2,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=5,oW=4; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = 2; // 0 - [kH, kW, oC, iC], 1 - [iC, oC, kH, kW], 2 - [iC, kH, kW, oC] + + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {-0.1, 0.2}, sd::DataType::FLOAT32); + NDArray weights('c',{iC, kH, kW, oC}, {1., 4., 7., 10., 2., 5., 8., 11., 3., 6., 9., 12.}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32); + + NDArray expGradI('c', {bS, iH, iW, iC}, {-86.5, -102.320007, -118.139999, -86.060005, -101.800003, -117.540001, -85.619995, -101.279999, -116.940002, -85.18, + -100.759995, -116.339996, -84.740005, -100.239998, -115.739998, -84.300003, -99.720001, -115.139999, -83.860001, -99.199997, -114.539993, -83.419998, -98.68, + -113.939995, -82.979996, -98.160004, -113.339996, -82.539993, -97.639999, -112.739998, -82.099998, -97.120003, -112.139999, -81.660004, -96.600006, -111.539993, + -81.220001, -96.080002, -110.939995, -80.779999, -95.559998, -110.340012, -80.340004, -95.040001, -109.740005, -79.900002, -94.519997, -109.139992, -77.699997, + -91.919998, -106.139999, -77.260002, -91.400002, -105.540001, -76.820007, -90.880005, -104.940002, -76.380005, -90.360001, -104.339996, -75.940002, -89.839996, -103.740005, -75.5, -89.320007, -103.139999, -75.060005, -88.800003, -102.540001, -74.619995, -88.279999, -101.940002, -74.18, -87.759995, -101.339996, -73.740005, -87.239998, -100.739998, -73.300003, -86.720001, -100.139999, -72.860001, -86.199997, -99.539993, -72.419998, -85.68, -98.939995, -71.979996, -85.160004, -98.339996, -71.539993, -84.639999, -97.740005, -71.099998, -84.120003, -97.139999, -68.899994, -81.519997, -94.139999, -68.459999, -81.00, -93.539993, -68.019997, -80.479996, -92.940002, -67.580002, -79.959999, -92.339996, -67.139999, -79.440002, -91.740005, -66.699997, -78.919998, -91.139999, -66.260002, -78.399994, -90.540001, -65.820007, -77.880005, -89.940002, -65.380005, -77.360001, -89.339996, -64.940002, -76.839996, -88.740005, -64.5, -76.320007, -88.139999, -64.060005, -75.800003, -87.540001, -63.619995, -75.279999, -86.940002, -63.18, -74.759995, -86.339996, -62.739998, -74.239998, -85.739998, -62.299999, -73.720001, -85.139999}, sd::DataType::FLOAT32); + + NDArray expGradW('c', {iC, kH, kW, oC}, {-592.800110, -593.039917, -594.719116, -594.960266, -427.199890, -427.919617, -432.959900, -433.679993, -261.600281, -262.799591, -271.200317, -272.399536}, sd::DataType::FLOAT32); + NDArray expGradB('c', {oC}, {-204.600006, -204.}, sd::DataType::FLOAT32); + + input.linspace(70., -1); + gradO.linspace(-4, 0.01); + + sd::ops::deconv2d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} + +TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_1) { + auto input = NDArrayFactory::create('c', {2, 2, 6}); + auto weights = NDArrayFactory::create('c', {2, 2, 3}, {1,5,9,3,7,11,2,6,10,4,8,12}); + auto bias = NDArrayFactory::create('c', {3}); + auto expFF = NDArrayFactory::create('c', {2, 3, 5}, {59.0f, 69.0f, 79.0f, 89.0f, 99.0f, 132.0f, 158.0f, 184.0f, 210.0f, 236.0f, 205.0f, 247.0f, 289.0f, 331.0f, 373.0f, 179.0f, 189.0f, 199.0f, 209.0f, 219.0f, 444.0f, 470.0f, 496.0f, 522.0f, 548.0f, 709.0f, 751.0f, 793.0f, 835.0f, 877.0f}); + auto expEps = NDArrayFactory::create('c', {2, 2, 6}, {130.0f, 293.0f, 326.0f, 359.0f, 392.0f, 220.0f, 166.0f, 371.0f, 416.0f, 461.0f, 506.0f, 280.0f, 355.0f, 788.0f, 821.0f, 854.0f, 887.0f, 490.0f, 481.0f, 1046.0f, 1091.0f, 1136.0f, 1181.0f, 640.0f}); + auto expGW = NDArrayFactory::create('c', {3, 2, 2}, {1415.0f, 1520.0f, 2045.0f, 2150.0f, 1865.0f, 2020.0f, 2795.0f, 2950.0f, 2315.0f, 2520.0f, 3545.0f, 3750.0f}); + auto expGB = NDArrayFactory::create('c', {3}, {105.0f, 155.0f, 205.0f}); + + expGW.permutei({2,1,0}); + input.linspace(1); + bias.linspace(1); + + sd::ops::conv1d op; + auto result_FF = op.evaluate({&input, &weights, &bias}, {}, {2, 1, 0, 1, 0, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, result_FF.status()); + + auto z = result_FF.at(0); + + ASSERT_TRUE(expFF.isSameShape(z)); + ASSERT_TRUE(expFF.equalsTo(z)); + + sd::ops::conv1d_bp op_bp; + + auto epsilonNxt = new NDArray(z->dup()); + epsilonNxt->linspace(1); + + auto result_BP = op_bp.evaluate({&input, &weights, &bias, epsilonNxt}, {}, {2, 1, 0, 1, 0, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result_BP.status()); + + auto eps = result_BP.at(0); + auto gradW = result_BP.at(1); + auto gradB = result_BP.at(2); + + ASSERT_TRUE(expEps.isSameShape(eps)); + ASSERT_TRUE(expGW.isSameShape(gradW)); + ASSERT_TRUE(expGB.isSameShape(gradB)); + + ASSERT_TRUE(expEps.equalsTo(eps)); + ASSERT_TRUE(expGW.equalsTo(gradW)); + ASSERT_TRUE(expGB.equalsTo(gradB)); + + delete epsilonNxt; +} + + +TYPED_TEST(TypedConvolutionTests1, Test_Conv1D_ff_2) { + auto input = NDArrayFactory::create('c', {2, 2, 6}); + auto weights = NDArrayFactory::create('c', {2, 2, 3}, {1.f, 5.f, 9.f, 3.f, 7.f, 11.f, 2.f, 6.f, 10.f, 4.f, 8.f, 12.f}); + + input.linspace(1); + + sd::ops::conv1d op; + auto result = op.evaluate({&input, &weights}, {}, {2, 1, 0, 1, 1,0}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_1) { + + int bS=2, iW=3, iC=4,oC=3, kW=2, sW=1, pW=0, dW=1; + int oW = (iW-1)/sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iW, iC}); + NDArray weights('c', {kW, iC, oC}); + NDArray bias('c', {oC}, {-1,-2,-3}); + + NDArray expOutput('c', {bS, oW, oC}, {18. , 18. , 18. , 53. , 55.6, 58.2, 89.8, 95.6, 101.4, 102. , 106.8, 111.6, 163.4, 175.6, 187.8, 200.2, 215.6, 231.}); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + + sd::ops::conv1d op; + auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_2) { + + int bS=2, iW=16, iC=3,oC=4, kW=2, sW=2, pW=0, dW=1; + int oW = (iW-1)/sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iW, iC}); + NDArray weights('c', {kW, iC, oC}); + NDArray bias('c', {oC}, {-1,-2,-3,-4}); + + NDArray expOutput('c', {bS, oW, oC}, { 10. , 9.6, 9.2, 8.8, 48.9, 51.8, 54.7, 57.6, 88.5, 95. , 101.5, 108. , 128.1, 138.2, 148.3, 158.4, + 167.7, 181.4, 195.1, 208.8, 207.3, 224.6, 241.9, 259.2, 246.9, 267.8, 288.7, 309.6, 286.5, 311. , 335.5, 360. , + 254.8, 268.8, 282.8, 296.8, 365.7, 397.4, 429.1, 460.8, 405.3, 440.6, 475.9, 511.2, 444.9, 483.8, 522.7, 561.6, + 484.5, 527. , 569.5, 612. , 524.1, 570.2, 616.3, 662.4, 563.7, 613.4, 663.1, 712.8, 603.3, 656.6, 709.9, 763.2}); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + + sd::ops::conv1d op; + auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_3) { + + int bS=2, iW=16, iC=3,oC=4, kW=3, sW=3, pW=0, dW=1; + int oW = (iW-1)/sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iW, iC}); + NDArray weights('c', {kW, iC, oC}); + NDArray bias('c', {oC}, {-1,-2,-3,-4}); + + NDArray expOutput('c', {bS, oW, oC}, {17.2, 16.8, 16.4, 16.,145.4, 151.6, 157.8, 164.,283.1, 297.4, 311.7, 326., 420.8, 443.2, 465.6, 488., + 558.5, 589., 619.5, 650.,696.2001, 734.8, 773.4, 812., 434.8, 448.8, 462.8, 476.8, 879.8, 929.2, 978.6, 1028., + 1017.5, 1075., 1132.5, 1190.,1155.2001, 1220.8, 1286.4, 1352.,1292.8999, 1366.6, 1440.3, 1514., 1430.6001, 1512.4, 1594.2, 1676.}); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + + sd::ops::conv1d op; + auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_4) { + + int bS=2, iW=8, iC=3,oC=4, kW=3, sW=1, pW=0, dW=3; + int oW = (iW-1)/sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iW, iC}); + NDArray weights('c', {kW, iC, oC}); + NDArray bias('c', {oC}, {-1,-2,-3,-4}); + + NDArray expOutput('c', {bS, oW, oC}, {17.2, 16.8, 16.4, 16. ,43.3, 43.8, 44.3, 44.8,69.4, 70.8, 72.2, 73.6,106.5, 109.4, 112.3, 115.2,147.9, 152.6, 157.3, 162. ,189.3, 195.8, 202.3, + 208.8,234.5, 243.4, 252.3, 261.2,280.4, 292. , 303.6, 315.2, 226. , 232.8, 239.6, 246.4, 252.1, 259.8, 267.5, 275.2,278.2, 286.8, 295.4, 304. ,437.7, + 455. , 472.3, 489.6,479.1, 498.2, 517.3, 536.4,520.5, 541.4, 562.3, 583.2, 601.7, 632.2, 662.7, 693.2, 647.6, 680.8, 714. , 747.2}); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + + sd::ops::conv1d op; + auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_5) { + + int bS=2, iW=8, iC=3,oC=4, kW=3, sW=1, pW=0, dW=3; + int oW = (iW-1)/sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iW}); + NDArray weights('c', {kW, iC, oC}); + NDArray bias('c', {oC}, {-1,-2,-3,-4}); + + NDArray expOutput('c', {bS, oC, oW}, { 83.7, 92.4, 101.1, 162.1, 175.9, 189.7, 223.4, 238.7,85.4, 94.4, 103.4, 167.4, 181.8, 196.2, 233.2, 249.4,87.1, 96.4, 105.7, 172.7, 187.7, 202.7, 243. , 260.1, + 88.8, 98.4, 108. , 178. , 193.6, 209.2, 252.8, 270.8, 292.5, 301.2, 309.9, 493.3, 507.1, 520.9, 590.6, 605.9, 301.4, 310.4, 319.4, 513. , 527.4, 541.8, 622. , 638.2, + 310.3, 319.6, 328.9, 532.7, 547.7, 562.7, 653.4, 670.5, 319.2, 328.8, 338.4, 552.4, 568. , 583.6, 684.8, 702.8}); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + + sd::ops::conv1d op; + auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_6) { + + int bS=2, iW=16, iC=3,oC=4, kW=3, sW=3, pW=0, dW=1; + int oW = (iW-1)/sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iW}); + NDArray weights('c', {kW, iC, oC}); + NDArray bias('c', {oC}, {-1,-2,-3,-4}); + + NDArray expOutput('c', {bS, oC, oW}, {159.7,335.3,381.2,427.1,473. ,518.9,163.8,351.4,400. ,448.6,497.2,545.8,167.9,367.5,418.8,470.1,521.4,572.7,172. ,383.6,437.6,491.6,545.6,599.6, + 577.3, 1069.7, 1115.6, 1161.5, 1207.4, 1253.3,595.8, 1129. , 1177.6, 1226.2, 1274.8, 1323.4,614.3, 1188.3, 1239.6, 1290.9, 1342.2, 1393.5, + 632.8, 1247.6, 1301.6, 1355.6, 1409.6, 1463.6}); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + + sd::ops::conv1d op; + auto results = op.evaluate({&input, &weights, &bias}, {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_7) { + + int bS=2, iW=8, iC=3,oC=4, kW=2, sW=1, pW=0, dW=1; + int oW = (iW-1)/sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {kW, iC, oC}, sd::DataType::FLOAT32); + + NDArray expOutput('c', {bS, oW, oC}, {11.000000, 11.600000, 12.200000, 12.800000, 30.099998, 32.200001, 34.299999, 36.400002, 49.899998, 53.800003, 57.699997, + 61.599998, 69.699997, 75.400002, 81.099998, 86.800003, 89.500000, 97.000000, 104.500000, 112.000000, 109.300003, 118.600006, 127.899994, 137.199997, 129.100006, + 140.199997, 151.300003, 162.399994, 148.899994, 161.800003, 174.699997, 187.600006, 133.399994, 141.200012, 149.000000, 156.800003, 188.500000, 205.000000, + 221.500000, 238.000000, 208.299988, 226.600006, 244.899994, 263.200012, 228.100006, 248.200012, 268.299988, 288.399994, 247.899994, 269.799988, 291.700012, + 313.600006, 267.700012, 291.399994, 315.100006, 338.799988, 287.500000, 313.000000, 338.500000, 364.000000, 307.299988, 334.600006, 361.899994, 389.200012}, sd::DataType::FLOAT32); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + + sd::ops::conv1d op; + auto results = op.evaluate({&input, &weights}, {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_8) { + + int bS=2, iW=8, iC=3,oC=4, kW=2, sW=1, pW=0, dW=2; + int oW = (iW-1)/sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {kW, iC, oC}, sd::DataType::FLOAT32); + + NDArray expOutput('c', {bS, oW, oC}, {11.000000, 11.600000, 12.200000, 12.800000, 26.299999, 27.799999, 29.299999, 30.799999, 45.399998, 48.399998, + 51.400002, 54.400005, 65.199997, 70.000000, 74.800003, 79.600006, 85.000000, 91.600006, 98.199997, 104.800003, 104.799995, 113.199997, 121.600006, + 130.000000, 124.599998, 134.800003, 145.000000, 155.200012, 144.399994, 156.399994, 168.399994, 180.400009, 133.400009, 141.199997, 149.000000, + 156.800003, 148.699997, 157.400009, 166.099991, 174.800003, 203.800003, 221.200012, 238.599991, 256.000000, 223.599991, 242.799988, 262.000000, + 281.200012, 243.399994, 264.399994, 285.399994, 306.399994, 263.199982, 286.000000, 308.799988, 331.600006, 283.000000, 307.600006, 332.200012, + 356.800018, 302.799988, 329.199982, 355.600006, 382.000000}, sd::DataType::FLOAT32); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + + sd::ops::conv1d op; + auto results = op.evaluate({&input, &weights}, {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + +} + + + + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_bp_1) { + + int bS=2, iW=3, iC=4,oC=3, kW=2, sW=1, pW=0, dW=1; + int oW = (iW-1)/sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iW, iC}); + NDArray weights('c', {kW, iC, oC}); + NDArray bias('c', {oC}, {-1,-2,-3}); + NDArray gradO('c', {bS, oW, oC}); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + gradO.linspace(-1.5, 0.1); + + const OpArgsHolder argsHolderFF({&input, &weights, &bias}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); + const OpArgsHolder argsHolderBP({&input, &weights, &bias, &gradO}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); + + sd::ops::conv1d opFF; + sd::ops::conv1d_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + +TEST_F(ConvolutionTests1, Test_Dilation2D_1) { + auto input = NDArrayFactory::create('c', {2, 6, 6, 3}); + auto weights = NDArrayFactory::create('c', {3, 2, 3}); + auto exp = NDArrayFactory::create('c', {2, 3, 3, 3}, {77, 79, 81, 83, 85, 87, 80, 82, 84, 113, 115, 117, 119, 121, 123, 116, 118, 120, 107, 109, 111, 113, 115, 117, 110, 112, 114, 185, 187, 189, 191, 193, 195, 188, 190, 192, 221, 223, 225, 227, 229, 231, 224, 226, 228, 215, 217, 219, 221, 223, 225, 218, 220, 222,}); + + input.linspace(1); + weights.linspace(1); + + sd::ops::dilation2d op; + auto result = op.evaluate({&input, &weights}, {1, 1,2,2,1, 1,2,2,1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +TEST_F(ConvolutionTests1, Test_Dilation2D_2) { + auto input = NDArrayFactory::create('c', {2, 6, 6, 3}); + auto weights = NDArrayFactory::create('c', {3, 2, 3}); + auto exp = NDArrayFactory::create('c', {2, 1, 2, 3}, {95, 97, 99, 101, 103, 105, 203, 205, 207, 209, 211, 213}); + + input.linspace(1); + weights.linspace(1); + + sd::ops::dilation2d op; + auto result = op.evaluate({&input, &weights}, {0, 1,2,2,1, 1,2,2,1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test1) { + + int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=4,oW=3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); + auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); + auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); + + auto expGradI = NDArrayFactory::create('c', {bS, iH, iW, iC},{ 0.226f, 0.343f, 0.46f, 0.577f, 1.172f, 1.46f, 1.748f, 2.036f, 1.892f, 2.288f, 2.684f, 3.08f, 1.284f, 1.581f, 1.878f, 2.175f, 4.458f, 5.133f, 5.808f, 6.483f, 6.186f, 7.023f, 7.86f, 8.697f, + 3.39f, 3.93f, 4.47f, 5.01f, 9.642f, 10.803f, 11.964f, 13.125f,11.37f, 12.693f, 14.016f, 15.339f, 5.266f, 5.707f, 6.148f, 6.589f,12.98f, 13.916f, 14.852f, 15.788f,14.564f, 15.608f, 16.652f, 17.696f, + 3.25f, 4.015f, 4.78f, 5.545f, 9.812f, 11.396f, 12.98f, 14.564f,10.532f, 12.224f, 13.916f, 15.608f, 9.708f, 10.977f, 12.246f, 13.515f,25.194f, 27.813f, 30.432f, 33.051f,26.922f, 29.703f, 32.484f, 35.265f, + 11.814f, 13.326f, 14.838f, 16.35f,30.378f, 33.483f, 36.588f, 39.693f,32.106f, 35.373f, 38.64f, 41.907f,13.474f, 14.563f, 15.652f, 16.741f,31.988f, 34.22f, 36.452f, 38.684f,33.572f, 35.912f, 38.252f, 40.592f}); + + auto expGradW = NDArrayFactory::create('c', {kH, kW, iC, oC},{14.4f, 14.76f, 15.12f,14.4f, 14.76f, 15.12f,14.4f, 14.76f, 15.12f,14.4f, 14.76f, 15.12f, 9.24f, 9.48f, 9.72f, 9.24f, 9.48f, 9.72f, 9.24f, 9.48f, 9.72f, 9.24f, 9.48f, 9.72f, + 17.04f, 17.52f, 18.f,17.04f, 17.52f, 18.f, 17.04f, 17.52f, 18.f, 17.04f, 17.52f, 18.f,10.88f, 11.2f, 11.52f,10.88f, 11.2f, 11.52f,10.88f, 11.2f, 11.52f,10.88f, 11.2f, 11.52f, + 11.16f, 11.52f, 11.88f,11.16f, 11.52f, 11.88f,11.16f, 11.52f, 11.88f,11.16f, 11.52f, 11.88f, 7.08f, 7.32f, 7.56f, 7.08f, 7.32f, 7.56f, 7.08f, 7.32f, 7.56f, 7.08f, 7.32f, 7.56f}); + // auto expGradB('c', {oC},{}); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::conv2d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test2) { + + int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=2,oW=2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); + auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); + auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); + + auto expGradI = NDArrayFactory::create('c', {bS, iH, iW, iC},{ 0.014f, 0.032f, 0.05f, 0.068f,0.118f,0.181f, 0.244f, 0.307f,0.212f,0.257f, 0.302f, 0.347f,0.208f,0.298f, 0.388f, 0.478f,1.028f,1.262f, 1.496f, 1.73f,1.036f,1.18f, 1.324f, 1.468f, + 0.928f,1.018f, 1.108f, 1.198f,2.9f,3.134f, 3.368f, 3.602f,2.188f,2.332f, 2.476f, 2.62f, 1.202f,1.274f, 1.346f, 1.418f,3.142f,3.313f, 3.484f, 3.655f,2.048f,2.147f, 2.246f, 2.345f, + 0.086f,0.212f, 0.338f, 0.464f,0.694f,0.973f, 1.252f, 1.531f,0.716f,0.869f, 1.022f, 1.175f,1.216f,1.522f, 1.828f, 2.134f,3.908f,4.574f, 5.24f, 5.906f,2.908f,3.268f, 3.628f, 3.988f, + 3.664f,3.97f, 4.276f, 4.582f,9.236f,9.902f,10.568f,11.234f,5.788f,6.148f, 6.508f, 6.868f,3.002f,3.182f, 3.362f, 3.542f,7.174f,7.561f, 7.948f, 8.335f,4.28f,4.487f, 4.694f, 4.901f}); + + auto expGradW = NDArrayFactory::create('c', {kH, kW, iC, oC},{1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f, + 1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f, + 1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f,1.84f, 2.f, 2.16f}); + // auto expGradB('c', {oC},{}); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::conv2d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test3) { + int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=2,oW=2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto weights = NDArrayFactory::create('c', {oC, iC, kH, kW}); + auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); + auto gradO = NDArrayFactory::create('c', {bS, oC, oH, oW}); + + auto expGradI = NDArrayFactory::create('c', {bS, iC, iH, iW},{ 0.567f, 1.224f, 0.66f, 1.314f, 2.82f, 1.512f, 1.386f, 2.976f, 1.596f, 0.801f, 1.71f, 0.912f, 0.657f, 1.422f, 0.768f, 1.53f, 3.288f, 1.764f, 1.602f, 3.444f, 1.848f, 0.927f, 1.98f, 1.056f, + 0.747f, 1.62f, 0.876f, 1.746f, 3.756f, 2.016f, 1.818f, 3.912f, 2.1f, 1.053f, 2.25f, 1.2f, 0.837f, 1.818f, 0.984f, 1.962f, 4.224f, 2.268f, 2.034f, 4.38f, 2.352f, 1.179f, 2.52f, 1.344f, + 1.467f, 3.06f, 1.596f, 3.186f, 6.636f, 3.456f, 3.402f, 7.08f, 3.684f, 1.845f, 3.834f, 1.992f, 1.773f, 3.69f, 1.92f, 3.834f, 7.968f, 4.14f, 4.05f, 8.412f, 4.368f, 2.187f, 4.536f, 2.352f, + 2.079f, 4.32f, 2.244f, 4.482f, 9.3f, 4.824f, 4.698f, 9.744f, 5.052f, 2.529f, 5.238f, 2.712f, 2.385f, 4.95f, 2.568f, 5.13f, 10.632f, 5.508f, 5.346f, 11.076f, 5.736f, 2.871f, 5.94f, 3.072f}); + + auto expGradW = NDArrayFactory::create('c', {oC, iC, kH, kW},{1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, + 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, + 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, + 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f}); + auto expGradB = NDArrayFactory::create('c', {oC},{0.68f, 1.f, 1.32f}); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + weights.permutei({2,3,1,0}); + expGradW.permutei({2,3,1,0}); + + sd::ops::conv2d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv2d_bp_4) { + + int bS=1, iH=7,iW=1, iC=2,oC=3, kH=2,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=7,oW=1; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, oC}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {1,2,3}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32); + + NDArray gradI('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray gradW('c', {kH, kW, iC, oC}, sd::DataType::FLOAT32); + NDArray gradB('c', {oC}, sd::DataType::FLOAT32); + + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::conv2d_bp op; + auto status = op.execute({&input, &weights, &bias, &gradO}, {&gradI, &gradW, &gradB}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}, {}); + + ASSERT_EQ(Status::OK(), status); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv2d_bp_5) { + + int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=2,oW=2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = 1; // 0-[kH, kW, iC, oC], 1-[oC, iC, kH, kW], 2-[oC, kH, kW, iC] + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray weights('c', {oC, iC, kH, kW}, {3.6, 2.4, 1.2, 0.0, -1.2, -2.4, 3.3, 2.1, 0.9, -0.3, -1.5, -2.7, 3.0, 1.8, 0.6, -0.6, -1.8, -3.0, 2.7, 1.5, 0.3, -0.9, -2.1, -3.3, 3.5, 2.3, 1.1, -0.1, -1.3, -2.5, 3.2, 2.0, 0.8, -0.4, -1.6, -2.8, 2.9, 1.7, 0.5, -0.7, -1.9, -3.1, 2.6, 1.4, 0.2, -1.0, -2.2, -3.4, 3.4, 2.2, 1.0, -0.2, -1.4, -2.6, 3.1, 1.9, 0.7, -0.5, -1.7, -2.9, 2.8, 1.6, 0.4, -0.8, -2.0, -3.2, 2.5, 1.3, 0.1, -1.1, -2.3, -3.5}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {1,-0.5, 0.1}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32); + + NDArray expGradI('c', {bS, iC, iH, iW},{0.517, 0.959, 0.406, 0.884, 1.474, 0.518, 0.020, -0.398, -0.490, -0.281, -0.853, -0.608, 0.472, 0.860, 0.352, 0.776, 1.240, + 0.392, -0.088, -0.632, -0.616, -0.344, -0.988, -0.680, 0.427, 0.761, 0.298, 0.668, 1.006, 0.266, -0.196, -0.866, -0.742, -0.407, -1.123, -0.752, 0.382, 0.662, + 0.244, 0.560, 0.772, 0.140, -0.304, -1.100, -0.868, -0.470, -1.258, -0.824, 1.777, 3.047, 1.234, 2.540, 3.922, 1.310, -0.052, -1.406, -1.426, -0.749, -2.221, + -1.508, 1.624, 2.732, 1.072, 2.216, 3.256, 0.968, -0.376, -2.072, -1.768, -0.920, -2.572, -1.688, 1.471, 2.417, 0.910, 1.892, 2.590, 0.626, -0.700, -2.738, -2.110, + -1.091, -2.923, -1.868, 1.318, 2.102, 0.748, 1.568, 1.924, 0.284, -1.024, -3.404, -2.452, -1.262, -3.274, -2.048}, sd::DataType::FLOAT32); + + NDArray expGradW('c', {oC, iC, kH, kW},{-3.3, -2.62, -1.26, -0.58, 0.78, 1.46, 4.86, 5.54, 6.9, 7.58, 8.940001, 9.619999, 13.02, 13.700001, 15.06, 15.74, 17.1, + 17.780001, 21.18, 21.860001, 23.219999, 23.900002, 25.259998, 25.940001, -10.340001, -9.34, -7.339999, -6.34, -4.339999, -3.339999, 1.66, 2.66, 4.660001, + 5.660001, 7.66, 8.66, 13.66, 14.660001, 16.66, 17.66, 19.66, 20.66, 25.66, 26.66, 28.66, 29.66, 31.66, 32.66, -17.380001, -16.059999, -13.420003, -12.099999, + -9.46, -8.139999, -1.540001, -0.219999, 2.419999, 3.739999, 6.379999, 7.7, 14.299999, 15.62, 18.26, 19.58, 22.219999, 23.539999, 30.139999, 31.459999, 34.099998, + 35.419998, 38.060001, 39.380001}, sd::DataType::FLOAT32); + + NDArray expGradB('c', {oC}, {0.68, 1., 1.32}, sd::DataType::FLOAT32); + + input.linspace(-48, 1); + // weights.linspace(3.6, -0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::conv2d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv2d_bp_6) { + + int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=4,oW=3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = 2; // 0-[kH, kW, iC, oC], 1-[oC, iC, kH, kW], 2-[oC, kH, kW, iC] + + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {oC, kH, kW, iC}, {3.6, 0.0, 3.3, -0.3, 3.0, -0.6, 2.7, -0.9, 3.5, -0.1, 3.2, -0.4, 2.9, -0.7, 2.6, -1.0, 3.4, -0.2, 3.1, -0.5, 2.8, -0.8, 2.5, -1.1, 2.4, -1.2, 2.1, -1.5, 1.8, -1.8, 1.5, -2.1, 2.3, -1.3, 2.0, -1.6, 1.7, -1.9, 1.4, -2.2, 2.2, -1.4, 1.9, -1.7, 1.6, -2.0, 1.3, -2.3, 1.2, -2.4, 0.9, -2.7, 0.6, -3.0, 0.3, -3.3, 1.1, -2.5, 0.8, -2.8, 0.5, -3.1, 0.2, -3.4, 1.0, -2.6, 0.7, -2.9, 0.4, -3.2, 0.1, -3.5}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {1,-0.5, 0.1}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32); + + NDArray expGradI('c', {bS, iH, iW, iC}, {0.882, -0.522, 0.765, -0.639, 1.953, -1.503, 1.665, -1.791, 2.691, -2.061, 2.295, -2.457, 2.259, -1.305, 1.962, -1.602, 4.545, + -3.555, 3.870, -4.230, 5.625, -4.419, 4.788, -5.256001, 4.122, -2.358, 3.582, -2.898, 7.785, -6.147, 6.624, -7.308, 8.865, -7.011, 7.541999, -8.334, 3.273, -2.019, + 2.832, -2.460, 6.069, -5.163, 5.133, -6.099, 6.771, -5.757, 5.727, -6.801, 5.958, -3.222, 5.193, -3.987, 10.809, -8.198999, 9.225, -9.783, 11.547, -8.757, 9.855, + -10.448999, 9.711, -5.517, 8.441999, -6.786, 17.505001, -13.922999, 14.886, -16.542, 18.585001, -14.787001, 15.804001, -17.568001, 11.574, -6.570, 10.062, -8.082, + 20.745001, -16.514999, 17.639999, -19.619999, 21.825001, -17.379002, 18.558001, -20.646, 8.133, -4.935, 7.044, -6.024, 14.492998, -12.291, 12.261, -14.523001, 15.195001, -12.885, 12.855, -15.225}, sd::DataType::FLOAT32); + + NDArray expGradW('c', {oC, kH, kW, iC},{34.559998, 41.760010, 48.959999, 56.160004, 33.119999, 37.739998, 42.360001, 46.979996, 120.960007, 129.480011, 138.0, 146.519989, + 91.200005, 96.639999, 102.079994, 107.520004, 114.479996, 120.059998, 125.639999, 131.220001, 82.080002, 85.620003, 89.160004, 92.699997, 33.120003, 40.499996, + 47.879993, 55.260002, 32.399998, 37.139996, 41.880001, 46.620003, 120.479988, 129.240005, 137.999985, 146.759995, 91.199997, 96.799995, 102.399994, 108.0, 115.199989, + 120.959999, 126.720001, 132.479996, 82.799995, 86.460007, 90.119995, 93.779999, 31.679998, 39.239994, 46.800003, 54.359997, 31.680000, 36.540001, 41.400002, 46.260002, + 120.0, 129.0, 138.0, 147.0, 91.200005, 96.960007, 102.720001, 108.480003, 115.919998, 121.860001, 127.799988, 133.740005, 83.520004, 87.300003, 91.080002, 94.860001}, sd::DataType::FLOAT32); + + NDArray expGradB('c', {oC}, {8.520, 8.760, 9.}, sd::DataType::FLOAT32); + + input.linspace(-48, 1); + gradO.linspace(0.01, 0.01); + + sd::ops::conv2d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} + +//////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test1) { + + int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=3,oH=4,oW=3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); + auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); + + auto expGradI = NDArrayFactory::create('c', {bS, iD, iH, iW, iC},{0.226f, 0.343f, 0.46f, 0.577f, 1.172f, 1.46f, 1.748f, 2.036f, 1.892f, 2.288f, 2.684f, 3.08f, 1.284f, 1.581f, 1.878f, 2.175f, 4.458f, 5.133f, 5.808f, 6.483f, 6.186f, 7.023f, 7.86f, 8.697f, 3.39f, 3.93f, 4.47f, 5.01f, 9.642f, 10.803f, 11.964f, 13.125f, 11.37f, 12.693f, 14.016f, 15.339f, + 5.266f, 5.707f, 6.148f, 6.589f, 12.98f, 13.916f, 14.852f, 15.788f, 14.564f, 15.608f, 16.652f, 17.696f, 6.284f, 7.166f, 8.048f, 8.93f, 17.896f, 19.768f, 21.64f, 23.512f, 21.928f, 24.016f, 26.104f, 28.192f, 18.12f, 19.686f, 21.252f, 22.818f, 45.852f, 49.146f, 52.44f, 55.734f, 53.196f, 56.814f, 60.432f, 64.05f, + 28.164f, 30.216f, 32.268f, 34.32f, 67.884f, 72.15f, 76.416f, 80.682f, 75.228f, 79.818f, 84.408f, 88.998f, 29.324f, 30.854f, 32.384f, 33.914f, 67.432f, 70.6f, 73.768f, 76.936f, 73.192f, 76.576f, 79.96f, 83.344f, 27.884f, 30.062f, 32.24f, 34.418f, 66.28f, 70.744f, 75.208f, 79.672f, 70.312f, 74.992f, 79.672f, 84.352f, + 58.296f, 61.806f, 65.316f, 68.826f, 133.98f, 141.162f, 148.344f, 155.526f, 141.324f, 148.83f, 156.336f, 163.842f, 68.34f, 72.336f, 76.332f, 80.328f, 156.012f, 164.166f, 172.32f, 180.474f, 163.356f, 171.834f, 180.312f, 188.79f, 61.292f, 64.118f, 66.944f, 69.77f, 136.552f, 142.312f, 148.072f, 153.832f, 142.312f, 148.288f, 154.264f, 160.24f, + 9.298f, 11.359f, 13.42f, 15.481f, 27.092f, 31.268f, 35.444f, 39.62f, 27.812f, 32.096f, 36.38f, 40.664f, 26.556f, 29.769f, 32.982f, 36.195f, 66.666f, 73.173f, 79.68f, 86.187f, 68.394f, 75.063f, 81.732f, 88.401f, 28.662f, 32.118f, 35.574f, 39.03f, 71.85f, 78.843f, 85.836f, 92.829f, 73.578f, 80.733f, 87.888f, 95.043f, + 29.89f, 32.275f, 34.66f, 37.045f, 70.004f, 74.828f, 79.652f, 84.476f, 71.588f, 76.52f, 81.452f, 86.384f, 71.084f, 75.854f, 80.624f, 85.394f, 163.048f, 172.696f, 182.344f, 191.992f, 167.08f, 176.944f, 186.808f, 196.672f, 138.648f, 146.046f, 153.444f, 160.842f, 310.236f, 325.194f, 340.152f, 355.11f, 317.58f, 332.862f, 348.144f, 363.426f, + 148.692f, 156.576f, 164.46f, 172.344f, 332.268f, 348.198f, 364.128f, 380.058f, 339.612f, 355.866f, 372.12f, 388.374f, 125.228f, 130.646f, 136.064f, 141.482f, 274.792f, 285.736f, 296.68f, 307.624f, 280.552f, 291.712f, 302.872f, 314.032f, 92.684f, 98.75f, 104.816f, 110.882f, 211.432f, 223.672f, 235.912f, 248.152f, 215.464f, 227.92f, 240.376f, 252.832f, + 178.824f, 188.166f, 197.508f, 206.85f, 398.364f, 417.21f, 436.056f, 454.902f, 405.708f, 424.878f, 444.048f, 463.218f, 188.868f, 198.696f, 208.524f, 218.352f, 420.396f, 440.214f, 460.032f, 479.85f, 427.74f, 447.882f, 468.024f, 488.166f, 157.196f, 163.91f, 170.624f, 177.338f, 343.912f, 357.448f, 370.984f, 384.52f, 349.672f, 363.424f, 377.176f, 390.928f}); + + auto expGradW = NDArrayFactory::create('c', {kD, kH, kW, iC, oC},{120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, + 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, + 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, + 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, + 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, + 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f}); + // auto expGradB('c', {oC},{}); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::conv3dnew_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + +} + + +//////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test2) { + + int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=2,oH=2,oW=2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); + auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); + + auto expGradI = NDArrayFactory::create('c', {bS, iD, iH, iW, iC},{ 0.014f, 0.032f, 0.05f, 0.068f, 0.118f, 0.181f, 0.244f, 0.307f, 0.212f, 0.257f, 0.302f, 0.347f, 0.208f, 0.298f, 0.388f, 0.478f, 1.028f, 1.262f, 1.496f, 1.73f, 1.036f, 1.18f, 1.324f, 1.468f, 0.928f, 1.018f, 1.108f, 1.198f, 2.9f, 3.134f, 3.368f, 3.602f, 2.188f, 2.332f, 2.476f, 2.62f, + 1.202f, 1.274f, 1.346f, 1.418f, 3.142f, 3.313f, 3.484f, 3.655f, 2.048f, 2.147f, 2.246f, 2.345f, 0.532f, 0.676f, 0.82f, 0.964f, 2.324f, 2.666f, 3.008f, 3.35f, 2.008f, 2.206f, 2.404f, 2.602f, 3.584f, 3.98f, 4.376f, 4.772f, 10.552f, 11.452f, 12.352f, 13.252f, 7.4f, 7.904f, 8.408f, 8.912f, + 6.752f, 7.148f, 7.544f, 7.94f, 17.752f, 18.652f, 19.552f, 20.452f, 11.432f, 11.936f, 12.44f, 12.944f, 5.932f, 6.184f, 6.436f, 6.688f, 14.42f, 14.978f, 15.536f, 16.094f, 8.704f, 9.01f, 9.316f, 9.622f, 3.11f, 3.236f, 3.362f, 3.488f, 7.39f, 7.669f, 7.948f, 8.227f, 4.388f, 4.541f, 4.694f, 4.847f, + 8.56f, 8.866f, 9.172f, 9.478f, 19.892f, 20.558f, 21.224f, 21.89f, 11.548f, 11.908f, 12.268f, 12.628f, 11.008f, 11.314f, 11.62f, 11.926f, 25.22f, 25.886f, 26.552f, 27.218f, 14.428f, 14.788f, 15.148f, 15.508f, 7.322f, 7.502f, 7.682f, 7.862f, 16.462f, 16.849f, 17.236f, 17.623f, 9.248f, 9.455f, 9.662f, 9.869f, + 0.158f, 0.392f, 0.626f, 0.86f, 1.27f, 1.765f, 2.26f, 2.755f, 1.22f, 1.481f, 1.742f, 2.003f, 2.224f, 2.746f, 3.268f, 3.79f, 6.788f, 7.886f, 8.984f, 10.082f, 4.78f, 5.356f, 5.932f, 6.508f, 6.4f, 6.922f, 7.444f, 7.966f, 15.572f, 16.67f, 17.768f, 18.866f, 9.388f, 9.964f, 10.54f, 11.116f, + 4.802f, 5.09f, 5.378f, 5.666f, 11.206f, 11.809f, 12.412f, 13.015f, 6.512f, 6.827f, 7.142f, 7.457f, 6.004f, 6.58f, 7.156f, 7.732f, 14.996f, 16.202f, 17.408f, 18.614f, 9.208f, 9.838f, 10.468f, 11.098f, 17.984f, 19.244f, 20.504f, 21.764f, 42.808f, 45.436f, 48.064f, 50.692f, 25.256f, 26.624f, 27.992f, 29.36f, + 28.064f, 29.324f, 30.584f, 31.844f, 63.832f, 66.46f, 69.088f, 71.716f, 36.2f, 37.568f, 38.936f, 40.304f, 18.316f, 19.f, 19.684f, 20.368f, 40.916f, 42.338f, 43.76f, 45.182f, 22.816f, 23.554f, 24.292f, 25.03f, 8.438f, 8.78f, 9.122f, 9.464f, 18.91f, 19.621f, 20.332f, 21.043f, 10.58f, 10.949f, 11.318f, 11.687f, + 20.944f, 21.682f, 22.42f, 23.158f, 46.388f, 47.918f, 49.448f, 50.978f, 25.66f, 26.452f, 27.244f, 28.036f, 26.848f, 27.586f, 28.324f, 29.062f, 58.628f, 60.158f, 61.688f, 63.218f, 31.996f, 32.788f, 33.58f, 34.372f, 16.106f, 16.502f, 16.898f, 17.294f, 34.894f, 35.713f, 36.532f, 37.351f, 18.896f, 19.319f, 19.742f, 20.165f}); + + auto expGradW = NDArrayFactory::create('c', {kD, kH, kW, iC, oC},{7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, + 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, + 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, + 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f}); + // auto expGradB('c', {oC},{}); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::conv3dnew_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + +} + +//////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test3) { + + int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=2,oH=2,oW=2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto weights = NDArrayFactory::create('c', {oC, iC, kD, kH, kW}); + auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); + auto gradO = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); + + auto expGradI = NDArrayFactory::create('c', {bS, iC, iD, iH, iW},{2.091f, 4.356f, 2.268f, 4.53f, 9.42f, 4.896f, 4.65f, 9.672f, 5.028f, 2.517f, 5.226f, 2.712f, 4.932f, 10.242f, 5.316f, 10.62f, 22.02f, 11.412f, 10.908f, 22.62f, 11.724f, 5.868f, 12.15f, 6.288f, 2.913f, 6.03f, 3.12f, 6.234f, 12.888f, 6.66f, 6.402f, 13.236f, 6.84f, 3.423f, 7.068f, 3.648f, + 2.415f, 5.04f, 2.628f, 5.25f, 10.932f, 5.688f, 5.37f, 11.184f, 5.82f, 2.913f, 6.054f, 3.144f, 5.724f, 11.898f, 6.18f, 12.348f, 25.62f, 13.284f, 12.636f, 26.22f, 13.596f, 6.804f, 14.094f, 7.296f, 3.381f, 7.002f, 3.624f, 7.242f, 14.976f, 7.74f, 7.41f, 15.324f, 7.92f, 3.963f, 8.184f, 4.224f, + 2.739f, 5.724f, 2.988f, 5.97f, 12.444f, 6.48f, 6.09f, 12.696f, 6.612f, 3.309f, 6.882f, 3.576f, 6.516f, 13.554f, 7.044f, 14.076f, 29.22f, 15.156f, 14.364f, 29.82f, 15.468f, 7.74f, 16.038f, 8.304f, 3.849f, 7.974f, 4.128f, 8.25f, 17.064f, 8.82f, 8.418f, 17.412f, 9.f, 4.503f, 9.3f, 4.8f, + 3.063f, 6.408f, 3.348f, 6.69f, 13.956f, 7.272f, 6.81f, 14.208f, 7.404f, 3.705f, 7.71f, 4.008f, 7.308f, 15.21f, 7.908f, 15.804f, 32.82f, 17.028f, 16.092f, 33.42f, 17.34f, 8.676f, 17.982f, 9.312f, 4.317f, 8.946f, 4.632f, 9.258f, 19.152f, 9.9f, 9.426f, 19.5f, 10.08f, 5.043f, 10.416f, 5.376f, + 5.619f, 11.484f, 5.868f, 11.73f, 23.964f, 12.24f, 12.138f, 24.792f, 12.66f, 6.333f, 12.93f, 6.6f, 12.42f, 25.362f, 12.948f, 25.884f, 52.836f, 26.964f, 26.748f, 54.588f, 27.852f, 13.932f, 28.422f, 14.496f, 6.873f, 14.022f, 7.152f, 14.298f, 29.16f, 14.868f, 14.754f, 30.084f, 15.336f, 7.671f, 15.636f, 7.968f, + 6.807f, 13.896f, 7.092f, 14.178f, 28.932f, 14.76f, 14.586f, 29.76f, 15.18f, 7.593f, 15.486f, 7.896f, 14.94f, 30.474f, 15.54f, 31.068f, 63.348f, 32.292f, 31.932f, 65.1f, 33.18f, 16.596f, 33.822f, 17.232f, 8.205f, 16.722f, 8.52f, 17.034f, 34.704f, 17.676f, 17.49f, 35.628f, 18.144f, 9.075f, 18.48f, 9.408f, + 7.995f, 16.308f, 8.316f, 16.626f, 33.9f, 17.28f, 17.034f, 34.728f, 17.7f, 8.853f, 18.042f, 9.192f, 17.46f, 35.586f, 18.132f, 36.252f, 73.86f, 37.62f, 37.116f, 75.612f, 38.508f, 19.26f, 39.222f, 19.968f, 9.537f, 19.422f, 9.888f, 19.77f, 40.248f, 20.484f, 20.226f, 41.172f, 20.952f, 10.479f, 21.324f, 10.848f, + 9.183f, 18.72f, 9.54f, 19.074f, 38.868f, 19.8f, 19.482f, 39.696f, 20.22f, 10.113f, 20.598f, 10.488f, 19.98f, 40.698f, 20.724f, 41.436f, 84.372f, 42.948f, 42.3f, 86.124f, 43.836f, 21.924f, 44.622f, 22.704f, 10.869f, 22.122f, 11.256f, 22.506f, 45.792f, 23.292f, 22.962f, 46.716f, 23.76f, 11.883f, 24.168f, 12.288f}); + + auto expGradW = NDArrayFactory::create('c', {oC, iC, kD, kH, kW},{5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, + 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, + 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, + 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, + 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, + 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f}); + + auto expGradB = NDArrayFactory::create('c', {oC},{2.64f, 3.92f, 5.2f}); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + weights.permutei({2, 3, 4, 1, 0}); + expGradW.permutei({2, 3, 4, 1, 0}); + + sd::ops::conv3dnew_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); + auto* gradI = results.at(0); + auto* gradW = results.at(1); + auto* gradB = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv3d_bp_test4) { + + int bS=2, iD=4,iH=3,iW=3, iC=4,oC=3, kD=3,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=2,oH=2,oW=2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = 1; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] + + NDArray input('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); + NDArray weights('c', {oC, iC, kD, kH, kW}, {7., 5.8, 4.6, 3.4, 2.2, 1., -0.2, -1.4, -2.6, -3.8, -5., -6.2, 6.7, 5.5, 4.3, 3.1, 1.9, 0.7, -0.5, -1.7, -2.9, -4.1, + -5.3, -6.5, 6.4, 5.2, 4., 2.8, 1.6, 0.4, -0.8, -2., -3.2, -4.4, -5.6, -6.8, 6.1, 4.9, 3.7, 2.5, 1.3, 0.1, -1.1, -2.3, -3.5, -4.7, -5.9, -7.1, 6.9, 5.7, 4.5, + 3.3, 2.1, 0.9, -0.3, -1.5, -2.7, -3.9, -5.1, -6.3, 6.6, 5.4, 4.2, 3., 1.8, 0.6, -0.6, -1.8, -3., -4.2, -5.4, -6.6, 6.3, 5.1, 3.9, 2.7, 1.5, 0.3, -0.9, -2.1, + -3.3, -4.5, -5.7, -6.9, 6., 4.8, 3.6, 2.4, 1.2, 0., -1.2, -2.4, -3.6, -4.8, -6., -7.2, 6.8, 5.6, 4.4, 3.2, 2., 0.8, -0.4, -1.6, -2.8, -4., -5.2, -6.4, 6.5, 5.3, 4.1, 2.9, 1.7, 0.5, -0.7, -1.9, -3.1, -4.3, -5.5, -6.7, 6.2, 5., 3.8, 2.6, 1.4, 0.2, -1., -2.2, -3.4, -4.6, -5.8, -7., 5.9, 4.7, 3.5, 2.3, 1.1, -0.1, -1.3, -2.5, -3.7, -4.9, -6.1, -7.3}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {1,-0.5, 0.1}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oC, oD, oH, oW}, sd::DataType::FLOAT32); + + NDArray expGradI('c', {bS, iC, iD, iH, iW},{1.847, 3.577, 1.694, 3.460, 6.542, 3.010, 1.469, 2.677, 1.172, 3.226, 5.929999, 2.632, 5.408, 9.483999, 3.932, 1.894, + 2.978, 1.012, 0.058, -0.694, -0.824, -1.504, -4.916, -3.556, -1.850, -4.798, -3.020, -1.069, -2.687, -1.654, -3.236, -7.714, -4.550, -2.311, -5.315, -3.040, + 1.766, 3.406, 1.604, 3.280, 6.164, 2.812, 1.370, 2.470, 1.064, 3.028, 5.516, 2.416, 4.976, 8.584001, 3.464, 1.660, 2.492, 0.760, -0.140, -1.108, -1.040, -1.936, + -5.816, -4.024, -2.084, -5.284, -3.272, -1.186, -2.930, -1.780, -3.488, -8.236, -4.820, -2.446, -5.594, -3.184, 1.685, 3.235, 1.514, 3.100, 5.786, 2.614, 1.271, + 2.263, 0.956, 2.830, 5.102, 2.200, 4.544001, 7.683999, 2.996, 1.426, 2.006, 0.508, -0.338, -1.522, -1.256, -2.368, -6.716, -4.492, -2.318, -5.770, -3.524, -1.303, + -3.173, -1.906, -3.740, -8.757999, -5.090, -2.581, -5.873, -3.328, 1.604, 3.064, 1.424, 2.920, 5.408, 2.416, 1.172, 2.056, 0.848, 2.632, 4.688, 1.984, 4.112, 6.784, 2.528, 1.192, 1.520, 0.256, -0.536, -1.936, -1.472, -2.800, -7.616, -4.960, -2.552, -6.256, -3.776, -1.420, -3.416, -2.032, -3.992, -9.280001, -5.360, -2.716, -6.152, -3.472, 6.815001, 12.649, 5.798, 11.668, 21.230, 9.490, 4.709, 8.292999, 3.548, 9.706, 17.162001, 7.384, 14.912001, 25.036001, 9.980001, 4.918, 7.298, 2.308, -0.374, -3.286, -2.984, -5.824, -17.012001, -11.332001, -5.738, -14.302, -8.636, -3.013, -7.439, -4.462, -8.852, -20.674, -11.894, -5.983, -13.523, -7.576, 6.518, 12.046, 5.492, 11.056, 19.988001, 8.860001, 4.394, 7.654, 3.224, 9.075999, 15.883999, 6.736001, 13.616, 22.407999, 8.648, 4.252, 5.947999, 1.624, -1.004, -4.564, -3.632, -7.120, -19.639999, -12.664001, -6.404, -15.652, -9.320, -3.346, -8.114, -4.804, -9.536, -22.059999, -12.596, -6.334, -14.233999, -7.936, 6.221, 11.443, 5.186, 10.444, 18.746, 8.230, 4.079, 7.015, 2.900, 8.446, 14.606001, 6.088, 12.320, 19.779999, 7.316, 3.586, 4.598001, 0.940, -1.634, -5.842, -4.280, -8.416, -22.268002, -13.996, -7.070001, -17.001999, -10.004001, -3.679, -8.789, -5.146, -10.220, -23.445999, -13.298, -6.684999, -14.945, -8.296, 5.924, 10.840, 4.880, 9.832001, 17.504, 7.600, 3.764, 6.376, 2.576, 7.816, 13.328, 5.440001, 11.024, 17.152, 5.983999, 2.920, 3.247999, 0.256, -2.264, -7.120, -4.928, -9.712, -24.896, -15.328, -7.736, -18.352001, -10.688, -4.012, -9.464, -5.488, -10.903999, -24.832001, -14.000, -7.035999, -15.656, -8.655999}, sd::DataType::FLOAT32); + + NDArray expGradW('c', {oC, iC, kD, kH, kW},{-24.399998, -23.080000, -20.440001, -19.119999, -12.519999, -11.199998, -8.560001, -7.240002, -0.639999, 0.679999, + 3.320001, 4.640001, 23.119999, 24.439999, 27.080002, 28.400002, 35.000000, 36.320000, 38.959999, 40.279999, 46.879997, 48.200005, 50.839996, 52.160004, + 70.639999, 71.959999, 74.599998, 75.919998, 82.520004, 83.840004, 86.479996, 87.800003, 94.399994, 95.719994, 98.360001, 99.680008, 118.160004, 119.479996, + 122.120003, 123.440010, 130.040009, 131.360001, 134.000000, 135.319992, 141.919998, 143.239990, 145.879990, 147.200012, -70.159996, -68.200005, -64.279999, + -62.319996, -52.519993, -50.559994, -46.640003, -44.680000, -34.880001, -32.919998, -29.000002, -27.040005, 0.400004, 2.359996, 6.279998, 8.240004, 18.040001, + 20.000000, 23.920002, 25.879999, 35.680000, 37.639996, 41.560001, 43.520000, 70.959999, 72.919998, 76.840004, 78.799995, 88.599998, 90.560005, 94.479996, 96.440002, 106.240005, 108.199997, 112.120003, 114.080002, 141.519989, 143.479996, 147.400009, 149.360001, 159.159988, 161.119995, 165.040009, 167.000000, 176.800003, 178.760010, 182.679993, 184.639999, -115.920006, -113.320000, -108.120003, -105.520012, -92.520004, -89.919991, -84.720001, -82.119995, -69.120010, -66.520004, -61.320000, -58.719994, -22.320000, -19.719999, -14.520001, -11.920001, 1.079997, 3.679997, 8.879997, 11.480003, 24.480001, 27.079998, 32.280003, 34.880001, 71.279999, 73.880005, 79.080002, 81.680000, 94.679993, 97.280006, 102.479996, 105.080002, 118.080002, 120.679993, 125.879997, 128.479996, 164.880005, 167.479996, 172.679993, 175.279999, 188.279984, 190.880005, 196.080002, 198.679993, 211.680008, 214.280014, 219.479996, 222.079987}, sd::DataType::FLOAT32); + + NDArray expGradB('c', {oC}, {2.64, 3.92, 5.2}, sd::DataType::FLOAT32); + + input.linspace(-75, 0.5); + gradO.linspace(0.01, 0.01); + + sd::ops::conv3dnew_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv3d_bp_test5) { + + int bS=2, iD=4,iH=3,iW=3, iC=4,oC=3, kD=3,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=4,oH=3,oW=3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = 2; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] + + NDArray input('c', {bS, iD, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {oC, kD, kH, kW, iC}, {15., 14.7, 14.4, 14.1, 13.8, 13.5, 13.2, 12.9, 12.6, 12.3, 12., 11.7, 11.4, 11.1, 10.8, 10.5, 10.2, 9.9, 9.6, 9.3, 9., + 8.7, 8.4, 8.1, 7.8, 7.5, 7.2, 6.9, 6.6, 6.3, 6., 5.7, 5.4, 5.1, 4.8, 4.5, 4.2, 3.9, 3.6, 3.3, 3., 2.7, 2.4, 2.1, 1.8, 1.5, 1.2, 0.9, 14.9, 14.6, 14.3, 14., + 13.7, 13.4, 13.1, 12.8, 12.5, 12.2, 11.9, 11.6, 11.3, 11., 10.7, 10.4, 10.1, 9.8, 9.5, 9.2, 8.9, 8.6, 8.3, 8., 7.7, 7.4, 7.1, 6.8, 6.5, 6.2, 5.9, 5.6, 5.3, 5., + 4.7, 4.4, 4.1, 3.8, 3.5, 3.2, 2.9, 2.6, 2.3, 2., 1.7, 1.4, 1.1, 0.8, 14.8, 14.5, 14.2, 13.9, 13.6, 13.3, 13., 12.7, 12.4, 12.1, 11.8, 11.5, 11.2, 10.9, 10.6, + 10.3, 10., 9.7, 9.4, 9.1, 8.8, 8.5, 8.2, 7.9, 7.6, 7.3, 7., 6.7, 6.4, 6.1, 5.8, 5.5, 5.2, 4.9, 4.6, 4.3, 4., 3.7, 3.4, 3.1, 2.8, 2.5, 2.2, 1.9, 1.6, 1.3, 1., 0.7}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {1,-0.5, 0.1}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oD, oH, oW, oC}, sd::DataType::FLOAT32); + + NDArray expGradI('c', {bS, iD, iH, iW, iC}, {13.565001, 13.286001, 13.007000, 12.728001, 28.264000, 27.652000, 27.040001, 26.427999, 32.547997, 31.827999, 31.108002, + 30.388000, 31.647999, 30.927998, 30.208000, 29.487999, 64.484001, 62.935997, 61.387997, 59.839996, 72.188004, 70.424004, 68.660004, 66.896004, 43.852001, 42.807999, + 41.764000, 40.719997, 87.596001, 85.400002, 83.204002, 81.007996, 95.299988, 92.887993, 90.475998, 88.063995, 34.130997, 33.348000, 32.564999, 31.782001, 67.856995, + 66.210007, 64.563004, 62.916000, 72.987000, 71.178001, 69.369003, 67.559998, 70.179001, 68.369995, 66.561005, 64.751999, 137.927994, 134.147995, 130.367996, 126.587997, + 146.891998, 142.787994, 138.683990, 134.580017, 84.597000, 82.302002, 80.007004, 77.711998, 164.820007, 160.067993, 155.316010, 150.563995, 173.783997, 168.707993, + 163.631989, 158.556000, 58.674000, 57.162003, 55.649994, 54.138000, 114.027008, 110.921997, 107.816994, 104.711990, 119.156998, 115.889999, 112.623001, 109.355995, 113.433006, 110.166000, 106.899002, 103.632004, 218.603989, 211.908020, 205.211975, 198.515991, 227.568008, 220.547974, 213.528015, 206.507996, 127.850998, 124.098000, 120.345001, 116.591995, 245.496002, 237.828018, 230.159988, 222.492004, 254.459991, 246.468002, 238.475998, 230.483994, 34.049000, 32.797997, 31.547001, 30.295998, 64.479996, 61.924000, 59.368004, 56.812000, 67.035995, 64.372002, 61.707996, 59.044003, 62.248001, 59.584003, 56.919998, 54.256001, 116.180000, 110.744003, 105.307999, 99.872002, 120.428001, 114.776001, 109.124001, 103.472000, 69.268005, 66.279999, 63.292000, 60.304001, 128.923996, 122.839996, 116.755997, 110.671997, 133.171997, 126.872002, 120.571991, 114.271996, 94.565002, 92.342010, 90.118996, 87.896004, 182.488007, 177.988007, 173.488007, 168.988007, 186.772003, 182.164001, 177.556000, 172.947998, 178.095993, 173.488007, 168.880005, 164.272003, 341.828003, 332.504028, 323.180023, 313.856018, 349.532013, 339.992004, 330.451996, 320.911987, 190.299988, 185.368011, 180.436005, 175.503998, 364.940002, 354.967987, 344.996002, 335.024017, 372.644012, 362.455994, 352.268005, 342.080017, 132.303009, 128.604004, 124.904999, 121.206001, 252.536987, 245.057999, 237.578979, 230.100006, 257.666992, 250.026001, 242.385010, 234.744019, 243.195007, 235.554001, 227.912994, 220.272003, 460.631958, 445.188019, 429.744019, 414.299988, 469.595947, 453.827972, 438.059998, 422.291992, 257.613007, 249.486008, 241.358994, 233.232010, 487.523987, 471.108032, 454.691986, 438.276001, 496.488037, 479.748016, 463.007996, 446.268005, 156.846008, 152.417999, 147.989990, 143.561996, 298.707001, 289.769989, 280.833008, 271.895996, 303.837006, 294.737976, 285.638977, 276.540009, 286.449005, 277.350006, 268.250977, 259.151978, 541.307983, 522.947998, 504.587982, 486.227997, 550.271973, 531.588013, 512.903992, 494.220032, 300.867004, 291.281982, 281.696991, 272.112000, 568.200012, 548.868042, 529.535950, 510.204010, 577.164062, 557.507935, 537.851990, 518.196045, 83.944992, 80.750000, 77.555000, 74.360001, 156.496002, 150.052002, 143.608002, 137.164001, 159.052002, 152.500000, 145.947998, 139.395996, 146.488007, 139.936005, 133.384003, 126.832001, 269.107971, 255.895996, 242.684006, 229.471985, 273.356018, 259.927979, 246.500000, 233.071991, 153.507996, 146.632004, 139.755997, 132.880005, 281.851990, 267.992004, 254.132004, 240.272003, 286.100006, 272.023987, 257.947998, 243.872009}, sd::DataType::FLOAT32); + + NDArray expGradW('c', {oC, kD, kH, kW, iC}, {396.899872, 429.570007, 462.240234, 494.910156, 313.739960, 335.250000, 356.760071, 378.270020, 403.379944, 424.350006, + 445.320007, 466.289978, 299.520020, 313.319977, 327.119995, 340.920013, 1556.280029, 1594.979980, 1633.679932, 1672.379883, 1090.080078, 1115.520020, 1140.959961, + 1166.400024, 1183.679932, 1208.400024, 1233.119995, 1257.840088, 821.279907, 837.519897, 853.760010, 870.000000, 1500.119873, 1525.500122, 1550.880005, 1576.260010, + 1029.780029, 1046.429932, 1063.080078, 1079.729980, 1080.539917, 1096.650024, 1112.760010, 1128.869995, 738.000000, 748.560059, 759.119995, 769.679993, 389.880005, + 422.819946, 455.759979, 488.699951, 309.420013, 331.109985, 352.799988, 374.490051, 399.780029, 420.930023, 442.080017, 463.230011, 297.359985, 311.280029, 325.200012, 339.120056, 1553.400146, 1592.459961, 1631.520020, 1670.579956, 1088.640015, 1114.320068, 1140.000000, 1165.679932, 1183.199951, 1208.160034, 1233.119995, 1258.079956, 821.280029, 837.680054, 854.079956, 870.479980, 1502.819946, 1528.469971, 1554.119995, 1579.770020, 1031.939941, 1048.770020, 1065.599976, 1082.429932, 1083.420044, 1099.709961, 1116.000000, 1132.290039, 740.159973, 750.840027, 761.519958, 772.199951, 382.859924, 416.070099, 449.279968, 482.489990, 305.099976, 326.970062, 348.840027, 370.709991, 396.179962, 417.510010, 438.839966, 460.169952, 295.200012, 309.239990, 323.279968, 337.320007, 1550.519775, 1589.939941, 1629.359985, 1668.779907, 1087.200073, 1113.119995, 1139.039917, 1164.959961, 1182.719971, 1207.920044, 1233.119995, 1258.320190, 821.279968, 837.840027, 854.400024, 870.959961, 1505.520142, 1531.439819, 1557.359985, 1583.279907, 1034.100098, 1051.110107, 1068.120117, 1085.130005, 1086.299927, 1102.770020, 1119.239990, 1135.710083, 742.319946, 753.119995, 763.919983, 774.720032}, sd::DataType::FLOAT32); + + NDArray expGradB('c', {oC}, {77.400002, 78.119995, 78.840004}, sd::DataType::FLOAT32); + + input.linspace(-75, 0.5); + gradO.linspace(0.01, 0.01); + + sd::ops::conv3dnew_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat}); + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv3d_test1) { + + int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto expected = NDArrayFactory::create('c', {2, 3, 4, 3, 3}, {534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, + 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, + 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f, + 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, + 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f, + 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f}); + input = 2.; + weights.linspace(0.1, 0.1); + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); + auto* output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv3d_test2) { + + int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto expected = NDArrayFactory::create('c', {2, 2, 2, 2, 3}, {686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, + 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f}); + input = 2.; + weights.linspace(0.1, 0.1); + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); + auto* output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv3d_test3) { + + int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2}); + input = 2.; + weights = 0.5; + expected = 48.; + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); + auto* output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +//////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv3d_test4) { + + int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto bias = NDArrayFactory::create('c', {oC}); + auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2}); + + input = 2.; + weights = 0.5; + expected = 49.; + bias = 1.; + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); + auto* output = results.at(0); + + // output->printIndexedBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +//////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv3d_test5) { + + int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto bias = NDArrayFactory::create('c', {oC},{1.f, 2.f, 3.f}); + auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2},{49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, + 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, + 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f}); + input = 2.; + weights = 0.5; + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); + auto* output = results.at(0); + + // output->printIndexedBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +//////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv3d_test6) { + + int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto weights = NDArrayFactory::create('c', {oC, iC, kD, kH, kW}); + auto bias = NDArrayFactory::create('c', {oC},{1.f, 2.f, 3.f}); + auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2},{236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 698.f, 698.f, 698.f, 698.f, + 698.f, 698.f, 698.f, 698.f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, + 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 698.f, 698.f, 698.f, 698.f, + 698.f, 698.f, 698.f, 698.f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f}); + input = 2.; + weights.linspace(0.1, 0.1); + weights.permutei({2, 3, 4, 1, 0}); + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); + auto* output = results.at(0); + + // output->printIndexedBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv3d_test7) { + + int bS=2, iD=3,iH=4,iW=3, iC=4,oC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto weights = NDArrayFactory::create('c', {oC, iC, kD, kH, kW}); + auto expected = NDArrayFactory::create('c', {2, 3, 2, 2, 2},{235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, + 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, + 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f}); + input = 2.; + weights.linspace(0.1, 0.1); + weights.permutei({2, 3, 4, 1, 0}); + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); + auto* output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv3d_test8) { + auto x = NDArrayFactory::create('c', {4, 2, 28, 28, 3}); + auto y = NDArrayFactory::create('c', {2, 5, 5, 3, 4}); + auto e = NDArrayFactory::create('c', {4, 1, 7, 10, 4}); + + sd::ops::conv3dnew op; + auto result = op.evaluate({&x, &y}, {}, {2,5,5, 5,4,3, 0,0,0, 1,1,1, 1,1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(e.isSameShape(z)); +} + +TYPED_TEST(TypedConvolutionTests1, conv3d_test9) { + auto x = NDArrayFactory::create('c', {4, 2, 28, 28, 3}); + auto w = NDArrayFactory::create('c', {2, 5, 5, 3, 4}); + auto exp = NDArrayFactory::create('c', {4, 1, 7, 10, 4}); + + sd::ops::conv3dnew op; + auto result = op.evaluate({&x, &w}, {}, {2,5,5, 5,4,3, 0,0,0, 1,1,1, 1,1}); + ASSERT_EQ(Status::OK(), result.status()); + + ShapeList shapeList({x.shapeInfo(), w.shapeInfo()}); + ContextPrototype proto; + Context ctx(1); + ctx.getIArguments()->push_back(2); + ctx.getIArguments()->push_back(5); + ctx.getIArguments()->push_back(5); + + ctx.getIArguments()->push_back(5); + ctx.getIArguments()->push_back(4); + ctx.getIArguments()->push_back(3); + + ctx.getIArguments()->push_back(0); + ctx.getIArguments()->push_back(0); + ctx.getIArguments()->push_back(0); + + ctx.getIArguments()->push_back(1); + ctx.getIArguments()->push_back(1); + ctx.getIArguments()->push_back(1); + + ctx.getIArguments()->push_back(0); + ctx.getIArguments()->push_back(1); // previous variant was "ctx.getIArguments()->push_back(0)" and this caused fail + + auto shapes = op.calculateOutputShape(&shapeList, ctx); + ASSERT_EQ(1, shapes->size()); + + auto s = shapes->at(0); + + auto z = result.at(0); + // z->printShapeInfo("z shape"); + + ASSERT_TRUE(exp.isSameShape(z)); + + delete shapes; +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv3d_test10) { + + int bS=1, iD=2,iH=2,iW=2, iC=1,oC=1, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + + input = 2.; + weights = 1.; + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); + auto* output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, conv3d_test11) { + + int bS=5, iD=4,iH=14,iW=14, iC=1,oC=1, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=3,oH=13,oW=13; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto expected = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); + + input = 2.; + weights = 1.; + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); + auto* output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(output->isSameShape(&expected)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv3d_test12) { + + int bS=2, iD=4,iH=3,iW=3, iC=4,oC=3, kD=3,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=2,oH=2,oW=2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = 1; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] + + NDArray input('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); + NDArray weights('c', {oC, iC, kD, kH, kW}, {-14.4, -13.2, -12.0, -10.8, -9.6, -8.4, -7.2, -6.0, -4.8, -3.6, -2.4, -1.2, -14.1, -12.9, -11.7, -10.5, -9.3, -8.1, + -6.9, -5.7, -4.5, -3.3, -2.1, -0.9, -13.8, -12.6, -11.4, -10.2, -9.0, -7.8, -6.6, -5.4, -4.2, -3.0, -1.8, -0.6, -13.5, -12.3, -11.1, -9.9, -8.7, -7.5, -6.3, + -5.1, -3.9, -2.7, -1.5, -0.3, -14.3, -13.1, -11.9, -10.7, -9.5, -8.3, -7.1, -5.9, -4.7, -3.5, -2.3, -1.1, -14.0, -12.8, -11.6, -10.4, -9.2, -8.0, -6.8, -5.6, + -4.4, -3.2, -2.0, -0.8, -13.7, -12.5, -11.3, -10.1, -8.9, -7.7, -6.5, -5.3, -4.1, -2.9, -1.7, -0.5, -13.4, -12.2, -11.0, -9.8, -8.6, -7.4, -6.2, -5.0, -3.8, -2.6, -1.4, -0.2, -14.2, -13.0, -11.8, -10.6, -9.4, -8.2, -7.0, -5.8, -4.6, -3.4, -2.2, -1.0, -13.9, -12.7, -11.5, -10.3, -9.1, -7.9, -6.7, -5.5, -4.3, -3.1, -1.9, -0.7, -13.6, -12.4, -11.2, -10.0, -8.8, -7.6, -6.4, -5.2, -4.0, -2.8, -1.6, -0.4, -13.3, -12.1, -10.9, -9.7, -8.5, -7.3, -6.1, -4.9, -3.7, -2.5, -1.3, -0.1}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {-1,2,0.5}, sd::DataType::FLOAT32); + + NDArray expOutput('c', {bS, oC, oD, oH, oW}, {-42520.597656, -42344.199219, -41991.402344, -41814.996094, -40932.992188, -40756.597656, -40403.800781, -40227.406250, + -41953.601562, -41779.601562, -41431.597656, -41257.601562, -40387.601562, -40213.597656, -39865.601562, -39691.597656, -41391.105469, -41219.492188, + -40876.300781, -40704.699219, -39846.707031, -39675.097656, -39331.898438, -39160.300781, -17119.001953, -16942.599609, -16589.798828, -16413.400391, + -15531.399414, -15355.000000, -15002.199219, -14825.800781, -16897.597656, -16723.597656, -16375.599609, -16201.599609, -15331.599609, -15157.600586, + -14809.601562, -14635.598633, -16680.703125, -16509.099609, -16165.900391, -15994.300781, -15136.300781, -14964.700195, -14621.500000, -14449.900391}, sd::DataType::FLOAT32); + + input.linspace(150,-0.5); + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv3d_test13) { + + int bS=2, iD=4,iH=3,iW=3, iC=4,oC=3, kD=3,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=4,oH=3,oW=3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = 2; // 0-[kD, kH, kW, iC, oC], 1-[oC, iC, kD, kH, kW], 2-[oC, kD, kH, kW, iC] + + NDArray input('c', {bS, iD, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {oC, kD, kH, kW, iC}, {-7., -6.7, -6.4, -6.1, -5.8, -5.5, -5.2, -4.9, -4.6, -4.3, -4., -3.7, -3.4, -3.1, -2.8, -2.5, -2.2, -1.9, -1.6, -1.3, + -1., -0.7, -0.4, -0.1, 0.2, 0.5, 0.8, 1.1, 1.4, 1.7, 2., 2.3, 2.6, 2.9, 3.2, 3.5, 3.8, 4.1, 4.4, 4.7, 5., 5.3, 5.6, 5.9, 6.2, 6.5, 6.8, 7.1, -6.9, -6.6, -6.3, + -6., -5.7, -5.4, -5.1, -4.8, -4.5, -4.2, -3.9, -3.6, -3.3, -3., -2.7, -2.4, -2.1, -1.8, -1.5, -1.2, -0.9, -0.6, -0.3, 0., 0.3, 0.6, 0.9, 1.2, 1.5, 1.8, 2.1, + 2.4, 2.7, 3., 3.3, 3.6, 3.9, 4.2, 4.5, 4.8, 5.1, 5.4, 5.7, 6., 6.3, 6.6, 6.9, 7.2, -6.8, -6.5, -6.2, -5.9, -5.6, -5.3, -5., -4.7, -4.4, -4.1, -3.8, -3.5, -3.2, + -2.9, -2.6, -2.3, -2., -1.7, -1.4, -1.1, -0.8, -0.5, -0.2, 0.1, 0.4, 0.7, 1., 1.3, 1.6, 1.9, 2.2, 2.5, 2.8, 3.1, 3.4, 3.7, 4., 4.3, 4.6, 4.9, 5.2, 5.5, 5.8, 6.1, 6.4, 6.7, 7., 7.3}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {-1,2,0.5}, sd::DataType::FLOAT32); + + NDArray expOutput('c', {bS, oD, oH, oW, oC}, {3969.399658, 4168.399902, 4362.899414, 3812.600586, 4005.200195, 4193.299805, 1317.000000, 1413.199829, 1504.899902, + 3498.999756, 3678.800049, 3854.100098, 3342.200195, 3515.599854, 3684.500244, 1139.400024, 1226.000000, 1308.099976, 685.799927, 772.400024, 854.500000, + 645.800049, 729.200073, 808.099976, 80.799995, 123.200012, 161.100006, -2851.000732, -2597.199707, -2347.899414, -2855.799805, -2611.600098, -2371.900879, + -2124.399414, -2003.199951, -1886.500244, -2865.399902, -2640.400146, -2419.899902, -2870.199951, -2654.800049, -2443.899902, -2045.200073, -1938.399902, + -1836.100220, -2596.000244, -2489.199707, -2386.900146, -2540.799561, -2438.800049, -2341.300049, -1539.699951, -1488.400024, -1441.599854, -2894.200195, + -2726.800049, -2563.899902, -2899.000488, -2741.199707, -2587.899658, -1886.800171, -1808.800049, -1735.300171, -2908.599121, -2770.000488, -2635.900146, -2913.400146, -2784.399658, -2659.899902, -1807.599976, -1743.999878, -1684.900146, -2099.199951, -2035.599976, -1976.500366, -2044.000244, -1985.199707, -1930.900024, -1161.699951, -1132.000122, -1106.800171, -2731.399902, -2647.599609, -2568.300293, -2580.999756, -2503.600098, -2430.699951, -1457.400024, -1418.800049, -1384.700073, -2280.200195, -2215.600098, -2155.500732, -2129.799561, -2071.600098, -2017.899780, -1174.200073, -1145.200195, -1120.699829, -1282.200073, -1253.199951, -1228.699951, -1168.599976, -1142.799927, -1121.500122, -615.199951, -601.600037, -592.500000, -1675.399658, -1706.800049, -1742.700073, -1832.200073, -1870.000000, -1912.299561, -814.199951, -833.200012, -856.699951, -2145.800049, -2196.399902, -2251.500244, -2302.600342, -2359.599854, -2421.100098, -991.800049, -1020.400024, -1053.500000, -754.199951, -782.800049, -815.900085, -794.199951, -825.999939, -862.299988, -293.600006, -308.800018, -328.500000, -3023.800293, -3115.600098, -3211.900391, -3028.599121, -3130.000244, -3235.899902, -1173.999878, -1225.600098, -1281.699951, -3038.200195, -3158.799805, -3283.899902, -3043.000000, -3173.199707, -3307.900391, -1094.800049, -1160.800049, -1231.300049, -608.799988, -674.799988, -745.300049, -553.599976, -624.400024, -699.700012, -27.700012, -62.799988, -102.400009, -3066.999512, -3245.199707, -3427.900391, -3071.800293, -3259.599854, -3451.900146, -936.400085, -1031.199951, -1130.500000, -3081.400146, -3288.400635, -3499.899414, -3086.200439, -3302.799805, -3523.899902, -857.199951, -966.400024, -1080.099976, -111.999969, -221.199936, -334.900024, -56.800079, -170.799988, -289.299927, 350.299927, 293.600037, 232.399979, 2683.000244, 2536.400146, 2385.300049, 2833.399658, 2680.400391, 2522.900391, 1940.999878, 1864.399902, 1783.300049, 3134.200195, 2968.399414, 2798.100098, 3284.600098, 3112.400391, 2935.699707, 2224.199707, 2138.000244, 2047.300049, 2807.399658, 2721.200195, 2630.500000, 2921.000000, 2831.599854, 2737.699707, 1775.200195, 1731.199951, 1682.699829}, sd::DataType::FLOAT32); + + input.linspace(75,-0.5); + + sd::ops::conv3dnew op; + auto results = op.evaluate({&input, &weights, &bias}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, pointwise_conv2d_test1) { + + int bS=2, iH=4,iW=3, iC=4,oC=3; + + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {1, 1, iC, oC}); + auto bias = NDArrayFactory::create('c', {oC}); + + + auto expOutput = NDArrayFactory::create('c', {bS, iH, iW, oC},{ 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, + 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, + 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, + 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f}); + input = 2.; + weights.linspace(0.1, 0.1); + bias = 1.; + + sd::ops::pointwise_conv2d op; + auto results = op.evaluate({&input, &weights, &bias}, {}, {dataFormat}); + auto* output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, vol2col_test1) { + + int bS=2, iD=2,iH=3,iW=2, iC=3,oC=2, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=2,oH=3,oW=2; + + NDArray volume('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); + NDArray columns('c', {bS, iC, kD, kH, kW, oD, oH, oW}, sd::DataType::FLOAT32); + + columns = -1.; + volume.linspace(1); + + NDArray columnsExpected('c', {bS, iC, kD, kH, kW, oD, oH, oW}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 2., 0., 4., 0., 6.,0., 8., 0., 10., 0., 12., 0., 3., 4., 5., 6., 0., 0., 9., 10., 11., 12., 0., 0., 4., 0., 6., 0., 0., 0., 10., 0., 12., 0., 0., 0., 5., 6., + 0., 0., 0., 0., 11., 12., 0., 0., 0., 0., 6., 0., 0., 0., 0., 0., 12., 0., 0., 0., 0., 0., 7., 8., 9., 10., 11., 12., 0., 0., 0., 0., 0., 0., 8., 0., 10., 0., 12., 0., 0., 0., 0., 0., 0., 0., 9., 10., 11., 12., 0., 0., 0., 0., 0., 0., 0., 0., 10., 0., 12., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 11., 12., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 12., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 13., 14., 15., 16., 17.,18., 19., 20., 21., 22., 23., 24., 14., 0., 16., 0., 18., 0., 20., 0., 22., 0., 24., 0., 15., 16., 17., 18., 0., 0., 21., 22., 23., 24., 0., + 0., 16., 0., 18., 0., 0., 0., 22., 0., 24., 0., 0., 0., 17., 18., 0., 0., 0., 0., 23., 24., 0., 0., 0., 0., 18., 0., 0., 0., 0., 0., 24., 0., 0., 0., 0., 0., 19., 20., 21., 22., 23., 24., 0., 0., 0., 0., 0., 0., 20., 0., 22., 0., 24., 0., 0., 0., 0., 0., 0., 0., 21., 22., 23., + 24., 0., 0., 0., 0., 0., 0., 0., 0., 22., 0., 24., 0., 0., 0., 0., 0., 0., 0., 0., 0., 23., 24., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 24.,0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 26., 0., 28., 0., 30., 0., 32., 0., + 34., 0., 36., 0., 27., 28., 29., 30., 0., 0., 33., 34., 35., 36., 0., 0., 28., 0., 30., 0., 0., 0., 34., 0., 36., 0., 0., 0., 29., 30., 0., 0., 0., 0., 35., 36., 0., 0., 0., 0., 30., 0., 0., 0., 0., 0., 36., 0., 0., 0., 0., 0., 31., 32., 33., 34., 35., 36., 0., 0., 0., 0., 0., + 0., 32., 0., 34., 0., 36., 0., 0., 0., 0., 0., 0., 0., 33., 34., 35., 36., 0., 0., 0., 0., 0., 0., 0., 0., 34., 0., 36., 0., 0., 0., 0., 0., 0., 0., 0., 0., 35., 36., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 36., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 37., 38., 39., 40., + 41., 42., 43., 44., 45., 46., 47., 48., 38., 0., 40., 0., 42., 0., 44., 0., 46., 0., 48., 0., 39., 40., 41., 42., 0., 0., 45., 46., 47., 48., 0., 0., 40., 0., 42., 0., 0., 0., 46., 0., 48., 0., 0., 0., 41., 42., 0., 0., 0., 0., 47., 48., 0., 0., 0., 0., 42., 0., 0., 0., 0., + 0., 48., 0., 0., 0., 0., 0., 43., 44., 45., 46., 47., 48., 0., 0., 0., 0., 0., 0., 44., 0., 46., 0., 48., 0., 0., 0., 0., 0., 0., 0., 45., 46., 47., 48., 0., 0., 0., 0., 0., 0., 0., 0., 46., 0., 48., 0., 0., 0., 0., 0., 0., 0., 0., 0., 47., 48., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 48., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 50., 0., 52., 0., 54.,0., 56., 0., 58., 0., 60., 0., 51., 52., 53., 54., 0., 0., 57., 58., 59., 60., 0., 0., 52., 0., 54., 0., 0., 0., 58., 0., 60., 0., 0., 0., + 53., 54., 0., 0., 0., 0., 59., 60., 0., 0., 0., 0., 54., 0., 0., 0., 0., 0., 60., 0., 0., 0., 0., 0., 55., 56., 57., 58., 59., 60., 0., 0.,0., 0., 0., 0., 56., 0., 58., 0., 60., 0., 0., 0., 0., 0., 0., 0., 57., 58., 59., 60., 0., 0., 0., 0., 0., 0., 0., 0., 58., 0., 60., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 59., 60., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 60., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 61., 62., 63., 64., 65., 66., 67., 68., 69., 70., 71., 72., 62., 0., 64., 0., 66., 0., 68., 0., 70., 0., 72., 0., 63., 64., 65., 66., 0., 0., 69., + 70., 71., 72., 0., 0., 64., 0., 66., 0., 0., 0., 70., 0., 72., 0., 0., 0., 65., 66., 0., 0., 0., 0., 71., 72., 0., 0., 0., 0., 66., 0., 0., 0., 0., 0., 72., 0., 0., 0., 0., 0., 67., 68., 69., 70., 71., 72., 0., 0., 0., 0., 0., 0., 68., 0., 70., 0., 72., 0., 0., 0., 0., 0., 0., + 0., 69., 70., 71., 72., 0., 0., 0., 0., 0., 0., 0., 0., 70., 0., 72., 0., 0., 0., 0., 0., 0., 0., 0., 0., 71., 72., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 72., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}, sd::DataType::FLOAT32); + + graph::Context context(1); + sd::ops::ConvolutionUtils::vol2col(context, volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); + // columns.printBuffer(); + + ASSERT_TRUE(columns.equalsTo(columnsExpected)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, vol2col_test2) { + + int bS=2, iD=2,iH=3,iW=2, iC=3,oC=2, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=2,oH=3,oW=2; + + auto volume = NDArrayFactory::create('c', {iD, bS, iH, iC, iW}); + volume.permutei({1, 3, 0, 2, 4}); + volume.linspace(1); + + auto columns = NDArrayFactory::create('c', {kD, iC, kH, oW, kW, bS, oD, oH}); + columns.permutei({5, 1, 0, 2, 4, 6, 7, 3}); + columns = -1.; + auto columnsExpected = NDArrayFactory::create('c', {bS, iC, kD, kH, kW, oD, oH, oW}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, + 10.f, 11.f, 12.f, 2.f, 0.f, 4.f, 0.f, 6.f, 0.f, 8.f, 0.f, 10.f, 0.f, 12.f, 0.f, 3.f, 4.f, 5.f, 6.f, 0.f, 0.f, 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 4.f, 0.f, 6.f, 0.f, 0.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 5.f, 6.f, 0.f, 0.f, 0.f, 0.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 6.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 8.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 14.f, 0.f, 16.f, 0.f, 18.f, 0.f, 20.f, 0.f, 22.f, 0.f, 24.f, 0.f, 15.f, 16.f, 17.f, 18.f, 0.f, 0.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 16.f, 0.f, 18.f, 0.f, 0.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 17.f, 18.f, 0.f, 0.f, 0.f, 0.f, + 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 20.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 26.f, 0.f, 28.f, 0.f, 30.f, 0.f, 32.f, 0.f, 34.f, 0.f, 36.f, 0.f, 27.f, 28.f, 29.f, 30.f, 0.f, 0.f, 33.f, 34.f, 35.f, 36.f, + 0.f, 0.f, 28.f, 0.f, 30.f, 0.f, 0.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 29.f, 30.f, 0.f, 0.f, 0.f, 0.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 30.f, 0.f, 0.f, 0.f, 0.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 32.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 33.f, + 34.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 38.f, 0.f, 40.f, + 0.f, 42.f, 0.f, 44.f, 0.f, 46.f, 0.f, 48.f, 0.f, 39.f, 40.f, 41.f, 42.f, 0.f, 0.f, 45.f, 46.f, 47.f, 48.f, 0.f, 0.f, 40.f, 0.f, 42.f, 0.f, 0.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 41.f, 42.f, 0.f, 0.f, 0.f, 0.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 42.f, 0.f, 0.f, 0.f, 0.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 43.f, 44.f, 45.f, 46.f, 47.f, + 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 44.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 45.f, 46.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 50.f, 0.f, 52.f, 0.f, 54.f, 0.f, 56.f, 0.f, 58.f, 0.f, 60.f, 0.f, 51.f, 52.f, 53.f, 54.f, 0.f, 0.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 52.f, 0.f, 54.f, 0.f, 0.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 53.f, 54.f, 0.f, 0.f, 0.f, 0.f, 59.f, 60.f, 0.f, 0.f, + 0.f, 0.f, 54.f, 0.f, 0.f, 0.f, 0.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 56.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 59.f, 60.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 62.f, 0.f, 64.f, 0.f, 66.f, 0.f, 68.f, 0.f, 70.f, 0.f, 72.f, 0.f, 63.f, 64.f, 65.f, 66.f, 0.f, 0.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, 64.f, 0.f, 66.f, + 0.f, 0.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 65.f, 66.f, 0.f, 0.f, 0.f, 0.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 66.f, 0.f, 0.f, 0.f, 0.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 68.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); + + graph::Context context(1); + sd::ops::ConvolutionUtils::vol2col(context, volume, columns, sD, sH, sW, pD, pH, pW, dD, dH, dW); + // columns.printBuffer(); + + ASSERT_TRUE(columns.equalsTo(columnsExpected)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, col2im_test1) { + + int bS=2, iH=2,iW=2, iC=2, kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oH=2,oW=2; + + auto image = NDArrayFactory::create('c', {bS, iC, iH, iW}); + image = -2.; + + auto columns = NDArrayFactory::create('c', {bS, iC, kH, kW, oH, oW}); + columns.linspace(1); + + auto imageExpected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {1.f, 7.f, 12.f, 34.f, 17.f, 39.f, 44.f, 98.f, 33.f, 71.f, 76.f, 162.f, 49.f, 103.f, 108.f, 226.f}); + + + sd::ops::col2im op; + auto status = op.execute({&columns}, {&image}, {sH, sW, pH, pW, iH, iW, dH, dW, 0}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(image.equalsTo(imageExpected)); +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, upsampling2d_test1) { + + const int bS=3, iH=2,iW=2, iC=3; + const int factorH=2, factorW=3; + const int isNCHW = 0; // data format, default is NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + input.linspace(1); + + auto expOutput = NDArrayFactory::create('c', {bS, iH*factorH, iW*factorW, iC}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f}); + + sd::ops::upsampling2d op; + auto results = op.evaluate({&input}, {factorH, factorW, isNCHW}); + auto* output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, upsampling2d_test2) { + + const int bS=3, iH=2,iW=2, iC=3; + const int factorH=2, factorW=3; + const int isNCHW = 1; // data format, default is NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + input.linspace(1); + + auto expOutput = NDArrayFactory::create('c', {bS, iC, iH*factorH, iW*factorW}, {1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, + 5.f, 5.f, 5.f, 6.f, 6.f, 6.f, 5.f, 5.f, 5.f, 6.f, 6.f, 6.f, 7.f, 7.f, 7.f, 8.f, 8.f, 8.f, 7.f, 7.f, 7.f, 8.f, 8.f, 8.f, 9.f, 9.f, 9.f, 10.f, 10.f, 10.f, 9.f, 9.f, 9.f, 10.f, 10.f, 10.f, 11.f, 11.f, 11.f, 12.f, 12.f, 12.f, 11.f, 11.f, 11.f, 12.f, 12.f, 12.f, + 13.f, 13.f, 13.f, 14.f, 14.f, 14.f, 13.f, 13.f, 13.f, 14.f, 14.f, 14.f, 15.f, 15.f, 15.f, 16.f, 16.f, 16.f, 15.f, 15.f, 15.f, 16.f, 16.f, 16.f, 17.f, 17.f, 17.f, 18.f, 18.f, 18.f, 17.f, 17.f, 17.f, 18.f, 18.f, 18.f, 19.f, 19.f, 19.f, 20.f, 20.f, 20.f, 19.f, 19.f, 19.f, 20.f, 20.f, 20.f, + 21.f, 21.f, 21.f, 22.f, 22.f, 22.f, 21.f, 21.f, 21.f, 22.f, 22.f, 22.f, 23.f, 23.f, 23.f, 24.f, 24.f, 24.f, 23.f, 23.f, 23.f, 24.f, 24.f, 24.f, 25.f, 25.f, 25.f, 26.f, 26.f, 26.f, 25.f, 25.f, 25.f, 26.f, 26.f, 26.f, 27.f, 27.f, 27.f, 28.f, 28.f, 28.f, 27.f, 27.f, 27.f, 28.f, 28.f, 28.f, + 29.f, 29.f, 29.f, 30.f, 30.f, 30.f, 29.f, 29.f, 29.f, 30.f, 30.f, 30.f, 31.f, 31.f, 31.f, 32.f, 32.f, 32.f, 31.f, 31.f, 31.f, 32.f, 32.f, 32.f, + 33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f}); + + sd::ops::upsampling2d op; + auto results = op.evaluate({&input}, {factorH, factorW, isNCHW}); + auto* output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, upsampling3d_test1) { + + const int bS=3, iD=2,iH=2,iW=2, iC=3; + const int factorD=2,factorH=3,factorW=2; + const int isNCDHW = 0; // data format, default is NCHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + input.linspace(1); + + auto expOutput = NDArrayFactory::create('c', {bS, iD*factorD, iH*factorH, iW*factorW, iC}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, + 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, + 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, + 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, + 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, + 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, + 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, + 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, + 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, + 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, + 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f}); + + sd::ops::upsampling3d op; + auto results = op.evaluate({&input}, {factorD, factorH, factorW, isNCDHW}); + auto* output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, upsampling3d_test2) { + + const int bS=3, iD=2,iH=2,iW=2, iC=3; + const int factorD=2,factorH=3,factorW=2; + const int isNCDHW = 1; // data format, default is NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + input.linspace(1); + + auto expOutput = NDArrayFactory::create('c', {bS, iC, iD*factorD, iH*factorH, iW*factorW}, { 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, + 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, + 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, + 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, + 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, + 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, + 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, + 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, + 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, + 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, + 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, + 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f}); + + sd::ops::upsampling3d op; + auto results = op.evaluate({&input}, {factorD, factorH, factorW, isNCDHW}); + auto* output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, upsampling3d_bp_test1) { + + const int bS=1, iD=2,iH=2,iW=2, iC=1; + const int factorD=2, factorH=2, factorW=2; + const int isNCDHW = 1; // data format, default is NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, iD*factorD, iH*factorH, iW*factorW}); + gradO = 1.; + + auto expGradI = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + expGradI = 8.; + + sd::ops::upsampling3d_bp op; + auto results = op.evaluate({&input, &gradO}, {isNCDHW}); + auto* gradI = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); +} + +TYPED_TEST(TypedConvolutionTests1, conv2D_input_BP_test1) { + + auto inputShape = NDArrayFactory::create('c', {4}, {2, 1, 4, 4}); + auto weights = NDArrayFactory::create('c', {2, 1, 3, 3}); + auto epsilonNext = NDArrayFactory::create('c', {2, 2, 4, 4}); + auto shapeArr = NDArrayFactory::create('c', {2, 1, 4, 4}); + + + TypeParam _expEpsB[] = {952.0, 1540.0, 1636.0, 1180.0, 1791.0, 2886.0, 3057.0, 2193.0, 2223.0, 3570.0, 3741.0, 2673.0, 1900.0, 3028.0, 3160.0, 2240.0, 2872.0, 4612.0, 4708.0, 3356.0, 5247.0, 8358.0, 8529.0, 6033.0, 5679.0, 9042.0, 9213.0, 6513.0, 4588.0, 7252.0, 7384.0, 5184.0}; + NDArray expEps(_expEpsB, shapeArr.shapeInfo()); + + weights.linspace(1); + epsilonNext.linspace(1); + weights.permutei({2,3,1,0}); + + sd::ops::conv2d_input_bp op; + + auto results = op.evaluate({&inputShape, &weights, &epsilonNext}, {}, {3, 3, 1, 1, 0, 0, 1, 1, 1}); + + ASSERT_TRUE(results.size() == 1); + + auto epsilon = results.at(0); + + ASSERT_TRUE(shapeArr.isSameShape(epsilon)); + ASSERT_TRUE(expEps.equalsTo(epsilon)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, upsampling3d_bp_test3) { + + const int bS=1, iD=3,iH=3,iW=3, iC=2; + const int factorD=2, factorH=2, factorW=2; + const int isNCDHW = 1; // data format, default is NCHW + + NDArray input('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, iC, iD*factorD, iH*factorH, iW*factorW}, {0.6793504, 0.35508695, 0.84278935, 0.20031333, 0.7014987, 0.31069338, + 0.44793984, 0.93800974, 0.32667395, 0.15187258, 0.38331753, 0.78212297, 0.1988072, 0.7985636, 0.1632634, 0.14696825, 0.26089668, + 0.13505761, 0.7562093, 0.27545404, 0.36908787, 0.09282647, 0.83649176, 0.26841334, 0.09506222, 0.31279507, 0.13591796, 0.5175439, + 0.32870287, 0.061735712, 0.39643127, 0.248016, 0.5489592, 0.115046196, 0.8143622, 0.7215636, 0.40449402, 0.29908907, 0.4038839, + 0.9883108, 0.022296403, 0.927782, 0.3184157, 0.0685462, 0.28453344, 0.23272, 0.35214192, 0.058909304, 0.7112212, 0.6744568, 0.19694561, + 0.6994972, 0.0743224, 0.42042503, 0.5842631, 0.14957358, 0.44640633, 0.72307247, 0.06448108, 0.48307765, 0.8759956, 0.5698191, 0.4458631, + 0.5277549, 0.016646361, 0.753678, 0.14063567, 0.7541292, 0.16193217, 0.7750374, 0.3326449, 0.11739397, 0.017710684, 0.60847557, 0.52515227, + 0.9171938, 0.84989065, 0.5894228, 0.85227835, 0.39063585, 0.88968325, 0.6694452, 0.698873, 0.96147966, 0.15740126, 0.15736352, 0.49352047, + 0.5699365, 0.12683152, 0.11572781, 0.7863682, 0.737939, 0.49007934, 0.6084143, 0.9564999, 0.3900982, 0.14730452, 0.8506447, 0.49765033, + 0.07186628, 0.08214969, 0.035314173, 0.7320408, 0.36993408, 0.8406658, 0.27389422, 0.43179566, 0.13323106, 0.19297548, 0.24689731, 0.38641843, + 0.51154125, 0.19903564, 0.1416313, 0.69769853, 0.25363067, 0.78221816, 0.9300991, 0.3355119, 0.5588076, 0.6643576, 0.018850708, 0.63755876, + 0.2904297, 0.43490165, 0.84251267, 0.46609768, 0.38139546, 0.52318525, 0.9901826, 0.9257676, 0.6434591, 0.016828254, 0.9187561, 0.22897908, + 0.0063138064, 0.66597503, 0.19036093, 0.59552056, 0.69888055, 0.22146936, 0.9124342, 0.8708221, 0.7273687, 0.52397245, 0.66288394, 0.2188415, + 0.3354802, 0.03566524, 0.5101009, 0.5017283, 0.75122046, 0.1884508, 0.7407126, 0.6253045, 0.47145858, 0.5369367, 0.19884548, 0.99008304, + 0.08256686, 0.91884845, 0.02360027, 0.98895234, 0.3751719, 0.91783875, 0.4338776, 0.6783008, 0.6667967, 0.46720362, 0.7508773, 0.52304846, + 0.76631916, 0.4187526, 0.7653719, 0.5159193, 0.42730415, 0.49462363, 0.2731735, 0.8862948, 0.043214794, 0.3197591, 0.040378205, 0.5427239, + 0.9228089, 0.045940384, 0.70047987, 0.8419288, 0.53966296, 0.009444186, 0.038044546, 0.03158029, 0.43485752, 0.9204235, 0.5478789, 0.8290083, + 0.11868837, 0.0229866, 0.6639305, 0.8757367, 0.8279557, 0.76270294, 0.43242732, 0.4713431, 0.2569212, 0.30575937, 0.44395888, 0.99384075, + 0.6127142, 0.44844577, 0.6347944, 0.098358564, 0.34233716, 0.9329664, 0.65776783, 0.108565055, 0.2052629, 0.46441218, 0.041791342, 0.89369565, + 0.7000381, 0.2106213, 0.51152664, 0.44200692, 0.8293282, 0.20901772, 0.6387249, 0.8016979, 0.11178707, 0.109545894, 0.19654618, 0.060582615, + 0.08239174, 0.64630795, 0.32862368, 0.60225064, 0.8328141, 0.5484566, 0.8120276, 0.38822946, 0.6742381, 0.34913155, 0.42887798, 0.45344824, + 0.73956585, 0.9714739, 0.42937812, 0.45185348, 0.84535813, 0.046436775, 0.8802151, 0.8676222, 0.42625394, 0.4985318, 0.42399272, 0.122144565, + 0.0060101906, 0.47253844, 0.18123977, 0.86316174, 0.5863874, 0.3852012, 0.9785553, 0.0054711984, 0.88500834, 0.020897374, 0.27467912, 0.3852802, + 0.0766939, 0.94622654, 0.38687763, 0.3308602, 0.7770494, 0.9052543, 0.22258204, 0.42207044, 0.18050623, 0.21057767, 0.012561422, 0.7977821, + 0.61251044, 0.7203693, 0.6028265, 0.6036933, 0.1446382, 0.6712341, 0.76634467, 0.4854034, 0.26634562, 0.76523924, 0.16348523, 0.2663676, + 0.96846986, 0.8273284, 0.10700377, 0.7600526, 0.6771002, 0.47963092, 0.21264452, 0.56934077, 0.5514792, 0.85725874, 0.99090636, 0.54562527, + 0.93597686, 0.21142527, 0.4628326, 0.35011524, 0.31464386, 0.31164807, 0.65928996, 0.94418925, 0.39666295, 0.9496393, 0.103756346, 0.482158, + 0.49171793, 0.4108867, 0.22594318, 0.97093135, 0.5974685, 0.34632966, 0.54835194, 0.10499302, 0.9767778, 0.55008715, 0.54379046, 0.3583731, + 0.33369112, 0.04279039, 0.24939054, 0.23943715, 0.06775989, 0.7750291, 0.24329625, 0.4327169, 0.86916673, 0.80322117, 0.049972698, 0.47177452, + 0.37419558, 0.15303156, 0.121425234, 0.75884604, 0.8191354, 0.48554084, 0.053899214, 0.7858246, 0.39219773, 0.77579063, 0.34507045, 0.46070176, + 0.14496958, 0.47706795, 0.50678796, 0.64902323, 0.3277943, 0.0017530271, 0.6536156, 0.8582253, 0.95703506, 0.9963951, 0.8239163, 0.305142, + 0.012419582, 0.9498972, 0.1595827, 0.47947606, 0.5071124, 0.78227425, 0.2066719, 0.5217094, 0.7841406, 0.5260441, 0.49798164, 0.10975622, + 0.8633349, 0.76298475, 0.14295428, 0.6131504, 0.43794408, 0.50339264, 0.4504877, 0.19235311, 0.6678411, 0.80769485, 0.67495126, 0.96461457, + 0.10535406, 0.66438645, 0.4372345, 0.93851465, 0.8635335, 0.3405871, 0.45652762, 0.3636232, 0.52931345, 0.20154329, 0.07698499, 0.6125804, + 0.3583082, 0.3894796, 0.32601944, 0.5237369, 0.66683626, 0.08541841, 0.4815708, 0.11897489, 0.97555137, 0.3602705, 0.9620871, 0.6361821, + 0.71167386, 0.5134439, 0.57761437, 0.58598644, 0.39387667, 0.6966405, 0.46841687, 0.85788506, 0.9957087, 0.051309288, 0.24846801, 0.55938333, + 0.10230542, 0.9370694, 0.57527155, 0.54656035, 0.28896323, 0.51303476, 0.8865, 0.38641605, 0.9836358}, sd::DataType::FLOAT32); + + NDArray expGradI('c', {bS, iC, iD, iH, iW}, {3.510932, 3.4310975, 3.538762, 4.148549, 2.8380678, 2.5431657, 3.3928843, 3.228055, 3.1467278, + 3.2603023, 5.611751, 4.334653, 3.3697734, 4.603307, 4.4357986, 4.32991, 3.0532732, 3.1370173, 4.181534, 2.9965065, 2.8553872, 5.2719016, + 4.5671935, 3.7027276, 3.3517184, 5.2544537, 3.5107024, 4.1496124, 3.9333878, 3.1798909, 3.1446428, 3.0932689, 3.9730802, 3.0466917, + 4.9675374, 4.769673, 3.766952, 3.6375027, 3.6492167, 4.9440994, 3.8379507, 3.467589, 4.719474, 3.1295977, 4.5177174, 4.2760015, 2.8443856, + 4.225355, 4.377341, 4.4398847, 4.710785, 4.4199953, 3.928307, 4.8769503}, sd::DataType::FLOAT32); + + sd::ops::upsampling3d_bp op; + auto results = op.evaluate({&input, &gradO}, {isNCDHW}); + auto* gradI = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, deconv2d_test1) { + + int bS=2, oH=4,oW=4, oC=5,iC=10, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int iH=3,iW=3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, oC, iC}); + auto exp = NDArrayFactory::create('c', {bS, oH, oW, oC}, { 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, + 52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f, + 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, + 52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f}); + input = 0.5; + weights.linspace(0.1, 0.1); + + sd::ops::deconv2d op; + auto results = op.evaluate({&input, &weights}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, deconv2d_test2) { + + int bS=2, iH=4,iW=4, iC=5,oC=10, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=4,oW=4; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, oH, oW, oC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); + auto exp = NDArrayFactory::create('c', {bS, iH, iW, iC}, {2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f }); + input = 0.5; + weights.linspace(0.1, 0.1); + + sd::ops::deconv2d op; + auto results = op.evaluate({&input, &weights}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, deconv2d_test3) { + + int bS=1, oH=5,oW=5, oC=3,iC=2, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=2,dW=2; + int iH=3,iW=3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, oC, iC}); + auto bias = NDArrayFactory::create('c', {oC}); + + auto exp = NDArrayFactory::create('c', {bS, oH, oW, oC}, {-2.9f, -6.8f, -10.7f, -2.6f, -6.1f, -9.6f, -16.9f, -23.9f, -30.9f, -13.1f, -16.6f, -20.1f, -11.6f, -14.7f, -17.8f, -2.0f, -4.7f, -7.4f, -1.7f, -4.0f, -6.3f, -11.5f, -16.1f, + -20.7f, -8.6f, -10.9f, -13.2f, -7.1f, -9.0f, -10.9f, -27.4f, -32.8f, -38.2f, -24.4f, -29.0f, -33.6f, -65.0f, -74.2f, -83.4f, -38.2f, -42.8f, -47.4f, + -32.8f, -36.6f, -40.4f, -18.2f, -20.9f, -23.6f, -15.5f, -17.8f, -20.1f, -39.1f, -43.7f, -48.3f, -22.4f, -24.7f, -27.0f, -18.5f, -20.4f, -22.3f, -10.1f, -11.6f, -13.1f, + -7.4f, -8.5f, -9.6f, -19.3f, -21.5f, -23.7f, -10.7f, -11.8f, -12.9f, -6.8f, -7.5f, -8.2f}); + + input.linspace(-10, 0.5); + weights.linspace(0.1, 0.1); + bias = 0.2; + + sd::ops::deconv2d op; + auto results = op.evaluate({&input, &weights}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, deconv2d_test4) { + + NDArray input('c', {2, 3, 4, 4}, sd::DataType::FLOAT32); + NDArray weights('c', {3, 3, 5, 5}, sd::DataType::FLOAT32); + NDArray exp('c', {2,3,8,8}, {6276.0,12831.0,19668.0,26790.0,27012.0,20703.0,14100.0,7200.0,13719.0,28023.0,42918.0,58410.0,58902.0,45105.0,30693.0,15660.0,22389.0,45696.0,69930.0,95100.0,95910.0,73386.0,49899.0,25440.0,32346.0,65970.0, + 100884.0,137100.0,138276.0,105726.0,71838.0,36600.0,33726.0,68790.0,105204.0,142980.0,144156.0,110226.0,74898.0,38160.0,27555.0,56154.0,85806.0,116520.0,117474.0,89748.0,60933.0,31020.0,19917.0,40557.0,61926.0, + 84030.0,84714.0,64671.0,43875.0,22320.0,10752.0,21879.0,33384.0,45270.0,45636.0,34815.0,23604.0,12000.0,7551.0,15456.0,23718.0,32340.0,32562.0,24978.0,17025.0,8700.0,16569.0,33873.0,51918.0,70710.0,71202.0, + 54555.0,37143.0,18960.0,27114.0,55371.0,84780.0,115350.0,116160.0,88911.0,60474.0,30840.0,39246.0,80070.0,122484.0,166500.0,167676.0,128226.0,87138.0,44400.0,40626.0,82890.0,126804.0,172380.0,173556.0,132726.0, + 90198.0,45960.0,33180.0,67629.0,103356.0,140370.0,141324.0,107973.0,73308.0,37320.0,23967.0,48807.0,74526.0,101130.0,101814.0,77721.0,52725.0,26820.0,12927.0,26304.0,40134.0,54420.0,54786.0,41790.0,28329.0,14400.0, + 8826.0,18081.0,27768.0,37890.0,38112.0,29253.0,19950.0,10200.0,19419.0,39723.0,60918.0,83010.0,83502.0,64005.0,43593.0,22260.0,31839.0,65046.0,99630.0,135600.0,136410.0,104436.0,71049.0,36240.0,46146.0,94170.0, + 144084.0,195900.0,197076.0,150726.0,102438.0,52200.0,47526.0,96990.0,148404.0,201780.0,202956.0,155226.0,105498.0,53760.0,38805.0,79104.0,120906.0,164220.0,165174.0,126198.0,85683.0,43620.0,28017.0,57057.0,87126.0, + 118230.0,118914.0,90771.0,61575.0,31320.0,15102.0,30729.0,46884.0,63570.0,63936.0,48765.0,33054.0,16800.0,17220.0,34863.0,52932.0,71430.0,72228.0,54831.0,36996.0,18720.0,36327.0,73527.0,111606.0,150570.0,152214.0, + 115521.0,77925.0,39420.0,57381.0,116112.0,176202.0,237660.0,240198.0,182250.0,122907.0,62160.0,80442.0,162738.0,246900.0,332940.0,336420.0,255198.0,172062.0,87000.0,84702.0,171318.0,259860.0,350340.0,353820.0, + 268338.0,180882.0,91440.0,66867.0,135210.0,205038.0,276360.0,279042.0,211572.0,142581.0,72060.0,46845.0,94701.0,143574.0,193470.0,195306.0,148047.0,99747.0,50400.0,24576.0,49671.0,75288.0,101430.0,102372.0,77583.0, + 52260.0,26400.0,22095.0,44688.0,67782.0,91380.0,92178.0,69906.0,47121.0,23820.0,46377.0,93777.0,142206.0,191670.0,193314.0,146571.0,98775.0,49920.0,72906.0,147387.0,223452.0,301110.0,303648.0,230175.0,155082.0, + 78360.0,101742.0,205638.0,311700.0,419940.0,423420.0,320898.0,216162.0,109200.0,106002.0,214218.0,324660.0,437340.0,440820.0,334038.0,224982.0,113640.0,83292.0,168285.0,254988.0,343410.0,346092.0,262197.0,176556.0, + 89160.0,58095.0,117351.0,177774.0,239370.0,241206.0,182697.0,122997.0,62100.0,30351.0,61296.0,92838.0,124980.0,125922.0,95358.0,64185.0,32400.0,26970.0,54513.0,82632.0,111330.0,112128.0,84981.0,57246.0,28920.0,56427.0,114027.0,172806.0,232770.0,234414.0,177621.0,119625.0,60420.0,88431.0,178662.0,270702.0,364560.0,367098.0,278100.0,187257.0,94560.0,123042.0,248538.0,376500.0,506940.0,510420.0,386598.0,260262.0,131400.0,127302.0,257118.0,389460.0,524340.0,527820.0,399738.0,269082.0,135840.0,99717.0,201360.0,304938.0,410460.0,413142.0,312822.0,210531.0,106260.0,69345.0,140001.0,211974.0,285270.0,287106.0,217347.0,146247.0,73800.0,36126.0,72921.0,110388.0,148530.0,149472.0,113133.0,76110.0,38400.0}, sd::DataType::FLOAT32); + + input.linspace(1); + weights.linspace(1); + weights.permutei({2,3,1,0}); + + sd::ops::deconv2d op; + auto result = op.evaluate({&input, &weights}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0}); + + auto z = result.at(0); + // z->printShapeInfo(); + // z->printBuffer(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, deconv2d_test5) { + Nd4jLong _expS[] = {4, 2, 3, 8, 8, 192, 64, 8, 1, 16384, 1, 99}; + double _expB[] = {6276.0,12831.0,19668.0,26790.0,27012.0,20703.0,14100.0,7200.0,13719.0,28023.0,42918.0,58410.0,58902.0,45105.0,30693.0,15660.0,22389.0,45696.0,69930.0,95100.0,95910.0,73386.0,49899.0,25440.0,32346.0,65970.0,100884.0,137100.0,138276.0,105726.0,71838.0,36600.0,33726.0,68790.0,105204.0,142980.0,144156.0,110226.0,74898.0,38160.0,27555.0,56154.0,85806.0,116520.0,117474.0,89748.0,60933.0,31020.0,19917.0,40557.0,61926.0,84030.0,84714.0,64671.0,43875.0,22320.0,10752.0,21879.0,33384.0,45270.0,45636.0,34815.0,23604.0,12000.0,7551.0,15456.0,23718.0,32340.0,32562.0,24978.0,17025.0,8700.0,16569.0,33873.0,51918.0,70710.0,71202.0,54555.0,37143.0,18960.0,27114.0,55371.0,84780.0,115350.0,116160.0,88911.0,60474.0,30840.0,39246.0,80070.0,122484.0,166500.0,167676.0,128226.0,87138.0,44400.0,40626.0,82890.0,126804.0,172380.0,173556.0,132726.0,90198.0,45960.0,33180.0,67629.0,103356.0,140370.0,141324.0,107973.0,73308.0,37320.0,23967.0,48807.0,74526.0,101130.0,101814.0,77721.0,52725.0,26820.0,12927.0,26304.0,40134.0,54420.0,54786.0,41790.0,28329.0,14400.0,8826.0,18081.0,27768.0,37890.0,38112.0,29253.0,19950.0,10200.0,19419.0,39723.0,60918.0,83010.0,83502.0,64005.0,43593.0,22260.0,31839.0,65046.0,99630.0,135600.0,136410.0,104436.0,71049.0,36240.0,46146.0,94170.0,144084.0,195900.0,197076.0,150726.0,102438.0,52200.0,47526.0,96990.0,148404.0,201780.0,202956.0,155226.0,105498.0,53760.0,38805.0,79104.0,120906.0,164220.0,165174.0,126198.0,85683.0,43620.0,28017.0,57057.0,87126.0,118230.0,118914.0,90771.0,61575.0,31320.0,15102.0,30729.0,46884.0,63570.0,63936.0,48765.0,33054.0,16800.0,17220.0,34863.0,52932.0,71430.0,72228.0,54831.0,36996.0,18720.0,36327.0,73527.0,111606.0,150570.0,152214.0,115521.0,77925.0,39420.0,57381.0,116112.0,176202.0,237660.0,240198.0,182250.0,122907.0,62160.0,80442.0,162738.0,246900.0,332940.0,336420.0,255198.0,172062.0,87000.0,84702.0,171318.0,259860.0,350340.0,353820.0,268338.0,180882.0,91440.0,66867.0,135210.0,205038.0,276360.0,279042.0,211572.0,142581.0,72060.0,46845.0,94701.0,143574.0,193470.0,195306.0,148047.0,99747.0,50400.0,24576.0,49671.0,75288.0,101430.0,102372.0,77583.0,52260.0,26400.0,22095.0,44688.0,67782.0,91380.0,92178.0,69906.0,47121.0,23820.0,46377.0,93777.0,142206.0,191670.0,193314.0,146571.0,98775.0,49920.0,72906.0,147387.0,223452.0,301110.0,303648.0,230175.0,155082.0,78360.0,101742.0,205638.0,311700.0,419940.0,423420.0,320898.0,216162.0,109200.0,106002.0,214218.0,324660.0,437340.0,440820.0,334038.0,224982.0,113640.0,83292.0,168285.0,254988.0,343410.0,346092.0,262197.0,176556.0,89160.0,58095.0,117351.0,177774.0,239370.0,241206.0,182697.0,122997.0,62100.0,30351.0,61296.0,92838.0,124980.0,125922.0,95358.0,64185.0,32400.0,26970.0,54513.0,82632.0,111330.0,112128.0,84981.0,57246.0,28920.0,56427.0,114027.0,172806.0,232770.0,234414.0,177621.0,119625.0,60420.0,88431.0,178662.0,270702.0,364560.0,367098.0,278100.0,187257.0,94560.0,123042.0,248538.0,376500.0,506940.0,510420.0,386598.0,260262.0,131400.0,127302.0,257118.0,389460.0,524340.0,527820.0,399738.0,269082.0,135840.0,99717.0,201360.0,304938.0,410460.0,413142.0,312822.0,210531.0,106260.0,69345.0,140001.0,211974.0,285270.0,287106.0,217347.0,146247.0,73800.0,36126.0,72921.0,110388.0,148530.0,149472.0,113133.0,76110.0,38400.0,}; + NDArray exp(_expB, _expS); + + auto input = NDArrayFactory::create('c', {2, 3, 4, 4}); + auto weights = NDArrayFactory::create('c', {3, 3, 5, 5}); + auto z = NDArrayFactory::create('c', {2, 3, 8, 8}); + + input.linspace(1); + weights.linspace(1); + weights.permutei({2,3,1,0}); + + sd::ops::deconv2d op; + auto result = op.execute({&input, &weights}, {&z}, {5, 5, 1, 1, 0, 0, 1, 1, 0, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, result); + + ASSERT_TRUE(exp.isSameShape(&z)); + ASSERT_TRUE(exp.equalsTo(&z)); +} + +TYPED_TEST(TypedConvolutionTests1, deconv2d_test6) { + + int bS=2, iH=4,iW=4, iC=3,oC=3, kH=5,kW=5, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=8,oW=8; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto weights = NDArrayFactory::create('c', {kH, kW, oC, iC}, {1.f, 76.f, 151.f, 26.f, 101.f, 176.f, 51.f, 126.f, 201.f, 2.f, 77.f, 152.f, 27.f, 102.f, 177.f, 52.f, 127.f, 202.f, 3.f, 78.f, 153.f, 28.f, 103.f, 178.f, 53.f, 128.f, 203.f, + 4.f, 79.f, 154.f, 29.f, 104.f, 179.f, 54.f, 129.f, 204.f, 5.f, 80.f, 155.f, 30.f, 105.f, 180.f, 55.f, 130.f, 205.f, 6.f, 81.f, 156.f, 31.f, 106.f, 181.f, 56.f, 131.f, 206.f, + 7.f, 82.f, 157.f, 32.f, 107.f, 182.f, 57.f, 132.f, 207.f, 8.f, 83.f, 158.f, 33.f, 108.f, 183.f, 58.f, 133.f, 208.f, 9.f, 84.f, 159.f, 34.f, 109.f, 184.f, 59.f, 134.f, 209.f, + 10.f, 85.f, 160.f, 35.f, 110.f, 185.f, 60.f, 135.f, 210.f, 11.f, 86.f, 161.f, 36.f, 111.f, 186.f, 61.f, 136.f, 211.f, 12.f, 87.f, 162.f, 37.f, 112.f, 187.f, 62.f, 137.f, 212.f, + 13.f, 88.f, 163.f, 38.f, 113.f, 188.f, 63.f, 138.f, 213.f, 14.f, 89.f, 164.f, 39.f, 114.f, 189.f, 64.f, 139.f, 214.f, 15.f, 90.f, 165.f, 40.f, 115.f, 190.f, 65.f, 140.f, 215.f, + 16.f, 91.f, 166.f, 41.f, 116.f, 191.f, 66.f, 141.f, 216.f, 17.f, 92.f, 167.f, 42.f, 117.f, 192.f, 67.f, 142.f, 217.f, 18.f, 93.f, 168.f, 43.f, 118.f, 193.f, 68.f, 143.f, 218.f, + 19.f, 94.f, 169.f, 44.f, 119.f, 194.f, 69.f, 144.f, 219.f, 20.f, 95.f, 170.f, 45.f, 120.f, 195.f, 70.f, 145.f, 220.f, 21.f, 96.f, 171.f, 46.f, 121.f, 196.f, 71.f, 146.f, 221.f, + 22.f, 97.f, 172.f, 47.f, 122.f, 197.f, 72.f, 147.f, 222.f, 23.f, 98.f, 173.f, 48.f, 123.f, 198.f, 73.f, 148.f, 223.f, 24.f, 99.f, 174.f, 49.f, 124.f, 199.f, 74.f, 149.f, 224.f, + 25.f, 100.f, 175.f,50.f, 125.f, 200.f,75.f, 150.f, 225.f}); + + auto exp = NDArrayFactory::create('c', {bS, oC, oH, oW}, {6276.0f, 12831.0f, 19668.0f, 26790.0f, 27012.0f, 20703.0f, 14100.0f, 7200.0f, 13719.0f, 28023.0f, 42918.0f, 58410.0f, 58902.0f, 45105.0f, 30693.0f, 15660.0f, 22389.0f, 45696.0f, 69930.0f, 95100.0f, 95910.0f, 73386.0f, 49899.0f, 25440.0f, 32346.0f, 65970.0f, 100884.0f, 137100.0f, 138276.0f, 105726.0f, 71838.0f, 36600.0f, 33726.0f, 68790.0f, 105204.0f, 142980.0f, 144156.0f, 110226.0f, 74898.0f, 38160.0f, 27555.0f, 56154.0f, 85806.0f, 116520.0f, 117474.0f, 89748.0f, 60933.0f, 31020.0f, 19917.0f, 40557.0f, 61926.0f, 84030.0f, 84714.0f, 64671.0f, 43875.0f, 22320.0f, 10752.0f, 21879.0f, 33384.0f, 45270.0f, 45636.0f, 34815.0f, 23604.0f, 12000.0f, 7551.0f, 15456.0f, 23718.0f, 32340.0f, 32562.0f, 24978.0f, 17025.0f, 8700.0f, 16569.0f, 33873.0f, 51918.0f, 70710.0f, 71202.0f, 54555.0f, 37143.0f, 18960.0f, 27114.0f, 55371.0f, 84780.0f, 115350.0f, 116160.0f, 88911.0f, 60474.0f, 30840.0f, 39246.0f, 80070.0f, 122484.0f, 166500.0f, 167676.0f, 128226.0f, 87138.0f, 44400.0f, 40626.0f, 82890.0f, 126804.0f, 172380.0f, 173556.0f, 132726.0f, 90198.0f, 45960.0f, 33180.0f, 67629.0f, 103356.0f, 140370.0f, 141324.0f, 107973.0f, 73308.0f, 37320.0f, 23967.0f, 48807.0f, 74526.0f, 101130.0f, 101814.0f, 77721.0f, 52725.0f, 26820.0f, 12927.0f, 26304.0f, 40134.0f, 54420.0f, 54786.0f, 41790.0f, 28329.0f, 14400.0f, 8826.0f, 18081.0f, 27768.0f, 37890.0f, 38112.0f, 29253.0f, 19950.0f, 10200.0f, 19419.0f, 39723.0f, 60918.0f, 83010.0f, 83502.0f, 64005.0f, 43593.0f, 22260.0f, 31839.0f, 65046.0f, 99630.0f, 135600.0f, 136410.0f, 104436.0f, 71049.0f, 36240.0f, 46146.0f, 94170.0f, 144084.0f, 195900.0f, 197076.0f, 150726.0f, 102438.0f, 52200.0f, 47526.0f, 96990.0f, 148404.0f, 201780.0f, 202956.0f, 155226.0f, 105498.0f, 53760.0f, 38805.0f, 79104.0f, 120906.0f, 164220.0f, 165174.0f, 126198.0f, 85683.0f, 43620.0f, 28017.0f, 57057.0f, 87126.0f, 118230.0f, 118914.0f, 90771.0f, 61575.0f, 31320.0f, 15102.0f, 30729.0f, 46884.0f, 63570.0f, 63936.0f, 48765.0f, 33054.0f, 16800.0f, 17220.0f, 34863.0f, 52932.0f, 71430.0f, 72228.0f, 54831.0f, 36996.0f, 18720.0f, 36327.0f, 73527.0f, 111606.0f, 150570.0f, 152214.0f, 115521.0f, 77925.0f, 39420.0f, 57381.0f, 116112.0f, 176202.0f, 237660.0f, 240198.0f, 182250.0f, 122907.0f, 62160.0f, 80442.0f, 162738.0f, 246900.0f, 332940.0f, 336420.0f, 255198.0f, 172062.0f, 87000.0f, 84702.0f, 171318.0f, 259860.0f, 350340.0f, 353820.0f, 268338.0f, 180882.0f, 91440.0f, 66867.0f, 135210.0f, 205038.0f, 276360.0f, 279042.0f, 211572.0f, 142581.0f, 72060.0f, 46845.0f, 94701.0f, 143574.0f, 193470.0f, 195306.0f, 148047.0f, 99747.0f, 50400.0f, 24576.0f, 49671.0f, 75288.0f, 101430.0f, 102372.0f, 77583.0f, 52260.0f, 26400.0f, 22095.0f, 44688.0f, 67782.0f, 91380.0f, 92178.0f, 69906.0f, 47121.0f, 23820.0f, 46377.0f, 93777.0f, 142206.0f, 191670.0f, 193314.0f, 146571.0f, 98775.0f, 49920.0f, 72906.0f, 147387.0f, 223452.0f, 301110.0f, 303648.0f, 230175.0f, 155082.0f, 78360.0f, 101742.0f, 205638.0f, 311700.0f, 419940.0f, 423420.0f, 320898.0f, 216162.0f, 109200.0f, 106002.0f, 214218.0f, 324660.0f, 437340.0f, 440820.0f, 334038.0f, 224982.0f, 113640.0f, 83292.0f, 168285.0f, 254988.0f, 343410.0f, 346092.0f, 262197.0f, 176556.0f, 89160.0f, 58095.0f, 117351.0f, 177774.0f, 239370.0f, 241206.0f, 182697.0f, 122997.0f, 62100.0f, 30351.0f, 61296.0f, 92838.0f, 124980.0f, 125922.0f, 95358.0f, 64185.0f, 32400.0f, 26970.0f, 54513.0f, 82632.0f, 111330.0f, 112128.0f, 84981.0f, 57246.0f, 28920.0f, 56427.0f, 114027.0f, 172806.0f, 232770.0f, 234414.0f, 177621.0f, 119625.0f, 60420.0f, 88431.0f, 178662.0f, 270702.0f, 364560.0f, 367098.0f, 278100.0f, 187257.0f, 94560.0f, 123042.0f, 248538.0f, 376500.0f, 506940.0f, 510420.0f, 386598.0f, 260262.0f, 131400.0f, 127302.0f, 257118.0f, 389460.0f, 524340.0f, 527820.0f, 399738.0f, 269082.0f, 135840.0f, 99717.0f, 201360.0f, 304938.0f, 410460.0f, 413142.0f, 312822.0f, 210531.0f, 106260.0f, 69345.0f, 140001.0f, 211974.0f, 285270.0f, 287106.0f, 217347.0f, 146247.0f, 73800.0f, 36126.0f, 72921.0f, 110388.0f, 148530.0f, 149472.0f, 113133.0f, 76110.0f, 38400.0f}); + + input.linspace(1); + + sd::ops::deconv2d op; + auto results = op.evaluate({&input, &weights}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + +} + +TEST_F(ConvolutionTests1, deconv2d_test7) { + + NDArray exp('c', {3, 2, 4, 4}, {218., 227., 236., 245., 254., 263., 272., 281., 290., 299., 308., 317., 326., 335., 344., 353., 270., 282., 294., 306., 318., 330., 342., 354., 366., 378., 390., 402., 414., 426., 438., 450., 650., 659., 668., 677., 686., 695., 704., 713., 722., 731., 740., 749., 758., 767., 776., 785., 846., 858., 870., 882., 894., 906., 918., 930., 942., 954., 966., 978., 990., 1002., 1014., 1026., 1082., 1091., 1100., 1109., 1118., 1127., 1136., 1145., 1154., 1163., 1172., 1181., 1190., 1199., 1208., 1217., 1422., 1434., 1446., 1458., 1470., 1482., 1494., 1506., 1518., 1530., 1542., 1554., 1566., 1578., 1590., 1602.}); + + auto input = NDArrayFactory::create('c', {3, 3, 4, 4}); + auto weights = NDArrayFactory::create('c',{1, 1, 2, 3}, {1,3,5,2,4,6}); + auto bias = NDArrayFactory::create('c', {2}); + + input.linspace(1); + bias.linspace(1); + + sd::ops::deconv2d op; + + auto result = op.evaluate({&input, &weights, &bias}, {1, 1, 1, 1, 0, 0, 1, 1, 1, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, deconv2d_test8) { + + int bS=1, iH=7,iW=7, iC=3,oC=2, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=7,oW=7; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iH, iW}, {0.679350, 0.355087, 0.842789, 0.200313, 0.701499, 0.310693, 0.447940, 0.938010, 0.326674, 0.151873, 0.383318, 0.782123, 0.198807, + 0.798564, 0.163263, 0.146968, 0.260897, 0.135058, 0.756209, 0.275454, 0.369088, 0.092826, 0.836492, 0.268413, 0.095062, 0.312795, 0.135918, 0.517544, 0.328703, + 0.061736, 0.396431, 0.248016, 0.548959, 0.115046, 0.814362, 0.721564, 0.404494, 0.299089, 0.403884, 0.988311, 0.022296, 0.927782, 0.318416, 0.068546, 0.284533, + 0.232720, 0.352142, 0.058909, 0.711221, 0.674457, 0.196946, 0.699497, 0.074322, 0.420425, 0.584263, 0.149574, 0.446406, 0.723072, 0.064481, 0.483078, 0.875996, + 0.569819, 0.445863, 0.527755, 0.016646, 0.753678, 0.140636, 0.754129, 0.161932, 0.775037, 0.332645, 0.117394, 0.017711, 0.608476, 0.525152, 0.917194, 0.849891, + 0.589423, 0.852278, 0.390636, 0.889683, 0.669445, 0.698873, 0.961480, 0.157401, 0.157364, 0.493520, 0.569937, 0.126832, 0.115728, 0.786368, 0.737939, 0.490079, + 0.608414, 0.956500, 0.390098, 0.147305, 0.850645, 0.497650, 0.071866, 0.082150, 0.035314, 0.732041, 0.369934, 0.840666, 0.273894, 0.431796, 0.133231, 0.192975, + 0.246897, 0.386418, 0.511541, 0.199036, 0.141631, 0.697699, 0.253631, 0.782218, 0.930099, 0.335512, 0.558808, 0.664358, 0.018851, 0.637559, 0.290430, 0.434902, + 0.842513, 0.466098, 0.381395, 0.523185, 0.990183, 0.925768, 0.643459, 0.016828, 0.918756, 0.228979, 0.006314, 0.665975, 0.190361, 0.595521, 0.698881, 0.221469, + 0.912434, 0.870822, 0.727369, 0.523972, 0.662884, 0.218841}); + + NDArray weights('c', {kH, kW, oC, iC}, {0.4195024073123932, 0.22738978266716003, 0.10093523561954498, 0.25008103251457214, 0.3183899223804474, 0.5976081490516663}); + NDArray bias('c', {1, oC}, {0.3596062958240509, 0.6866418123245239}); + + NDArray exp('c', {bS, oC, oH, oW}, {0.848190, 0.560603, 0.880509, 0.464103, 0.823376, 0.660138, 0.666382, 0.882257, 0.704650, 0.451427, 0.649734, 0.911822, 0.611581, + 0.847623, 0.568191, 0.439341, 0.710854, 0.473843, 0.927273, 0.605861, 0.724540, 0.530591, 0.804268, 0.478136, 0.602198, 0.639553, 0.669082, 0.855013, 0.678572, + 0.617800, 0.667545, 0.765899, 0.835564, 0.631733, 0.921562, 0.790830, 0.588187, 0.597934, 0.725855, 0.822259, 0.455384, 0.998167, 0.683336, 0.591897, 0.705213, + 0.748148, 0.648922, 0.484723, 0.873482, 1.368675, 0.881096, 1.169214, 0.781504, 1.433406, 1.171439, 1.348675, 1.227033, 1.256600, 0.824772, 1.051633, 1.308692, + 1.148711, 1.334007, 1.014448, 0.813336, 1.408801, 0.916766, 1.583323, 1.362920, 1.226212, 1.149715, 1.330235, 0.770671, 1.285158, 1.105632, 1.272558, 1.590159, + 1.235054, 1.201363, 1.222816, 1.623673, 1.590317, 1.322463, 1.206481, 1.466262, 0.974741, 0.922343, 1.367100, 1.087943, 1.084952, 1.586691, 1.133576, 1.405098, + 1.471922, 1.484062, 1.212039, 1.144419, 1.266123}); + + sd::ops::deconv2d op; + auto results = op.evaluate({&input, &weights, &bias}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, deconv2d_test9) { + + int bS=2, oH=4,oW=4, oC=5,iC=10, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int iH=3,iW=3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = 1; // 0-[kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] + + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {iC, oC, kH, kW}, {100.000000, 75.000000, 50.000000, 25.000000, 95.000000, 70.000000, 45.000000, 20.000000, 90.000000, 65.000000, 40.000000, + 15.000000, 85.000000, 60.000000, 35.000000, 10.000000, 80.000000, 55.000000, 30.000000, 5.000000, 99.500000, 74.500000, 49.500000, 24.500000, 94.500000, 69.500000, + 44.500000, 19.500000, 89.500000, 64.500000, 39.500000, 14.500000, 84.500000, 59.500000, 34.500000, 9.500000, 79.500000, 54.500000, 29.500000, 4.500000, 99.000000, + 74.000000, 49.000000, 24.000000, 94.000000, 69.000000, 44.000000, 19.000000, 89.000000, 64.000000, 39.000000, 14.000000, 84.000000, 59.000000, 34.000000, 9.000000, + 79.000000, 54.000000, 29.000000, 4.000000, 98.500000, 73.500000, 48.500000, 23.500000, 93.500000, 68.500000, 43.500000, 18.500000, 88.500000, 63.500000, 38.500000, + 13.500000, 83.500000, 58.500000, 33.500000, 8.500000, 78.500000, 53.500000, 28.500000, 3.500000, 98.000000, 73.000000, 48.000000, 23.000000, 93.000000, 68.000000, + 43.000000, 18.000000, 88.000000, 63.000000, 38.000000, 13.000000, 83.000000, 58.000000, 33.000000, 8.000000, 78.000000, 53.000000, 28.000000, 3.000000, 97.500000, 72.500000, 47.500000, 22.500000, 92.500000, 67.500000, 42.500000, 17.500000, 87.500000, 62.500000, 37.500000, 12.500000, 82.500000, 57.500000, 32.500000, 7.500000, 77.500000, 52.500000, 27.500000, 2.500000, 97.000000, 72.000000, 47.000000, 22.000000, 92.000000, 67.000000, 42.000000, 17.000000, 87.000000, 62.000000, 37.000000, 12.000000, 82.000000, 57.000000, 32.000000, 7.000000, 77.000000, 52.000000, 27.000000, 2.000000, 96.500000, 71.500000, 46.500000, 21.500000, 91.500000, 66.500000, 41.500000, 16.500000, 86.500000, 61.500000, 36.500000, 11.500000, 81.500000, 56.500000, 31.500000, 6.500000, 76.500000, 51.500000, 26.500000, 1.500000, 96.000000, 71.000000, 46.000000, 21.000000, 91.000000, 66.000000, 41.000000, 16.000000, 86.000000, 61.000000, 36.000000, 11.000000, 81.000000, 56.000000, 31.000000, 6.000000, 76.000000, 51.000000, 26.000000, 1.000000, 95.500000, 70.500000, 45.500000, 20.500000, 90.500000, 65.500000, 40.500000, 15.500000, 85.500000, 60.500000, 35.500000, 10.500000, 80.500000, 55.500000, 30.500000, 5.500000, 75.500000, 50.500000, 25.500000, 0.500000}, sd::DataType::FLOAT32); + NDArray expOutput('c', {bS, oH, oW, oC}, {-30844.250000, -29266.750000, -27689.250000, -26111.750000, -24534.250000, -52823.500000, -49718.500000, -46613.500000, -43508.500000, -40403.500000, -51118.500000, + -48113.500000, -45108.500000, -42103.500000, -39098.500000, -21501.750000, -20024.250000, -18546.750000, -17069.250000, -15591.750000, -42981.000000, -39976.000000, -36971.000000, -33966.000000, -30961.000000, + -69482.000000, -63572.000000, -57662.000000, -51752.000000, -45842.000000, -67072.000000, -61362.000000, -55652.000000, -49942.000000, -44232.000000, -26046.000000, -23241.000000, -20436.000000, -17631.000000, + -14826.000000, -38616.000000, -35911.000000, -33206.000000, -30501.000000, -27796.000000, -62252.000000, -56942.000000, -51632.000000, -46322.000000, -41012.000000, -59842.000000, -54732.000000, -49622.000000, + -44512.000000, -39402.000000, -23181.000000, -20676.000000, -18171.000000, -15666.000000, -13161.000000, -12204.250000, -10926.750000, -9649.250000, -8371.750000, -7094.250000, -17543.500000, -15038.500000, + -12533.500000, -10028.500000, -7523.500000, -16838.500000, -14433.499023, -12028.500000, -9623.500000, -7218.500000, -5361.750000, -4184.250000, -3006.750000, -1829.250000, -651.750000, -22046.750000, -20919.250000, + -19791.750000, -18664.250000, -17536.750000, -37478.500000, -35273.500000, -33068.500000, -30863.500000, -28658.500000, -35773.500000, -33668.500000, -31563.500000, -29458.500000, -27353.500000, -14954.250000, + -13926.750000, -12899.250000, -11871.750000, -10844.250000, -29886.000000, -27781.000000, -25676.000000, -23571.000000, -21466.000000, -47792.000000, -43682.000000, -39572.000000, -35462.000000, -31352.000000, + -45382.000000, -41472.000000, -37562.000000, -33652.000000, -29742.000000, -17451.000000, -15546.000000, -13641.000000, -11736.000000, -9831.000000, -25521.000000, -23716.000000, -21911.000000, -20106.000000, -18301.000000, -40562.000000, -37052.000000, -33542.000000, -30032.000000, -26522.000000, -38152.000000, -34842.000000, -31532.000000, -28222.000000, -24912.000000, -14586.000000, -12981.000000, -11376.000000, -9771.000000, -8166.000000, -7906.750000, -7079.250000, -6251.750000, -5424.250000, -4596.750000, -11198.500000, -9593.500000, -7988.500000, -6383.500000, -4778.500000, -10493.500000, -8988.500000, -7483.500000, -5978.500000, -4473.500000, -3314.250000, -2586.750000, -1859.250000, -1131.750000, -404.250000}, sd::DataType::FLOAT32); + + input.linspace(-32, 0.1); + + sd::ops::deconv2d op; + auto results = op.evaluate({&input, &weights}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, deconv2d_test10) { + + int bS=2, oH=4,oW=4, iC=5,oC=10, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int iH=4,iW=4; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = 2; // 0-[kH, kW, oC, iC], [iC, oC, kH, kW], [iC, kH, kW, oC] + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray weights('c', {iC, kH, kW, oC}, {100., 95., 90., 85., 80., 75., 70., 65., 60., 55., 50., 45., 40., 35., 30., 25., 20., 15., 10., 5., 0., -5., -10., -15., + -20., -25., -30., -35., -40., -45., -50., -55., -60., -65., -70., -75., -80., -85., -90., -95., 99., 94., 89., 84., 79., 74., 69., 64., 59., 54., 49., 44., + 39., 34., 29., 24., 19., 14., 9., 4., -1., -6., -11., -16., -21., -26., -31., -36., -41., -46., -51., -56., -61., -66., -71., -76., -81., -86., -91., -96., + 98., 93., 88., 83., 78., 73., 68., 63., 58., 53., 48., 43., 38., 33., 28., 23., 18., 13., 8., 3., -2., -7., -12., -17., -22., -27., -32., -37., -42., -47., + -52., -57., -62., -67., -72., -77., -82., -87., -92., -97., 97., 92., 87., 82., 77., 72., 67., 62., 57., 52., 47., 42., 37., 32., 27., 22., 17., 12., 7., 2., + -3., -8., -13., -18., -23., -28., -33., -38., -43., -48., -53., -58., -63., -68., -73., -78., -83., -88., -93., -98., 96., 91., 86., 81., 76., 71., 66., 61., + 56., 51., 46., 41., 36., 31., 26., 21., 16., 11., 6., 1., -4., -9., -14., -19., -24., -29., -34., -39., -44., -49., -54., -59., -64., -69., -74., -79., -84., -89., -94., -99.}, sd::DataType::FLOAT32); + NDArray expOutput('c', {bS, oC, oH, oW}, {-14128., -21007., -20934., -20861., -13660., -12972., -12926.000977, -12880., -13468., -12788., -12742., -12696.000977, + -13276., -12604., -12558., -12512., -13408., -19569.5, -19501.5, -19433.5, -12230., -10117., -10081.000977, -10045., -12058., -9973., -9937., -9901.000977, + -11886., -9829., -9793., -9757., -12688., -18132., -18069., -18006., -10800., -7262., -7236., -7210., -10648., -7157.999512, -7132., -7106., -10496., -7054., + -7027.999512, -7002., -11968., -16694.5, -16636.5, -16578.5, -9370., -4406.999023, -4391., -4375., -9238., -4343., -4326.999023, -4311., -9106., -4279., -4263., + -4246.999023, -11247.999023, -15257., -15204., -15151., -7940., -1551.999023, -1546., -1540., -7828., -1528.000977, -1521.999023, -1516., -7716., -1504., + -1498.000977, -1491.999023, -10527.999023, -13819.5, -13771.5, -13723.5, -6510., 1303.000977, 1299., 1295., -6418., 1286.999023, 1283.000977, 1279., -6326., + 1271., 1266.999023, 1263.000977, -9807.999023, -12382., -12339., -12296., -5080., 4158.000977, 4144., 4130., -5008., 4101.999023, 4088., 4074., -4936., 4046., 4031.999023, 4018., -9088., -10944.5, -10906.5, -10868.5, -3650., 7013., 6989., 6965., -3598., 6917., 6893., 6869., -3546., 6821., 6797., 6773., -8368., -9507., -9474., -9441., -2220., 9868., 9834., 9800., -2187.999512, 9732., 9698., 9664., -2156., 9596., 9562., 9528., -7648., -8069.5, -8041.5, -8013.499512, -790.000488, 12723., 12679., 12635., -777.999512, 12547., 12503., 12459., -766., 12371., 12327., 12283., -10208., -15167., -15094., -15021., -9820., -9292., -9246., -9200., -9628., -9108., -9062., -9016., -9436., -8924., -8878., -8832., -9687.999023, -14129.5, -14061.5, -13993.5, -8790., -7236.999023, -7201., -7164.999512, -8618., -7093., -7057., -7021., -8446., -6949., -6913., -6877., -9168., -13092., -13029., -12966., -7760., -5182., -5156., -5129.999512, -7608., -5078., -5052., -5026., -7456., -4974., -4948., -4922., -8648., -12054.5, -11996.5, -11938.5, -6730., -3127., -3111., -3095., -6598., -3063., -3047., -3031., -6465.999512, -2999., -2983.000488, -2967., -8128., -11017., -10964., -10911., -5700.000488, -1072., -1066., -1060., -5587.999512, -1048.000488, -1042., -1036., -5476., -1023.999512, -1018.000488, -1012., -7608., -9979.5, -9931.5, -9883.5, -4670.000488, 983., 979., 975., -4577.999512, 966.999512, 963., 959., -4486., 951.000488, 946.999512, 943., -7088., -8942., -8899., -8856., -3640.000488, 3038., 3024., 3010., -3567.999512, 2981.999512, 2968., 2954., -3496., 2926.000488, 2911.999512, 2898., -6568., -7904.5, -7866.5, -7828.499512, -2610.000488, 5093., 5069., 5045., -2557.999512, 4996.999512, 4973., 4949., -2506., 4901.000488, 4877., 4853., -6048., -6867., -6834., -6800.999512, -1580., 7148., 7114., 7080., -1547.999512, 7012., 6978., 6944., -1516., 6876.000488, 6842., 6808., -5528., -5829.5, -5801.5, -5773.499512, -550., 9203., 9159., 9115., -537.999512, 9027., 8983., 8939., -526., 8851., 8807., 8763.}, sd::DataType::FLOAT32); + + input.linspace(-32, 0.1); + + sd::ops::deconv2d op; + auto results = op.evaluate({&input, &weights}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests1, deconv2d_tf_test1) { + + int bS=2, iH=4,iW=4, iC=5,oC=10, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=3,oW=3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, oH, oW, oC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); + auto outShape = NDArrayFactory::create('c', {4}, {static_cast(bS), static_cast(iH), static_cast(iW), static_cast(iC)}); + auto exp = NDArrayFactory::create('c', {bS, iH, iW, iC}, { 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, + 52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f, + 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f, + 52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f}); + input = 0.5; + weights.linspace(0.1, 0.1); + + sd::ops::deconv2d_tf op; + auto results = op.evaluate({&outShape, &weights, &input}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv2d_bp_7) { + int bS=2, iH=12,iW=12, iC=3,oC=3, kH=3,kW=3, sH=2,sW=2, pH=0,pW=0, dH=1,dW=1; + int oH=6,oW=6; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, oC}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {1,2,3}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32); + NDArray gradI('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray gradW('c', {kH, kW, iC, oC}, sd::DataType::FLOAT32); + NDArray gradB('c', {oC}, sd::DataType::FLOAT32); + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::conv2d_bp op; + auto status = op.execute({&input, &weights, &bias, &gradO}, {&gradI, &gradW, &gradB}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}, {}); + ASSERT_EQ(Status::OK(), status); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv2d_ff_119_1) { + auto i = NDArrayFactory::create('c', {2, 3, 13, 13}); + auto w = NDArrayFactory::create('c', {3, 3, 3, 3}); + auto b = NDArrayFactory::create('c', {3}); + auto o = NDArrayFactory::create('c', {2, 3, 6, 6}); + + sd::ops::conv2d op_ff; + auto status = op_ff.execute({&i, &w, &b}, {&o}, {3,3, 2,2, 0,0, 1,1, 0, 0, 1}); + ASSERT_EQ(Status::OK(), status); + + auto gi = i.ulike(); + auto gw = w.ulike(); + + sd::ops::conv2d_bp op_bp; + status = op_bp.execute({&i, &w, &b, &o}, {&gi, &gw}, {3,3, 2,2, 0,0, 1,1, 0, 0, 1}); + ASSERT_EQ(Status::OK(), status); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv2d_ff_119_2) { + auto i = NDArrayFactory::create('c', {2, 3, 17, 17}); + auto w = NDArrayFactory::create('c', {3, 3, 3, 3}); + auto b = NDArrayFactory::create('c', {3}); + auto o = NDArrayFactory::create('c', {2, 3, 8, 8}); + + sd::ops::conv2d op_ff; + auto status = op_ff.execute({&i, &w, &b}, {&o}, {3,3, 2,2, 0,0, 1,1, 0, 0, 1}); + ASSERT_EQ(Status::OK(), status); + + auto gi = i.ulike(); + auto gw = w.ulike(); + + sd::ops::conv2d_bp op_bp; + status = op_bp.execute({&i, &w, &b, &o}, {&gi, &gw}, {3,3, 2,2, 0,0, 1,1, 0, 0, 1}); + ASSERT_EQ(Status::OK(), status); +} + +#endif //LIBND4J_CONVOLUTIONTESTS1_H + diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ConvolutionTests2.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ConvolutionTests2.cpp new file mode 100644 index 000000000..cd6eb4d70 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ConvolutionTests2.cpp @@ -0,0 +1,2859 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// @author Yurii Shyrma (iuriish@yahoo.com), created 02.04.2019 +// + +#ifndef LIBND4J_CONVOLUTIONTESTS2_H +#define LIBND4J_CONVOLUTIONTESTS2_H + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace sd; +using namespace sd::graph; + +class ConvolutionTests2 : public testing::Test { +public: + + const int bS = 2; // batch size + const int iD = 1; // input depth (number of picture channels, for example rgb=3) + const int iH = 28; // picture height in pixels + const int iW = 28; // picture width in pixels + const int oD = 3; // output depth (= N for dense layer) + const int kH = 5; // kernel height in pixels + const int kW = 5; // kernel width in pixels + const int sH = 1; // stride step in horizontal direction + const int sW = 1; // stride step in vertical direction + const int pH = 0; // padding height + const int pW = 0; // padding width + const int dH = 2; // dilation height + const int dW = 2; // dilation width + const int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; // output height + const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width + +}; + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, im2col_1) { + + int bS=2, iH=4,iW=3, iC=4, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH = (iH - (kH + (kH-1)*(dH-1)) + 2*pH)/sH + 1; // VALID + int oW = (iW - (kW + (kW-1)*(dW-1)) + 2*pW)/sW + 1; // VALID + + int paddingMode = 0; // 1-SAME, 0-VALID; + + NDArray image('c', {bS, iC, iH, iW}, sd::DataType::DOUBLE); + NDArray expected('c', {bS, iC, kH, kW, oH, oW}, {1, 2, 4, 5, 2, 3, 5, 6, 4, 5, 7, 8, 5, 6, 8, 9, 7, 8, 10, 11, 8, 9, 11, 12, 13, 14, 16, 17, 14, + 15, 17, 18, 16, 17, 19, 20, 17, 18, 20, 21, 19, 20, 22, 23, 20, 21, 23, 24, 25, 26, 28, 29, 26, 27, 29, 30, + 28, 29, 31, 32, 29, 30, 32, 33, 31, 32, 34, 35, 32, 33, 35, 36, 37, 38, 40, 41, 38, 39, 41, 42, 40, 41, 43, + 44, 41, 42, 44, 45, 43, 44, 46, 47, 44, 45, 47, 48, 49, 50, 52, 53, 50, 51, 53, 54, 52, 53, 55, 56, 53, 54, + 56, 57, 55, 56, 58, 59, 56, 57, 59, 60, 61, 62, 64, 65, 62, 63, 65, 66, 64, 65, 67, 68, 65, 66, 68, 69, 67, + 68, 70, 71, 68, 69, 71, 72, 73, 74, 76, 77, 74, 75, 77, 78, 76, 77, 79, 80, 77, 78, 80, 81, 79, 80, 82, 83, + 80, 81, 83, 84, 85, 86, 88, 89, 86, 87, 89, 90, 88, 89, 91, 92, 89, 90, 92, 93, 91, 92, 94, 95, 92, 93, 95, 96}); + + image.linspace(1, 1); + + sd::ops::im2col op; + auto results = op.evaluate({&image}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode}); + auto column = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expected.isSameShape(column)); + ASSERT_TRUE(expected.equalsTo(column)); + +} + +template +class TypedConvolutionTests2 : public testing::Test { +public: + +}; + +typedef ::testing::Types TestingTypes; +TYPED_TEST_CASE(TypedConvolutionTests2, TestingTypes); + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, deconv2d_tf_test2) { + + int bS=2, iH=4,iW=4, iC=5,oC=10, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=4,oW=4; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, oH, oW, oC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, oC}); + auto outShape = NDArrayFactory::create('c', {4}, {static_cast(bS), static_cast(iH), static_cast(iW), static_cast(iC)}); + auto exp = NDArrayFactory::create('c', {bS, iH, iW, iC}, {2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, + 55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f}); + input = 0.5; + weights.linspace(0.1, 0.1); + + sd::ops::deconv2d_tf op; + auto results = op.evaluate({&outShape, &weights, &input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, Test_DeConv2D_TF_1) { + auto input0 = NDArrayFactory::create('c', {4}, {12.f, 5.f, 5.f, 32.f}); + auto input1 = NDArrayFactory::create('c', {2, 2, 32, 16}); + auto input2 = NDArrayFactory::create('c', {12, 4, 4, 16}); + auto exp = NDArrayFactory::create('c', {12, 5, 5, 32}); + + sd::ops::deconv2d_tf op; + auto result = op.evaluate({&input0, &input1, &input2}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_EQ(exp, *result.at(0)); + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, Test_DeConv2D_TF_2) { + auto input0 = NDArrayFactory::create('c', {4}, {3.f, 8.f, 8.f, 16.f}); + + auto input1 = NDArrayFactory::create('c', {7, 7, 16, 5}, {1.05293429f, -0.89349967f, 0.31027254f, 1.22991478f, -0.62926656f, 0.56918693f, +-1.60992694f, 1.10167944f, -0.80843484f, 0.07521993f, -1.15994942f, 0.76016301f, -0.40056285f, -1.16872537f, -0.91384381f, -0.36700436f, 1.82389200f, -1.18200207f, 0.51612782f, -0.92479187f, -0.09307563f, -0.55122334f, 1.23532486f, -1.11124146f, -0.05812126f, 0.68159896f, 0.69125599f, -0.77127314f, -0.10874277f, 0.86469102f, +-1.31614351f, 0.33354419f, -1.71750402f, 0.17197680f, -1.03965557f, 1.10570908f, -1.19115615f, 1.05115080f, 0.18277600f, 1.08820546f, -0.72191417f, -0.10999311f, 1.56521320f, -0.35433730f, -1.11799145f, 0.34499285f, 0.64998639f, -1.64371550f, 0.92592359f, -0.47659501f, 0.49101439f, -0.15613313f, 1.47486567f, 0.43576995f, +2.19538260f, -0.83567709f, -1.21846950f, 0.80400819f, 1.14637423f, -1.01503456f, -0.61992753f, -0.47378838f, 0.86503726f, 0.27147385f, 0.37073180f, -0.19951358f, 0.79167330f, -0.33982825f, 0.18631981f, -1.54715073f, 0.39967480f, 0.95067030f, 1.12508667f, -0.86676019f, -1.10341156f, 2.33141375f, 1.10972047f, 0.71407092f, +1.70640314f, 1.80666339f, 0.59465605f, -0.39653218f, -2.61163163f, -1.15013492f, -1.19908321f, 0.41783467f, -0.22730024f, 0.31425011f, -0.58562893f, -0.10131568f, -0.85047537f, -2.59974790f, 1.22072542f, -2.08812046f, -0.19363593f, -1.27664304f, -0.02703438f, 1.08477545f, -0.65506506f, 0.46040919f, -0.13715318f, +-0.74945593f, -0.69006950f, -1.29617655f, -0.15865716f, 1.38956285f, 0.90216327f, -1.31185400f, -0.15067385f, -0.63093358f, -0.05895613f, 0.26545224f, 0.29332840f, 0.42852548f, 0.72409540f, 0.12879130f, 1.43038857f, 0.68647617f, 2.19654775f, 0.51878077f, -0.03769343f, 0.52877223f, -0.21733910f, 1.13710785f, -0.59003806f, +1.54624867f, -0.64997369f, -1.03239334f, 0.19708300f, 0.68658423f, 0.71048903f, -1.55250466f, -1.38636279f, 0.32385820f, 0.81226677f, 0.19209047f, -0.23002781f, -0.63631231f, 1.02101684f, 0.65428704f, -0.17206922f, 1.09488952f, 1.03022420f, -0.95567745f, -0.07595373f, -1.48606372f, 2.57174873f, -1.75366247f, 1.12913883f, +0.97053039f, -0.28552356f, 0.56511772f, -0.79568213f, 0.07561764f, -1.02085686f, 1.05770981f, -1.25715709f, 0.42046708f, -2.57390857f, 0.96947151f, 1.05215812f, 0.65624017f, -1.29019403f, 0.64157075f, -0.40509227f, -0.65354455f, 0.42348680f, -1.34107757f, 0.05931387f, -0.54337227f, 0.95460182f, 1.59319806f, -0.44433126f, +-0.33717924f, 0.79566282f, 0.50112695f, -0.22244534f, 1.76904583f, -0.89817202f, 1.82985342f, 0.17671813f, 0.80720717f, 1.32469308f, 0.39417782f, -0.23720963f, 0.96796370f, -1.02348757f, -0.86615551f, -1.58120525f, -0.37634999f, 0.00905940f, 0.01880967f, 1.75771821f, -0.64372772f, 0.36687651f, 0.15854552f, -0.67599791f, +0.53726906f, -1.20158446f, -1.78549063f, 0.96476388f, -0.66158366f, -0.41681561f, -0.97541636f, 2.35928202f, 0.32130197f, 1.06886065f, 1.38736427f, -0.73718959f, 0.11215294f, 2.12865782f, -0.37927702f, 0.55621815f, -1.10108411f, -0.02032263f, 0.29595461f, 1.58737493f, 1.24001300f, -0.66748160f, 0.80729002f, -0.10575818f, +-1.03175950f, 1.80755460f, 0.10825710f, 2.20666361f, 1.33633149f, 1.39290452f, 0.45211342f, -0.07837920f, 2.08304930f, -0.28387162f, -0.70775616f, 0.43626297f, 0.53556961f, 0.06201901f, -0.59255266f, -0.11854446f, 2.10024118f, 0.37638292f, -0.56178707f, -0.25220188f, -1.23731256f, -1.30002999f, 0.34283713f, 0.30502397f, +-1.09233856f, 1.12430644f, 0.52273953f, -0.68507338f, -0.69913578f, 0.88440478f, -0.76959240f, 1.07093310f, -0.34802195f, 0.35683727f, -0.76079178f, -1.92807376f, 0.84499562f, 1.39131641f, 0.44825050f, 0.34567752f, 0.44607711f, -1.00986362f, -0.50038189f, -0.09060892f, -2.55645394f, 0.56416476f, -0.83058155f, -0.65931624f, +-0.73649710f, 0.59814465f, -0.86736494f, -0.32200798f, -1.28087902f, -0.76818323f, 0.86848933f, -0.98678392f, -1.30813944f, -0.20255326f, 0.26557815f, -0.31090519f, -1.46331608f, -0.62782109f, 0.59034890f, 1.63147473f, -0.17727259f, -0.37636510f, 1.27368402f, 0.19096918f, -0.29936951f, -1.99038267f, 0.54831523f, 0.48849005f, -2.55680346f, -0.63126534f, 1.21715927f, 1.22841084f, -0.67416084f, 0.02927168f, -0.36693662f, 0.63204330f, 0.13721083f, 0.28742912f, 0.19470036f, 0.74873924f, -1.47602463f, 0.86264688f, -0.23730527f, -0.99978864f, -1.17048764f, -0.34996086f, 1.43019187f, 0.26224539f, 0.60689932f, -0.75002515f, -0.79823422f, -1.37300086f, -0.19951135f, -0.12150808f, -0.75272322f, 0.23755015f, 0.31270382f, 1.66539109f, -1.04104745f, 0.79540199f, -0.54042423f, -0.54150617f, 0.43871084f, 0.24163951f, -0.24517761f, -0.66178995f, -1.13064528f, -0.84426326f, 0.56437236f, 0.09088907f, -0.82823074f, 0.81753862f, -1.74096012f, -1.80599844f, -0.60943592f, 1.36094582f, -1.47762752f, 0.15931177f, 1.05569172f, 0.36751524f, 0.06497604f, 0.13536447f, -1.57156146f, 0.22783801f, -0.96910107f, -1.24294984f, -1.47147155f, -1.04790676f, 0.64629447f, -0.32266054f, -0.55675793f, -0.95612079f, -0.23005411f, -0.75229394f, 0.03050950f, -1.72484553f, -2.06055546f, 0.19892083f, -0.13597751f, 0.65180075f, 0.27096850f, 0.08977254f, 0.57564765f, -0.43227410f, 0.09541437f, -0.00358280f, 0.65680492f, 0.04006556f, 0.57160908f, 0.43821687f, 1.96118212f, 0.42602235f, -0.36731303f, 0.67200917f, -0.56667900f, 0.44014785f, 0.06970236f, -1.34415269f, -1.13301528f, -0.08848868f, 0.35615012f, -0.06426942f, -0.81406075f, 0.94097465f, -0.54560357f, -0.65877116f, -1.29646838f, -1.13109028f, -1.64186084f, -2.12723470f, 1.86027610f, 1.22621441f, 0.26098135f, -0.05608099f, 0.21143445f, -0.87244326f, 0.79408187f, 1.24279130f, 0.14458629f, 0.25532281f, -1.24023473f, 2.42278886f, 0.00405578f, -1.00119174f, 1.19856644f, -1.37395728f, -0.16656208f, 0.46858498f, -0.00678801f, -0.34960639f, 0.16614936f, 2.41560221f, -0.53880709f, 0.91618651f, -1.77009308f, 0.32911557f, 0.30216452f, 0.02881077f, 0.77705866f, 0.27061903f, -0.07440855f, -1.14010465f, 1.25383139f, -1.58615100f, 1.04185510f, 0.15140508f, -0.88059032f, -0.33872122f, -0.42526904f, 2.17365575f, 0.29308075f, -2.24234557f, -1.03164542f, -0.09263755f, 0.08050421f, -0.74946511f, -0.64589006f, -1.13416314f, -0.64989561f, 0.16502371f, -0.33831969f, 0.22832428f, -0.08389475f, -0.28009200f, 1.34536922f, -0.19075738f, 0.36238208f, 0.83690089f, 0.26144615f, 0.04457319f, -2.55585861f, -0.01807522f, 1.68334866f, -0.05795629f, -0.21315987f, -1.84039557f, 0.06512877f, -1.77318645f, -0.27637982f, 0.20439345f, 0.67558700f, -0.77179354f, -0.17902173f, 0.70381826f, -0.40395790f, -0.96492916f, 0.84138173f, 2.43879008f, -0.32297835f, -1.74370265f, -0.10330839f, -1.07465363f, 1.85030377f, -0.59153467f, 0.99667048f, -0.56753993f, 0.57383025f, -1.90630126f, 1.24299097f, 0.22797665f, 0.30468231f, -0.07360230f, 1.64654350f, 0.57195550f, 0.03227921f, 1.11005175f, 0.00088721f, 1.19266295f, 0.61323351f, 0.13754399f, 0.59900171f, -0.75831634f, 1.11500823f, 0.99747783f, -1.36923385f, 1.26563418f, 0.01253266f, 0.35483193f, 1.95143735f, -2.02703261f, -1.38265920f, -0.02404256f, 2.02788448f, -0.75144875f, -0.58445263f, 0.26129767f, 0.60691077f, -1.84661067f, 0.65872228f, -0.58298993f, 0.33067298f, -0.09431327f, 0.43333948f, -1.52616286f, -0.25961858f, -1.65459549f, -0.72950101f, -0.89906919f, -0.80081612f, -1.32189929f, -1.36574399f, -0.35809481f, 0.36385000f, 0.31480747f, -0.35797358f, -1.04066050f, 0.07971872f, -0.21176252f, -0.76559299f, -0.10352154f, 0.29248312f, -1.75030553f, 0.68219930f, 0.56189102f, -1.11212170f, 0.06501702f, -0.07131009f, 1.23410738f, 0.29311740f, -1.02052307f, 1.40220940f, -1.00995779f, 0.57955760f, 0.22640309f, 0.74853230f, -0.02586563f, -0.33427954f, 1.70311153f, -0.53405988f, 0.90975094f, -0.46450076f, 0.19904344f, 0.28559047f, 0.23167793f, -0.69065529f, -0.17176504f, -0.29301846f, -0.85477978f, -0.00267053f, -0.28529504f, -0.64201307f, 1.03479636f, 1.03805065f, 0.83270210f, -0.09405448f, 2.50615931f, 0.62019676f, 0.31354564f, -1.51599669f, 0.42848015f, 0.66263914f, 0.74651009f, -1.13042867f, -0.58933645f, -0.35146511f, 0.06223279f, 0.28065836f, 0.66506970f, 0.16942430f, -0.23316263f, -0.87481076f, 1.21992230f, 1.48536301f, -0.79667616f, -0.75519305f, 1.40999961f, -0.42802793f, -0.20252463f, 0.30573779f, -0.23319976f, 1.77525878f, -1.80704832f, 2.71519923f, -0.67500192f, 0.12268137f, -0.13014549f, -0.07479453f, -1.51065743f, 1.04198146f, 0.96205556f, -2.00525570f, -0.37911776f, 0.89329720f, -0.39495832f, -0.03683375f, -0.90928614f, -1.56263304f, 0.45038295f, -2.62184358f, -0.45686841f, -0.52536523f, 1.05351484f, 0.89982438f, -0.63724512f, 3.21004057f, -0.08608918f, 1.55209303f, 0.62688643f, -0.59702635f, 1.85774517f, 0.38172096f, -1.25640929f, -2.59278178f, 0.85050315f, -1.10080361f, -1.26422560f, -1.80045366f, -0.34494889f, 0.68448657f, 1.25671864f, -1.26594126f, 0.32244179f, -0.51956522f, -0.56212711f, -0.95574015f, 0.71973872f, 0.46736258f, -0.11772985f, -1.52736545f, 0.19571695f, 0.73147154f, 0.87724912f, -0.26265728f, -2.60267401f, 0.19263546f, 0.18320183f, 0.11485019f, -0.82999659f, 0.13582672f, -0.08040185f, 0.28152901f, -0.51421624f, -2.32467175f, 0.19923948f, 0.64616692f, 0.29718629f, 0.32785949f, -0.62266952f, -0.98174316f, 1.23276305f, 0.58563638f, 1.28528512f, -2.13718534f, 0.28842899f, 0.12676710f, -1.72105229f, 0.15053287f, 2.19496536f, 1.28683448f, -0.96318281f, 0.17043279f, -0.05245409f, -0.38710704f, -0.30441490f, -0.08249986f, 0.28423953f, 0.72963721f, -1.49658203f, 0.99077344f, -0.78913772f, -1.12661564f, -1.26294816f, 0.16517465f, 0.10124251f, -0.77198768f, -0.16342169f, 0.08615876f, 0.49711797f, -0.66083062f, 0.76648003f, 1.04756033f, 1.46122825f, -0.42798752f, -2.29203916f, 0.30444992f, 0.58697921f, 1.22166932f, 0.09022947f, -0.03920181f, 0.10444995f, 0.10361757f, 1.18224072f, -0.76641631f, 0.90802073f, 1.41639423f, 1.55682337f, 1.28101575f, -0.35396016f, 1.11443567f, 1.18218529f, -0.06048089f, 0.85024464f, -1.01789165f, -0.69154263f, 0.06663221f, 0.68429029f, 0.12560424f, 0.37915874f, -0.66829866f, -0.64524972f, -0.05568011f, 0.12230454f, -0.35041061f, 0.62027830f, -0.16739209f, -0.72145337f, 0.46263054f, -1.67837834f, 0.69413221f, -0.57243419f, 0.37638462f, -0.21446526f, -0.89821470f, 0.60078722f, -1.06706369f, -1.26132309f, 0.35714921f, 2.39221811f, -0.09376130f, 0.30760849f, 0.59180892f, 0.55815399f, -0.32628775f, 1.28890121f, -2.53237987f, -0.98241091f, 1.10520673f, -1.74751687f, -0.90837651f, -0.25220659f, -0.56625104f, -0.30691949f, 0.16058689f, 0.44309673f, -1.09874964f, -0.76747823f, -0.33679363f, -0.02535496f, 0.00990100f, 1.35318136f, -0.70140815f, 0.50937581f, 0.55386209f, -1.21721983f, 0.71376961f, -0.18079315f, -0.11077732f, 0.09292522f, -0.57235324f, 0.62748206f, 0.42587611f, 0.64860481f, -1.10635614f, 1.66414368f, 0.47505483f, 1.48602211f, -0.59611166f, -0.41932896f, -0.96542233f, -0.41756630f, -1.02963889f, -0.70070386f, 1.65803933f, 0.20138647f, 0.05895034f, -1.46152759f, -0.37278318f, 1.05535650f, 0.34437978f, -1.13257408f, 0.17635690f, 0.09386671f, 0.37079874f, 1.47695887f, -1.58420062f, -0.26100200f, 0.44847637f, 0.88847303f, -0.13877590f, -0.64620668f, -0.38019657f, 1.01608157f, 0.13357787f, 0.05137976f, 0.93498152f, -0.62226880f, 0.80461699f, -0.71682596f, -0.88756353f, 0.40933055f, -1.52167451f, 0.79756850f, -0.17307425f, 0.62368619f, -0.22466940f, -1.72802913f, 0.59047443f, -0.58020931f, 0.09096476f, -0.07317388f, 0.44522321f, -0.64880705f, 0.15684015f, 0.08708375f, -0.41556796f, 1.11579072f, -0.81733495f, 0.11643656f, -0.73995101f, 0.93685871f, 1.57971406f, 0.67606360f, 0.70509088f, -0.25283816f, -0.00010609f, -0.61884147f, -0.86409342f, 0.95383751f, -0.05895388f, -1.45261180f, 0.45166013f, -1.01434863f, 0.18496066f, 1.06517637f, 1.81127059f, 0.89470667f, -0.13232610f, 0.46958798f, 0.13884509f, 0.57117194f, 0.29575035f, -0.97884250f, 0.83291447f, -0.59255791f, -0.04354135f, -0.19431923f, 0.30071029f, -0.95421529f, 0.76359886f, -0.47799742f, 0.68254346f, 1.19368529f, -0.48935115f, 0.30357337f, -0.50225669f, -0.23370270f, 1.96702433f, 1.46558523f, 2.68482018f, 0.41622332f, 0.73697484f, 1.43430734f, 0.15387188f, 0.20875402f, -2.49335337f, -1.39674246f, -0.22125854f, -0.00424605f, 0.91416460f, 0.33384630f, 0.44703746f, 0.25610185f, 0.38966551f, -0.01784045f, 1.66148460f, 0.36005461f, 0.95716912f, -0.18246566f, -0.15480693f, 0.38775176f, -0.56969136f, -0.29644895f, -1.04565966f, -1.00455630f, 0.30897698f, -1.46885884f, 0.03657720f, -0.49302089f, 1.34134722f, 0.01673754f, 1.22725964f, 0.55256772f, 0.63803208f, -0.29041430f, 1.11455286f, 0.76329172f, 0.27073982f, 0.77173829f, -1.79884446f, -0.11889492f, -1.92040312f, -0.46382675f, 0.20078070f, -0.98889589f, 1.46711135f, -1.68280172f, -0.52852470f, 0.66245162f, 0.29575166f, 1.34826505f, -0.22362417f, -0.14345661f, -2.34815073f, 1.26572001f, 0.66505629f, 1.01141500f, 1.08030057f, 0.17036134f, 0.00168786f, -0.37282917f, 0.69206375f, 1.07367527f, -0.49708191f, 1.49504781f, 0.58224988f, 0.96593714f, -1.07661915f, 0.25202179f, 0.25531644f, 0.42357162f, -0.31236249f, 0.48383278f, -0.06361829f, 0.24131298f, -0.95695931f, -0.12589653f, 0.36134180f, 3.20266032f, -0.40879184f, -0.66985190f, 1.51674330f, 0.34072638f, 1.15076303f, -0.40199137f, 0.46223637f, -0.48608047f, 0.99119538f, -0.22506073f, 0.30968750f, 0.64210880f, 0.54640514f, 0.18607031f, 1.26293361f, -0.77960914f, 0.79572529f, 1.01936150f, 2.27160740f, -1.48034489f, 0.74466604f, 0.14863680f, 0.31102443f, -1.15673816f, -0.38609681f, -2.65026069f, -0.45524642f, -0.74022961f, 2.74991131f, 0.00103815f, -3.03303242f, -0.41556966f, -0.87103498f, 0.78306234f, -0.88195556f, -0.77297026f, 1.21203196f, -1.09754920f, -0.03556008f, -0.31546223f, 0.72954375f, 0.25251788f, 0.11378583f, 0.50921023f, 0.30301905f, -1.60631680f, 0.27152416f, 1.17342317f, -0.70891970f, -0.08392961f, 0.92137378f, -0.10568139f, -0.31653777f, -0.28878728f, 1.22166574f, 1.12693942f, -0.21325994f, 0.94010323f, 1.21796405f, -0.68866694f, 2.30724216f, 0.28141466f, 0.83481526f, -0.04885862f, 0.01675143f, 1.04355800f, -0.81050140f, 1.51300573f, 0.53429186f, -0.56439877f, 0.38572624f, -0.05620475f, 0.67644542f, 0.72528905f, 0.05937041f, -1.06315899f, -0.51393986f, 0.46937627f, -0.34699562f, -0.64765716f, -1.45512629f, 0.47739139f, -0.88228017f, -2.00791359f, 1.29929042f, 0.05482405f, -0.66725296f, -0.54735124f, 0.09972951f, 0.76675093f, 0.98748523f, 0.08900899f, -0.78854066f, 1.47970486f, -0.61667502f, 0.45625573f, -0.21766303f, -0.46250847f, -0.07130960f, 0.64414692f, 0.12784545f, 0.26393634f, 1.07720757f, -1.23938286f, 0.62483376f, -0.55001754f, -0.05358591f, 0.07322436f, 1.12003291f, -1.00830650f, -0.20486419f, 0.76664752f, 0.28850746f, -0.04464776f, -0.40146068f, 0.73262817f, -1.12827921f, -0.19989438f, -1.15999687f, 1.37973154f, 0.78881019f, -0.34762639f, 1.22088552f, -1.64088547f, 0.63218033f, 0.45736769f, 0.05502866f, 2.22683382f, -1.78935897f, -1.49635041f, 0.83450896f, 1.67770112f, 1.33909333f, 1.51158953f, 0.28595078f, -0.08593627f, 0.45812801f, -0.15193029f, 1.14770603f, -0.88920450f, -1.96352005f, -1.49894583f, 0.49629962f, 1.59872091f, 0.00903497f, 2.15563583f, 2.25149560f, -2.01200557f, 2.56229877f, -1.38850498f, 0.73552012f, -0.39378855f, 0.52616280f, -0.03685786f, 0.87403935f, 0.12163408f, 0.74297994f, -0.30697080f, 0.38139752f, 0.49113834f, -0.95485127f, -0.99908817f, 0.71716321f, 0.04000283f, -2.09645271f, 1.38789880f, 1.37198520f, 0.82493287f, 0.17114936f, 0.53696346f, -0.19516060f, -0.50377476f, -0.91730285f, -0.70113552f, -0.02406530f, 0.84943396f, -0.17428185f, -1.09140801f, -0.68156958f, 1.70756388f, -1.00399911f, 0.03023832f, -0.39023280f, -1.89737976f, 1.14469039f, -0.58337289f, -0.60037899f, -1.17490256f, -1.56342828f, 0.48714057f, 0.62266618f, -0.15967095f, 1.32789338f, -1.25700688f, -0.55633998f, -0.83128709f, -0.49346271f, 1.59561753f, -0.24675299f, 0.38012561f, 0.91796309f, -0.38522810f, -0.65509188f, 0.94100451f, -0.57324487f, 2.19070768f, 1.24058700f, -0.75978851f, -0.40460554f, 0.79189235f, 0.70192885f, 1.93569362f, -0.03070199f, 0.77010989f, 0.58794290f, 0.51087004f, 0.22892070f, 0.35007235f, 1.56023848f, -0.67453802f, -0.18485607f, 0.64349502f, -0.31489357f, -1.95834625f, 0.06560058f, 2.30394220f, 1.18194163f, -0.88034087f, -1.05000436f, -1.05471325f, -0.98481798f, 0.49904808f, 0.16438948f, -1.10297823f, -1.39736509f, 0.01306054f, -1.85160267f, -0.87292641f, -0.15418227f, 0.43412164f, 1.16518164f, 0.06273691f, 0.24659210f, -0.08267246f, 1.28885782f, 0.73575675f, -0.01019809f, -0.08753663f, -0.61827368f, -0.40863234f, 2.12599611f, -0.53620332f, 0.53789747f, -0.66386080f, -1.70461988f, 0.86608189f, -1.11151052f, 0.14120635f, 1.18858743f, -0.31760478f, -0.73533046f, 0.20978074f, -0.84074509f, 0.16523147f, -1.03362834f, 0.59721231f, 0.21318658f, 0.23671274f, 1.75115061f, 0.25363782f, -1.32541454f, 1.13056135f, 0.24652456f, 0.60381413f, 0.21478581f, 0.75044096f, -0.63125616f, -1.69889998f, -0.02116571f, 1.46165359f, 1.03068244f, 0.63693464f, 0.67795700f, 1.20033514f, -1.39205134f, -0.61743122f, 0.56549704f, 0.65182322f, -0.74250507f, -1.61939359f, 1.14054918f, -0.45725963f, 1.74519682f, -0.66251940f, -0.94811529f, -1.60865819f, -0.59968346f, 0.86309159f, -1.91936195f, -1.02646923f, -1.50352538f, 0.58292735f, 0.05320299f, 1.53582895f, 0.01069612f, 0.15226212f, -0.71840125f, -1.36896348f, 2.14600968f, 0.96626586f, -0.52014917f, 0.41001406f, 0.59478027f, 0.15282436f, 0.27790198f, 0.76614654f, -0.38971323f, -0.01839927f, -1.57882118f, 0.61391610f, -0.62133092f, -0.03968323f, -0.88467252f, -1.24041140f, 2.07306671f, -0.41776338f, 0.14537935f, -0.91069067f, 1.67362070f, 4.72630215f, -0.07395106f, 0.46280116f, -0.40843824f, 0.70683080f, -0.27510864f, -0.63465804f, -0.83630908f, -0.44419941f, 0.60405648f, -0.65039170f, -1.02413189f, 1.05983019f, 1.73366308f, 0.73343736f, -0.00895882f, -1.00826013f, 0.17323074f, 0.73995626f, 0.24128854f, 0.94510227f, 0.25557515f, 0.02244723f, -0.95197725f, -0.16297856f, -0.38497585f, 1.17993331f, 1.20282137f, -1.31491220f, 0.44229278f, -0.24349044f, -0.01230415f, 1.37944865f, 0.48554277f, -0.54510897f, -0.10793537f, 0.41121426f, -0.12889031f, 0.26434359f, 1.27966082f, 0.64518744f, -0.15577169f, -0.99864733f, -0.61746484f, 2.01614976f, 1.56254935f, 1.86473298f, -0.54662132f, -0.22047071f, -0.06118120f, 0.84799510f, 0.17009684f, -1.30523121f, 0.64000309f, 0.36299205f, -0.59620583f, 1.36372304f, -0.05389515f, -0.93849313f, 0.98043185f, -0.39373067f, -0.84898937f, 1.32077873f, 1.05988657f, -1.35339200f, 0.23259017f, 0.63816410f, -0.80297333f, 0.60017115f, 1.25715804f, 1.18894124f, -0.62473553f, 1.05611980f, 0.02335166f, 1.07509828f, 0.25873449f, -1.68341100f, 0.54547334f, 0.79288185f, -0.93678916f, 0.19202201f, -1.48575914f, 1.08649087f, 0.50851744f, -0.45758674f, -0.39734635f, 0.35637981f, -1.63079453f, -0.75910008f, 0.92640859f, -0.55599529f, -0.40276715f, 0.31307653f, 0.39907026f, -1.18830419f, 0.71051043f, 0.14157933f, -0.39581308f, -1.64361024f, -0.06161860f, -0.25312796f, 1.10018682f, 0.56500763f, 0.80385065f, 0.35395023f, 0.81813669f, 0.27644628f, 0.65563256f, 1.73197234f, 0.68178749f, 0.76769936f, 0.44597456f, 0.67761195f, 0.67635447f, -0.32315412f, 0.19330767f, -0.25557944f, 1.91693723f, 0.38335562f, 0.07107610f, -0.57384586f, 0.79184365f, 1.87835479f, 0.60902315f, -0.94220877f, 0.79479855f, -0.25656971f, 0.08739131f, 0.53384244f, 1.22159266f, -0.39152125f, -1.46373534f, -0.02458516f, 1.62825716f, -1.26112676f, 0.19967082f, -0.71114451f, 0.27929229f, 0.65001321f, -0.11868202f, -0.55587751f, 0.78069001f, 0.57969242f, -0.60274386f, 0.31650013f, 0.90339553f, 0.09453616f, -0.37119162f, -1.00320566f, 0.33299938f, -0.48636708f, 0.26342997f, -0.91914523f, 0.28682709f, -1.24780893f, -1.59254742f, 0.97176319f, 0.14744301f, -0.53056234f, -1.73221612f, -0.67645556f, 0.98705006f, 0.79895812f, -2.04333115f, -0.60132772f, -0.91653955f, -0.28094748f, 0.47943443f, 0.38157779f, -0.67648011f, 1.09093642f, 1.66012859f, -0.29358891f, -1.26773024f, 0.36747769f, -1.10141146f, 0.82383633f, -0.89772314f, -0.47145563f, 0.63939518f, -0.64430422f, -0.48889321f, -0.37680882f, -1.06962025f, -1.28689516f, 1.28365147f, 0.61859220f, -0.84676331f, 1.38404000f, 1.21053445f, -0.14871351f, 1.06349385f, 1.45878971f, -0.47362664f, 1.40707004f, 1.25224137f, 0.87364739f, 0.92858213f, 0.00157326f, 1.45661485f, -0.27318576f, 0.15482858f, -1.07058907f, -0.06903186f, -0.74147576f, -1.64111829f, -0.67226541f, -1.13458407f, 1.28511488f, -0.41041154f, 2.09085560f, 0.45243183f, -0.67437285f, 0.84960121f, -1.49300814f, -0.42961186f, -2.35021853f, 0.57255560f, -0.73903763f, 1.37607956f, -2.44575167f, 1.25105727f, 1.38575912f, -1.16299784f, -0.13719854f, -1.11507034f, 0.35796806f, -0.64511567f, -0.87903833f, 0.32833642f, -0.87696886f, 0.02714214f, 0.30224666f, -0.69118696f, -1.23500824f, 0.76678628f, -3.20508122f, -0.24704689f, 0.49019828f, -1.20862615f, -0.03778638f, -0.07273687f, -0.11517122f, -1.75857520f, -1.64188445f, 1.21574795f, 0.57325113f, 1.14370298f, -1.07824504f, 1.70653832f, -0.03700557f, -0.47645858f, 0.11065386f, -1.03143036f, -2.18094873f, -0.94403434f, -0.09335683f, -0.44817665f, 1.39707148f, -1.21947956f, 0.56575936f, -0.69612634f, -1.12361753f, -0.17105591f, 1.15422392f, 0.02840637f, 0.09469353f, -0.52859986f, -2.08487725f, 1.28789508f, -0.03740775f, 0.61196613f, 1.23405397f, 1.56595814f, -0.65800631f, 2.02985072f, -0.69446486f, -0.88443804f, -0.23448054f, -0.43628734f, -0.45888957f, -0.21943338f, 1.78258693f, 1.75214970f, 0.71804136f, 0.49782532f, 0.37886053f, -1.59176385f, -1.74758542f, -0.02820176f, 0.75398153f, 1.00119829f, 0.80881971f, -0.53365272f, -0.22720885f, 0.37476870f, 0.01005529f, -1.23421800f, -0.13431595f, -1.01843679f, 1.87386346f, -1.68539488f, -1.04942071f, -0.77322137f, 0.53964764f, 0.29278332f, -0.58299130f, -1.56022692f, -0.79441273f, 0.49289709f, 0.44112054f, 1.07305002f, 0.54899335f, 1.13781393f, 0.77809113f, 0.81795985f, 0.16576190f, 0.32552773f, -0.20250474f, 1.46543837f, 0.12731771f, 0.21013761f, -1.34241438f, 0.44267517f, 0.93246883f, 0.08808212f, 0.92653406f, -1.21083558f, 0.17247954f, -0.70557106f, 0.04630012f, 0.48834828f, 0.89634645f, 0.46683592f, -0.29553145f, 0.46363977f, -0.48971879f, -0.88603491f, -0.12333342f, 0.37073737f, 0.92061806f, 0.54675460f, -0.14716248f, 0.75578392f, -0.98173791f, -1.15983224f, -0.58713156f, 0.07950903f, -0.59016788f, 0.41622928f, -0.32474482f, 0.42086437f, 0.23061797f, 0.62596649f, -0.22615278f, -2.14721417f, 1.01685894f, -0.25976995f, 0.00739352f, -1.31597066f, 0.39005190f, -1.09549701f, 1.68375242f, 0.43331525f, -0.37124026f, 0.22255214f, 0.59654880f, -0.73840386f, -1.20048976f, 0.12226126f, 0.12997478f, 1.04826224f, 0.03894836f, -0.36289826f, 1.14466560f, -1.18198848f, -0.03713558f, 0.67677927f, -0.42329931f, -0.89409167f, -0.77874780f, 0.58438253f, -0.35176343f, -1.53329861f, -0.02995299f, -0.40145162f, -1.51052392f, 0.09194464f, -1.13275242f, -0.61983156f, -0.40004560f, -0.19893464f, 0.22134103f, -0.03903082f, 1.14894116f, -0.03476744f, 0.22520730f, -0.55851930f, 0.76650429f, -0.57863152f, -1.34161711f, -0.31498179f, -1.19411755f, 1.70044947f, -0.17428267f, -0.35983825f, -0.42613637f, 0.58165723f, -0.77866900f, -1.59727287f, -0.61723864f, 1.51078022f, 0.32971445f, -0.86441469f, 0.60552609f, 0.00208178f, -0.47096625f, -1.10479307f, -1.21652532f, -0.08211990f, -1.43739200f, -1.31684434f, 0.43312529f, -0.76822090f, 1.88128507f, -0.02179282f, 1.04971325f, -1.55004108f, 1.25337446f, 0.11203052f, -1.16048300f, 1.59467411f, -1.29469275f, 1.14019871f, 1.20021439f, 1.84098923f, 0.05004879f, 0.73529941f, 2.05272865f, -0.13080600f, -0.08436690f, -1.17919350f, -0.66256678f, -0.36727047f, 0.73840511f, 1.22293818f, -0.00206342f, -0.29839504f, -0.00618613f, 1.04213119f, 1.21176076f, -0.62886089f, -0.02589060f, 0.96009409f, -0.64478731f, -1.16516542f, 0.57528079f, 1.04294407f, -0.09774588f, 0.45935291f, 1.03263175f, 1.00633478f, -1.82209253f, -0.18035053f, -0.28302726f, -0.83813244f, 0.57593471f, -0.03807700f, 1.60498738f, 0.16530658f, -1.43083501f, 2.10824299f, 0.30279446f, -0.03961089f, -0.38900724f, 1.31272805f, -0.56575215f, 0.57970244f, -0.48305038f, 1.34114623f, 0.21859215f, 0.66399640f, -1.52087069f, -1.30717897f, 0.14394683f, 0.97648209f, -0.71372712f, -1.22574198f, -0.27702177f, 0.04041927f, 0.02442212f, 2.19617033f, -0.48566443f, 0.81463927f, 0.20383844f, 1.17562282f, -0.33829874f, -0.42141283f, -0.96415234f, -2.39141965f, -1.04285860f, -0.23004992f, 0.41186509f, 0.03811268f, 0.36818987f, -0.71099734f, -0.56749570f, 0.18486284f, -0.44530040f, 2.14008284f, -0.27467576f, 1.70690107f, -1.40462613f, 0.24697532f, -1.31629777f, -2.20674944f, -0.67868507f, -1.15767133f, -0.64391804f, -1.79037917f, 0.58749497f, -1.58303332f, -0.69021022f, 1.64376318f, -0.95393223f, 1.98415601f, -0.10991055f, 0.02474386f, 0.23683345f, -0.63420391f, -0.57991928f, 0.83028817f, -0.40033704f, 0.19212338f, 0.74640590f, 1.10264432f, -1.65286255f, 0.92683482f, -1.42252541f, -0.74605089f, 2.14535880f, 0.12971123f, -0.47971717f, 1.67546797f, 0.42268261f, 0.22648531f, -0.42369929f, 0.77403021f, -1.31818616f, -0.67143595f, -0.04311426f, 1.64128351f, 0.34776631f, -0.39353722f, -0.42765084f, 0.16170517f, -0.54488391f, -0.38428506f, 0.42097485f, -0.55982012f, -1.74543798f, 1.53704774f, 0.43562424f, -0.30395737f, 0.31846946f, 0.39205357f, 0.57386035f, -1.11912560f, -1.39164317f, -1.04337609f, 0.31629622f, 1.51927638f, 0.88745505f, -0.40445471f, 0.25783861f, 1.88646257f, 0.36509129f, -1.13266826f, -0.45394278f, -0.48400903f, -1.22332740f, 0.38626808f, -1.10049105f, 0.84138852f, 1.27863181f, 0.53942156f, -0.67743856f, -0.03896645f, 1.70393491f, 0.60997570f, 0.43368068f, -0.13338457f, -0.18920666f, -0.29583672f, -1.40738738f, 1.03876019f, 1.71253765f, 2.12821221f, -0.96092403f, 0.93841934f, -0.79030478f, 1.36427641f, -1.39196694f, 0.08514920f, 0.16223004f, 0.71259701f, 0.20150672f, 0.25068361f, -0.99952722f, 1.80129099f, -1.28586197f, -0.64957166f, -0.94813949f, -0.40161121f, 0.31977695f, 0.54932386f, -0.67757767f, 1.88086259f, 0.92337233f, -1.64887333f, 0.44333732f, -0.19468001f, 0.12977587f, 0.21171951f, 0.27679422f, 0.49134475f, -1.44429457f, 1.25617445f, 0.39978400f, 0.99869555f, -1.61617446f, 1.61177349f, 0.70243025f, -0.95748568f, -0.61795151f, -0.77302909f, 0.72967088f, 0.81964350f, -0.71813750f, 0.90140164f, -1.45950246f, -0.79972702f, 0.40875742f, 0.00152073f, -1.74491429f, 1.53776145f, 0.75769204f, -0.22075878f, -0.58385569f, 2.18884754f, 0.33597681f, -1.66265559f, 1.03805876f, -1.55245185f, -0.03582226f, -1.94542754f, -0.76081425f, -0.50471377f, 1.35763168f, -0.39631784f, -0.17134467f, -0.82220149f, -0.41021580f, -0.00940776f, -0.80176353f, -0.19816744f, 1.22061026f, -0.14486519f, -0.71727395f, -0.65721530f, 0.47020102f, -0.70403302f, -0.94795334f, 1.79884899f, 0.07779162f, -1.50615680f, 0.04140327f, -0.22001404f, 0.63735324f, 0.79237640f, -2.25412822f, -0.52519119f, -0.87280381f, -0.07100742f, -0.94734806f, -0.12286110f, -0.13623615f, -0.42595413f, 0.17547913f, -0.81707209f, 0.36855817f, -1.68186557f, 0.19312963f, -0.66249490f, -0.98283452f, -0.33314428f, 0.40918943f, 0.88268638f, -0.05390308f, -0.22440539f, -0.15879378f, -0.34859571f, -0.01013108f, -0.30005428f, -1.19408464f, 0.21789688f, -1.07769871f, 0.81475031f, -0.69555300f, 2.35201311f, -0.40362412f, 0.93497628f, 1.13343573f, 0.92343372f, 0.26987928f, 0.46123627f, 0.22577702f, 1.26289701f, -0.45956740f, 0.55994868f, -0.58410591f, 0.13304594f, -0.25806463f, 0.49044946f, -0.82065403f, -3.06672239f, -0.27774641f, 0.68504512f, -0.21386372f, 1.11427057f, -0.73201770f, 0.51655543f, 1.77261138f, 0.72081727f, 0.11116749f, 0.16637769f, -0.74987584f, 0.66579849f, -0.75808716f, 0.20678560f, -0.67698354f, -0.82141948f, 0.61008269f, 0.66520184f, 0.44894725f, 0.73015076f, -1.52517414f, 0.11714164f, 1.90452611f, -1.30355322f, 0.12144456f, 1.18547559f, -0.07349755f, -2.28061509f, 0.83522540f, 0.78438890f, 2.19334102f, 0.90305614f, -0.59345531f, 0.77925014f, 1.32338643f, 0.14068902f, 1.19032264f, 0.20666829f, -0.76595837f, 0.74967057f, 2.86965609f, 0.55690205f, -1.72530472f, -0.83317834f, -0.85842621f, -0.29678273f, 1.80955839f, -0.70496303f, 1.19106734f, -0.92985237f, -1.00617313f, -0.56049556f, -0.29382578f, -2.04022193f, -1.95356870f, -0.42553005f, -0.33369407f, 1.02115977f, -1.45769477f, -0.67720300f, 0.53819913f, 1.57643425f, -0.47015440f, -1.47861958f, -0.00545934f, -0.97836047f, 0.42680529f, 1.56110144f, -1.49487829f, -0.65198445f, 0.22720462f, 1.83036661f, -0.47099793f, -0.09915133f, 0.14923312f, -1.16313052f, 0.67798084f, -1.63665557f, -0.38220280f, 0.01719763f, 0.30041245f, 0.43148938f, -0.44021657f, -1.25734651f, 0.02465564f, -1.00845659f, -0.28574651f, 0.01367745f, 0.77253437f, -0.99399441f, 0.61445391f, 0.18343423f, -0.50997210f, 0.41359940f, 0.77279282f, 0.83511519f, 0.27929801f, 0.70800692f, -0.20278299f, 1.57884383f, 0.22650529f, 0.43347472f, 0.74003208f, -0.71401161f, -0.69829476f, -1.56766701f, -0.99254119f, 1.27301061f, 2.73726511f, 0.66089469f, -1.95778012f, -1.24642098f, -0.63579029f, -1.63168180f, -0.66980726f, 0.81933254f, 0.61866677f, 1.40594471f, 0.05158535f, 0.00196500f, -0.24592508f, -0.50780547f, -0.83905292f, -0.10748957f, 0.04490763f, 0.27769178f, -0.23227681f, 0.82108080f, 0.03562285f, 0.95483875f, -1.49897683f, 0.67809856f, 0.35497451f, -0.44021592f, -1.67361462f, -0.88895375f, 1.44293678f, -0.85046643f, -0.46437624f, -1.87252641f, 0.26775804f, -0.24535774f, 0.73365933f, 0.52253938f, 0.27947086f, -0.58796054f, 0.59045380f, 1.93476331f, -0.46775359f, 0.25238225f, -1.26601815f, -0.13324316f, -0.71454948f, -0.21610366f, -1.49586582f, 1.04903507f, 0.22208478f, 0.25512528f, -0.46157327f, -0.41319233f, -0.63846964f, -0.25100923f, 0.81277549f, -0.26959971f, 0.88737756f, 1.24578953f, -0.91121447f, -1.05756927f, 0.44390878f, 0.16672316f, -1.22941923f, 0.89547867f, -1.50212002f, -1.69620168f, 0.53339505f, -0.23656729f, -1.69879091f, 0.01510374f, 0.08315694f, -0.73196459f, -1.60263407f, -1.07601058f, -0.76389569f, -1.65307498f, -0.61484390f, -0.43546933f, 0.71318507f, -0.16273083f, 0.64122051f, -0.15406294f, 1.17673671f, -0.91240519f, 0.71091145f, 2.40497613f, 1.26343656f, 0.71469337f, 0.20705548f, 0.81776261f, 0.36253929f, -1.92106628f, -0.09300470f, -0.36648872f, 1.27732766f, -0.39180157f, -0.61186749f, -1.03455031f, -0.25079829f, -0.61479062f, -1.07094336f, 0.82218504f, 0.89934880f, 0.41308978f, -0.59968555f, 0.37682834f, -1.77388155f, 0.00294951f, -0.66145372f, -0.50789726f, -0.85123241f, -0.89909405f, -1.89454281f, -0.56692821f, 1.52272677f, -0.11961794f, 0.27843913f, -0.60582250f, 1.01871169f, -0.36098275f, -0.12242325f, -0.67375034f, -0.11204147f, -2.62773919f, -0.95901299f, 0.14040214f, 1.32364666f, -1.35099924f, -0.11077739f, -0.79319423f, 0.75949597f, -0.25485823f, -0.90959758f, -0.42373934f, -1.29850340f, 0.85699379f, -1.11882365f, 0.63470817f, 0.49696380f, -0.07983235f, -0.23903450f, -0.22618714f, -0.12117998f, -0.09442677f, 1.55589819f, -0.11996678f, -1.72700179f, 0.54683149f, -0.40804827f, -0.50099218f, 0.34596699f, -1.81841791f, 0.06385052f, 0.84428120f, 0.69901514f, 1.94559097f, 0.43251973f, 0.16794942f, 1.82829034f, 1.70959795f, 0.36130908f, -0.94608402f, -0.53498030f, 0.47781768f, -0.24203247f, 1.25065851f, 0.51788396f, -2.09381890f, 0.72973937f, 0.03281829f, 0.58632666f, 1.85737121f, -0.49569523f, 0.45921183f, 1.87173629f, 0.22803484f, 1.66433418f, -1.05872321f, -1.13663685f, 0.12397861f, -0.65112090f, 0.98152941f, 0.83739656f, -0.18783289f, 1.84249437f, -0.90706986f, -0.80824369f, -1.23854923f, -0.86488134f, -1.02627063f, 0.10976455f, -0.61403006f, 1.27554715f, 0.14653525f, -0.03953953f, -0.08512071f, -1.30043304f, -0.02566035f, 0.12054887f, 0.00282162f, 0.48921332f, -1.74398839f, 1.44554436f, -1.35854721f, 0.69256759f, 0.34101671f, 2.50045252f, 0.49121150f, -0.27115449f, 0.93974596f, 0.26258010f, 0.27151433f, -0.87214381f, -0.92580765f, -1.03269923f, 0.20615758f, -0.37822601f, 0.58983004f, 0.16426525f, 0.68218285f, 1.98158526f, 0.47492698f, 0.54224718f, 1.28722692f, -1.76915324f, -1.11240053f, 0.77428484f, 0.27184650f, 2.22473478f, -0.05574624f, 0.39976570f, -0.43911108f, 0.52805597f, 0.17340177f, 1.36057591f, -0.35004014f, 1.72787797f, 0.68357420f, 1.25532615f, -0.56752264f, 0.51840127f, -0.21237844f, -0.58821255f, -0.85278064f, 1.90179110f, -0.67447448f, -0.36831430f, -0.22930753f, 0.98231596f, -0.07011599f, -0.08560387f, 0.05998110f, -0.02481356f, -0.57335132f, -0.44288307f, -0.24468307f, 0.53321087f, 1.19609559f, 0.10664973f, 0.24379487f, 0.93687552f, 0.93615580f, 1.74319768f, -0.68310338f, 1.32163060f, 0.61918712f, -0.76501870f, -0.54549301f, 1.74077415f, -0.69977754f, -0.66880983f, -1.15981388f, 0.81571609f, 0.53788543f, 0.47898352f, -0.02484704f, -1.64646924f, -0.69822907f, 0.27020717f, 0.05027051f, 1.75149667f, 0.01548872f, 0.32615909f, 2.55151844f, -1.29172051f, -0.36133784f, 0.98637396f, 0.14009331f, -0.50038946f, -0.92230296f, 0.17307127f, 1.05361068f, -1.46784890f, 2.38960409f, 1.19413340f, -1.33349669f, 1.59141159f, -0.71811068f, 1.22429430f, 1.26947939f, 1.08177102f, -1.18138707f, -0.72775704f, 0.17282635f, -0.40554270f, -0.40341887f, 0.46564049f, -1.02069795f, -0.07653128f, -0.13979210f, -0.31195050f, -1.72042310f, 1.37131393f, 0.63849634f, 0.75561279f, 1.81152904f, 0.26686314f, 1.32796574f, 0.56100166f, 0.70058894f, -0.88962644f, -0.04360984f, -0.88249093f, 0.24311203f, 0.50410056f, -2.22567797f, 0.94520348f, -2.12467694f, 0.47282359f, -0.71379906f, -0.09857135f, 0.62374717f, 1.37182784f, 0.73380554f, 0.59745449f, 2.80427694f, 0.67253572f, 1.65335357f, 1.69891667f, 1.34585941f, -0.79989213f, 1.44980943f, -0.52013642f, -0.46971673f, -1.50070012f, -0.25687039f, -0.56916732f, 0.71065760f, -1.31996286f, 0.96031237f, 0.13929774f, 1.49679291f, -0.05966444f, -0.58674580f, -0.08278833f, -0.93390942f, 0.42415768f, -1.77889526f, 0.75336021f, -0.72699982f, -0.82880586f, 0.63955617f, 0.42771208f, -0.42366457f, -0.91581815f, 0.94750947f, 0.43123913f, -0.99053741f, 0.70470595f, -1.16662264f, 1.14847183f, -0.83885664f, 0.46714026f, -2.27748466f, -1.23656678f, 0.14695056f, -0.33159894f, -0.52553117f, -0.04391259f, -0.29630372f, 0.25949728f, 0.96991086f, -0.37714824f, -0.28251833f, 0.16106486f, 1.38844633f, -0.18713553f, -1.30708838f, 0.48490265f, 0.29553881f, -0.45505449f, 0.83341682f, 0.87346369f, -0.63516861f, 0.66063565f, 0.93892503f, -2.73996735f, -0.81515318f, -0.91458052f, 0.00978268f, 0.43472794f, -0.08090764f, 1.37249672f, 0.76722521f, -1.19154143f, 0.22046764f, 0.34916410f, 0.51383299f, -0.56379753f, -2.49949312f, -0.74207872f, -0.68400806f, -0.09663232f, -0.07199454f, -1.05562651f, -0.75028551f, -0.87253797f, 0.69039482f, 0.45923674f, -1.27515161f, -0.04555376f, -1.41501272f, -0.83773375f, -0.74807298f, 1.36646152f, 0.06317432f, -1.32559633f, 1.89092779f, 1.24883330f, -1.03608561f, 1.08677161f, -0.99629849f, -0.69947034f, -0.85716367f, -0.07947286f, -0.25485426f, -0.19732477f, 1.64581251f, 1.04618108f, 1.87186897f, -0.18198362f, -0.83807969f, 0.70462501f, -3.18930101f, 0.74610996f, -0.60935193f, -0.49383929f, -2.88986492f, 0.51707613f, 1.04620326f, 1.09837818f, -1.19840038f, -0.10391295f, -0.20789115f, -1.51052022f, -0.31087330f, 0.22411564f, -1.30506921f, -1.52000105f, -1.51593041f, 1.04321992f, 0.97611690f, 0.90424490f, 1.83324766f, -0.08682299f, 0.47035542f, 1.70865905f, -0.31108001f, 0.04115159f, -1.36352801f, -0.90797836f, 0.32128647f, 0.66191489f, 0.08681208f, 0.14993365f, 0.47110486f, -0.31522670f, -0.38906571f, -0.08876022f, -0.13106902f, 2.25685239f, -0.62211353f, -1.68553007f, -0.23707703f, 0.69236159f, -0.46686995f, -0.27520603f, 0.26619941f, 1.48525345f, 1.61278927f, 0.49452963f, 1.20846486f, -1.11853909f, -0.30010033f, -0.75471467f, -1.69959772f, -0.52042168f, -0.43881389f, -1.45240712f, 1.02122891f, 1.73639011f, -0.03813924f, -0.22239220f, 0.15797073f, -0.64418089f, -0.60228932f, -0.83248150f, -0.02042520f, 0.38137484f, 0.86056453f, 0.06410559f, -0.62785137f, -0.49916875f, -2.53796315f, -0.79168582f, -0.69197005f, -0.77175534f, -0.28669405f, -0.79764080f, 0.97218460f, -0.10351621f, -0.52759898f, 1.02840185f, 1.16363287f, 0.08351815f, -0.61088538f, 0.59944046f, 1.54409397f, -1.39842033f, 0.27917057f, -0.27146137f, 1.46310735f, 0.03626106f, 0.15038440f, -0.07894899f, -1.42527366f, 1.69641745f, 1.48384345f, -0.43328866f, -0.54252565f, -0.94416499f, 1.54436302f, -0.81367069f, -1.67925239f, -0.17525831f, 0.27891046f, -0.69066733f, 0.89911050f, 0.11606655f, 0.67450327f, 0.41538724f, 0.90886223f, 1.19786549f, 0.85810721f, 1.32862210f, -0.83469814f, -1.09682298f, 0.88092703f, -0.97478902f, -0.11664717f, -0.07929394f, -0.69581884f, -0.16928329f, -0.70731819f, -0.40485084f, -0.28954300f, 0.52882415f, 0.38769314f, -1.38704026f, 1.15099049f, -0.43566978f, 0.34459323f, 0.49520254f, 1.11130333f, 0.28783718f, -0.53783375f, -1.63577271f, 1.02222812f, 0.86302060f, 0.48346213f, 0.46627176f, -1.30133855f, -1.48477137f, 0.31219670f, -1.21498191f, 0.89838904f, 0.87186617f, -0.39968935f, 0.34930915f, -0.32909471f, -1.39364409f, 2.13006306f, 0.33270469f, 0.00215986f, 0.97776711f, 0.24908836f, 1.56164885f, 0.45157790f, -1.55970144f, 0.27677536f, 0.07662498f, -0.08262251f, -0.17658773f, 0.65820259f, 2.01052690f, -1.71946216f, 0.84686053f, -1.23594892f, 1.40792072f, -1.47772563f, -0.36132276f, -0.50405115f, 0.09009213f, 0.81659186f, 1.85574234f, -0.64974433f, 0.63352364f, 1.01766217f, -1.54804432f, -0.42570522f, -0.24763709f, 0.72822112f, -0.93733686f, 0.68087620f, -1.40644944f, 0.48672482f, 0.09725539f, -0.64416331f, -0.95747960f, 0.36771363f, 0.39155054f, -0.71790671f, -2.17222738f, -0.08655047f, -0.97842115f, -0.22991380f, 0.52029115f, -1.42072022f, 0.29576331f, 0.32391560f, -1.00823236f, 1.67909145f, 1.16841447f, -0.32307062f, 0.15756166f, -0.97590631f, -0.39429301f, -0.03583352f, 0.17554663f, 0.57961231f, -0.46873134f, -0.23343173f, -0.85060924f, 1.71745574f, -0.04658702f, 0.63088381f, -0.67581934f, -1.53171062f, -1.58800113f, -1.17987096f, -1.16737640f, -0.87544650f, -1.17138922f, 0.38979119f, -2.39369726f, -1.34747124f, 0.58450359f, 0.87791806f, -0.04459394f, 0.97995293f, -0.10354915f, 0.65324986f, -0.17833626f, -0.85849386f, -0.42063358f, 0.19708554f, 0.10255250f, -0.59539181f, 0.86194044f, 1.68610668f, 0.55275291f, -0.43127069f, -0.04218780f, -0.08466262f, 0.31236625f, -0.92824298f, -0.09879152f, 0.32358822f, 1.04045570f, 0.35617545f, 0.09059231f, 1.19069445f, 1.96978688f, 0.63561743f, 0.15030998f, -0.29879019f, 0.22774190f, -1.01608860f, 1.03605175f, 0.47804731f, -0.30450734f, -0.61382371f, 0.45390254f, -1.93547988f, 2.01267338f, 0.52447683f, 0.18379784f, 1.11913633f, -1.24273467f, 0.15803322f, 1.72184098f, -0.79349059f, 0.10258614f, -1.53445125f, 0.02630571f, 0.81649125f, 0.91089755f, -1.12968338f, 1.04016411f, 0.28999722f, 0.74863863f, -0.61388236f, 0.01665530f, 1.43592548f, 0.68138391f, 0.11963340f, -1.26123953f, 1.36340797f, 0.25696915f, -0.58877039f, 1.42209792f, 0.55563360f, -1.33329606f, 1.84695840f, 0.88433737f, 1.04359078f, 0.18906727f, -0.03448994f, 1.17944050f, 0.86783957f, 0.44934425f, -0.77892244f, -1.76232874f, -1.01689589f, 0.78943914f, 0.92141974f, -1.00187087f, -0.13809921f, -0.90222073f, 1.10094714f, -0.13657950f, -0.44349849f, -1.61441302f, 1.05724919f, 1.50337231f, -0.05785890f, -0.76958144f, -0.51498759f, 0.69227600f, -0.37975949f, 1.31949317f, 0.82049531f, 0.32868597f, -0.31557772f, -0.75534385f, 1.27303052f, 0.43453619f, 0.11296938f, 1.18182182f, 2.23387384f, -0.86412978f, -0.01599468f, -0.70869064f, -0.09221385f, -1.23729551f, 0.79490280f, 0.03522846f, -0.95069039f, -1.73461652f, 0.72329187f, 1.40385795f, -0.11585230f, -0.78033113f, 0.07491048f, -1.12873089f, 0.18476245f, 0.57568848f, -0.28792691f, 1.35411644f, -0.76956165f, 0.29571572f, 1.03178787f, -0.38780826f, 0.31680650f, 0.69368076f, -1.23856580f, -0.49848995f, 0.14766994f, 1.02625990f, 3.03858209f, -0.51030380f, 0.96796870f, 1.35078156f, -1.07729447f, 0.84322494f, 0.54886484f, 1.31453705f, -0.45792100f, 0.31196272f, -0.15701357f, 0.83586836f, -0.74952888f, -1.17432022f, -0.31002575f, -1.02149463f, -0.36117774f, -1.22079086f, 0.03532525f, 0.00555908f, -0.45891216f, 0.29636297f, -0.68272704f, 0.41257843f, 0.37988129f, 0.01747893f, 0.82739186f, 1.52292180f, -0.79456621f, 2.20275712f, 2.13212132f, -0.81393015f, -1.15712392f, 0.22488308f, 0.62776327f, -0.85444915f, 0.44017896f, 0.05863331f, -0.83198178f, 0.93063420f, -0.16121253f, 0.12382501f, -0.37826315f, 0.93118382f, 0.19507533f, -0.58595538f, 1.46994352f, 0.13170272f, -0.70031989f, -0.12820166f, 0.30487457f, 0.84148771f, -0.68807501f, 0.21187615f, -0.67030680f, -1.79136002f, 0.70810199f, -1.20959783f, -0.08468831f, -0.06317700f, 1.35527098f, -0.47018668f, -0.91693246f, 0.14818805f, -0.05405350f, 1.16875637f, -0.17363262f, -1.61833882f, -0.32934523f, -0.38346377f, -0.62702698f, 0.34135151f, 0.48015586f, -0.65263331f, -0.04689486f, 0.01156854f, 0.37580970f, -0.16174591f, 0.59627324f, 0.24351901f, -0.87983090f, 1.57049024f, 1.25836349f, -0.41464049f, -0.62279183f, 0.09693756f, -0.23850618f, -0.49007827f, 0.22298151f, 0.10914832f, -0.35192192f, -1.27221346f, 1.10203624f, -0.86399704f, -0.47319838f, -0.77105570f, -1.68624854f, 0.81198281f, 0.82534081f, 0.75654501f, 1.47631240f, -0.61000234f, -0.58933264f, 0.54822850f, -1.22829592f, 0.11107657f, 0.56449169f, 1.50693524f, -0.59280968f, -0.64286685f, -0.20120731f, 0.27184448f, 1.55500400f, -0.48919386f, 1.04044867f, -0.87048137f, -0.40569979f, 0.21908638f, -0.51829034f, -1.48748124f, 0.02990401f, 1.83462536f, 0.29885170f, 1.32370698f, -1.30129600f, 2.43271399f, 0.22967771f, -1.13014007f, 0.95529765f, -0.83325785f, 0.43633386f, 0.85774118f, 0.78160155f, 0.58583075f, 1.18906367f, -1.54354560f, -0.68320692f, 0.01900371f, -0.79777133f, 0.12851712f, 1.10176420f, 0.79418170f, -1.41154039f, 0.36929929f, 1.12176800f, 1.23849642f, -0.89377707f, 1.01390159f, -0.50889206f, -1.12554002f, 0.17932732f, 0.48949540f, -0.54235244f, -0.28146735f, -1.39125514f, 0.13309635f, -1.12864995f, -1.29901242f, -0.04266220f, -1.98028529f, -1.34869373f, 0.00038156f, -0.92473024f, 1.48010647f, -0.02754467f, -0.26030368f, 0.93083733f, 0.27946711f, 0.64052200f, -0.04220961f, 1.25002527f, -1.07923257f, 0.19048618f, 0.08900311f, -0.40813437f, -0.73068553f, 0.52122378f, 0.68990833f, -0.38749605f, -1.09269309f, -1.63480806f, 1.01789618f, -0.61596102f, 0.81049860f, 1.30838764f, -1.49213874f, -0.77916288f, -0.72660202f, -0.92013240f, -1.61726642f, -0.11527207f, 0.35143322f, -1.11646879f, -1.45525432f, -0.82892823f, 0.15512508f, 1.01891017f, 1.40162635f, 1.02494884f, 0.33882582f, -0.78747398f, -0.26009330f, -0.38519114f, 0.79247451f, 0.02065756f, -0.48030257f, 1.01167107f, -1.74057114f, -0.84549171f, -0.15337363f, -1.92544484f, 1.01270044f, 0.00762185f, -0.16405612f, 1.61778915f, 0.93316060f, -0.68960994f, -1.13214970f, -0.94695878f, -0.28418848f, 0.17102109f, -0.08787476f, -1.83799696f, -0.13761258f, -0.18652774f, 1.46456254f, 0.34169790f, -0.40697145f, 1.49663997f, -0.99555492f, -0.67775637f, -0.51951116f, 1.35157657f, -0.27099034f, -0.46987835f, 2.28101230f, 0.59104478f, 0.75010139f, 1.01472175f, 0.25741309f, -0.56074983f, 1.12267506f, 0.35336846f, 0.61733276f, -1.63976014f, -0.17700450f, -0.25093642f, -0.75599891f, 2.10956192f, 0.95155340f, 0.72049862f, 0.50492924f, 0.62067389f, 2.08688402f, -0.73604703f, 0.63383341f, -0.53528428f, -2.11538506f, -0.98173052f, 0.59560484f, -0.26205051f, -0.91948050f, 0.00593397f, -0.11734286f, -1.41261208f, -0.83611172f, -0.27682739f, -0.20619918f, -0.36557615f, 0.77194935f, 1.67695415f, -1.39265156f, 0.04892010f, -0.37773246f, 0.16124558f, -0.18348448f, -1.38248885f, 0.58459854f, 0.65064198f, 1.11349559f, 0.36708066f, -0.15471332f, 0.14208725f, -2.06860566f, 0.29629150f, 0.93084633f, -0.47215626f, 0.60208917f, 0.95415461f, 1.03390312f, -0.03639749f, -0.23988228f, 1.27037442f, 0.95133096f, 0.33187470f, -0.34527761f, 0.22134073f, 1.01799667f, -0.81475645f, -1.18869019f, 0.23314142f, 0.25180560f, -1.23762786f, 1.25283313f, 0.16980635f, 0.40740708f, 0.59256923f, 0.16274920f, -0.69713289f, -0.16444311f, -2.41602516f, 0.37952334f, -0.05604568f, -0.23772651f, 0.20581599f, -0.54303211f, 1.71877348f, 0.83602583f, -0.32586128f, 0.73609394f, -1.73640239f, 0.07249248f, 0.31248692f, 1.77627432f, 0.97660398f, -0.42095289f, -0.18750280f, -0.84246057f, 0.29762223f, 1.87054563f, -1.46980762f, -0.45306337f, 1.52366042f, 1.39061129f, -0.04980387f, -0.55382830f, -0.96987218f, -0.06910808f, -0.41276473f, -0.83891344f, -0.92597574f, 0.60252470f, 0.21938549f, -0.04451685f, -1.00330937f, -0.36955237f, -1.52876902f, 0.27296364f, -1.96721256f, 0.05291027f, -0.91540521f, 0.48990685f, -1.99560380f, -0.68551093f, -0.14532298f, -1.56881595f, -0.08319287f, 0.31003201f, -1.42829597f, -0.61810297f, -0.03581250f, 0.77747720f, 1.25297558f, -1.36239243f, -1.13274276f, -0.35045877f, -2.34157228f, 0.04515179f, -0.83044821f, 1.81353962f, -1.36855912f, 0.39704823f, 0.16665934f, -0.16654585f, 1.17806077f, 1.00086153f, -1.25474250f, -1.46876431f, 1.18021631f, -0.32257929f, 2.12062597f, 0.86819613f, -1.18048275f, -1.69747460f, -0.74092305f, 0.05086798f, 1.15339577f, 1.32972670f, 0.27247882f, 0.98499072f, 2.35597157f, 0.30179837f, -0.66633248f, 0.13794266f, -0.22753908f, -0.22868259f, -1.81792033f, 0.50151759f, -0.79408127f, -1.05343878f, 0.45727381f, 0.84800923f, -1.73605800f, -0.02032863f, 1.82778001f, 1.41025102f, -0.81715560f, 0.25888795f, -0.25075480f, 0.66256499f, 0.11993053f, 1.81336939f, -0.06345166f, -1.49658346f, 0.07531686f, 0.96972889f, 0.87405980f, 0.75830793f, -0.13497087f, -2.45855975f, -0.65984958f, 0.93919373f, -0.97305542f, 0.73477978f, 1.04337513f, -1.22712576f, -0.46385625f, -1.20876372f, -0.82760453f, 0.01455977f, -1.05089867f, -0.02801843f, 0.60899758f, -0.82052249f, -1.48932517f, -0.98073828f, -0.19311285f, -0.25602359f, 0.50351876f, -1.24557400f, -0.82138073f, -1.45966852f, 0.44991320f, -0.75550151f, -0.98550314f, -1.21418869f, -1.15771639f, -1.72192061f, -0.39616469f, -0.55566746f, -1.31880891f, -0.08843257f, 1.00422776f, 0.35846478f, 0.46060917f, 0.77326930f, 1.60129988f, -1.85124147f, -0.30582917f, 1.30227256f, 1.81890345f, -0.44084981f, 0.25315762f, 0.70259613f, -0.94882858f, 1.97040296f, 0.71473581f, -0.68193883f, -0.36290962f, 1.16348684f, 0.15418798f, 1.07806778f, 0.40554729f, 0.10280909f, -1.06474805f, 0.64398485f, -0.63568884f, -0.06108581f, -1.03290677f, 1.02834034f, 1.15284693f, 0.14046004f, 1.86630619f, 0.46804786f, -0.68397558f, 1.60733378f, -1.64890087f, -1.03819239f, -1.19212389f, -0.78382361f, 0.03925850f, 1.52259934f, 0.09540676f, -0.21220762f, 0.55955195f, -0.39845437f, -2.14541650f, 0.49337825f, -0.68574250f, 0.74040270f, 0.50783634f, -1.60461199f, -1.26806450f, -0.12652303f, -0.83992827f, -0.15524681f, 0.40098447f, 0.23392735f, -0.23262636f, 0.06525709f, -0.35994548f, -1.08432877f, -0.21395946f, -0.78357452f, -0.57157278f, 0.71407390f, 0.86596155f, -1.13723528f, 0.13460183f, -1.20881450f, 0.71018457f, 0.68943661f, -0.70428050f, 0.64600736f, 0.01990297f, -0.10575775f, -0.80263519f, 0.10618331f, 0.08865548f, 1.51651669f, 0.60851854f, 1.15161908f, 1.04919207f, 1.18359745f, -0.04352076f, -0.83643389f, -0.07922365f, 0.10597949f, -1.34984851f, -1.91319740f, 0.71585363f, -2.10845160f, 0.64385056f, -0.54551518f, -1.02039802f, -1.62510490f, 1.65401149f, -0.42711899f, 0.07970079f, -0.21404363f, 0.30498922f, 1.07942021f, 0.63995659f, -1.82114816f, 0.56396323f, 1.07084870f, -2.00350380f, 0.53339815f, 0.18500003f, 1.15034151f, -0.21436051f, -0.99986565f, -0.58812016f, -0.07247020f, 0.78910017f, 0.48839527f, 0.98795873f, 0.10357288f, -0.05604928f, 0.38977858f, 0.73745090f, 1.40838420f, 0.25967824f, 0.23588051f, -0.03451392f, 1.04897523f, -1.77121758f, 2.35625434f, -0.67086869f, -0.84005541f, -0.85940343f, -1.04449213f, -0.65917015f, -0.78713167f, -0.95910054f, 0.38597879f, -0.31879017f, -0.86260867f, -1.08593106f, 0.02802678f, 0.99484950f, -0.55113328f, 2.60936737f, -0.03388772f, -0.47583574f, -0.14021793f, 0.99019170f, -1.22431207f, 0.78734446f, -1.77037835f, 0.15018673f, 0.36423206f, 1.36447549f, -1.61007094f, 0.51875496f, -1.60788095f, -1.73557448f, -0.41414359f, -0.93710536f, 0.38715765f, 0.04243837f, -1.59682858f, -1.10728157f, 1.88292623f, -1.01428258f, 0.01074958f, -1.88169158f, -0.31616244f, 0.45334938f, 1.12449574f, -1.16699445f, -1.59505820f, 0.04126552f, -0.89016622f, 0.45838884f, 0.71463561f, 0.14563711f, 0.30694655f, 0.67193079f, 0.61429602f, 1.00201404f, -0.49295208f, 0.05997690f, 0.99491668f, -0.73801446f, -1.17185295f, 0.94778723f, 0.36106884f, -0.43561545f, 0.04102699f, 0.52626407f, 0.08442099f, -1.57626402f, 1.56855237f, -1.65396678f, 1.74014664f, -0.38219589f, 0.39305371f, -0.31705827f, -1.15742850f, 0.11669596f, 0.54043210f, -0.52270615f, -0.13375773f, 0.68094701f, -1.84134769f, -1.49383473f, 0.14632171f, -0.54607725f, -1.20867658f, -1.28439069f, -1.81734920f, 1.54257309f, 0.78347659f, -0.24049839f, 1.69973648f, 0.99825776f, 0.99971974f, -0.26055810f, 0.34143049f, -0.44862366f, 0.11253342f, -0.60932243f, 0.70383030f, -1.87318194f, 0.21953633f, 0.82791799f, 1.64545465f, -0.42693698f, -0.64897031f, -0.97996652f, -1.06616282f, 0.52939081f, -0.12541170f, -0.57480675f, 0.73600835f, 0.35711968f, -0.03528263f, 0.79997194f, 0.55742902f, -0.28909785f, 0.64331138f, -1.79893720f, 1.01572442f, 0.27111965f, -0.51778597f, 0.12906317f, 0.76148927f, 1.51315522f, 0.41101140f, 0.38008851f, 0.66759896f, -0.13804778f, 0.64854795f, 1.73474562f, 0.75999504f, -0.73411214f, -0.05406699f, 1.35664344f, -0.25298578f, -0.12696666f, -0.42628938f, 0.61129904f, 1.55259824f, -0.05820796f, -0.38598019f, -0.87325627f, -0.55066222f, -1.24557889f, -0.26509118f, -0.32103062f, 1.14031804f, -0.75985742f, 0.70659167f, -1.15016067f, 1.24906838f, 0.90396994f, -0.16241251f, 0.43682271f, -1.42695689f, 0.47134697f, -1.66143429f, 0.08698819f, -1.00775325f, -2.24129725f, -1.04226267f, -0.98537570f, -0.89938259f, -1.80710697f, -1.22866321f, 0.78125423f, 1.55150509f, 0.46235040f, 0.18444096f, 0.19313288f, -2.20686269f, -0.40341458f, 0.50321484f, 0.47339424f, -0.81383848f, -0.21972439f, 0.66612029f, 0.60239881f, 1.20443010f, 0.70015103f, 0.30632916f, 0.01489905f, 0.68129027f, -0.89645082f, -2.68969011f, -0.96684915f, 1.66421318f, 0.74333072f, -0.78321886f, 1.60063362f, -1.27524030f, -1.95856726f, 0.47504124f, 0.15398432f, -0.20796098f, -0.13449343f, 0.93458968f, 1.60390890f, 0.21798505f, -0.27035928f, -1.23248971f, -1.25361061f, 1.34666133f, 1.07233441f, 0.88799530f, -1.23687923f, -0.40781614f, -0.11916534f, -0.88050151f, -0.66422415f, -2.61471510f, 0.78276747f, 2.42323995f, -1.70715427f, 0.71550035f, -0.60298312f, 0.70491880f, 0.46175584f, 0.80827898f, -0.45108104f, -0.98219043f, -1.72823501f, 1.73190725f, 0.53906441f, -1.50445580f, -0.59250867f, -0.07239901f, 0.44743437f, -0.13740127f, 1.69935930f, -1.00480616f, -0.58191377f, 0.39853972f, -0.60960841f, -0.45473522f, -0.76396072f, -0.31872150f, 1.74509728f, -0.59950751f, 0.89810580f, -0.81400329f, 1.14280319f, 1.11165059f, -1.31295311f, -1.60784578f, -0.87506992f, -1.13461006f, -2.09486437f, -0.16449419f, -0.37728927f, 0.47595578f, -0.55342919f, -0.17574213f, 2.21499181f, 1.14331865f, -0.14938518f, 0.18935619f, -0.33802557f, 0.52538890f, 0.82673949f, 1.16562462f, 1.24713838f, 0.98890215f, -0.64991701f, 1.49886703f, 1.97769642f, 0.08059916f, -1.60925281f, -1.23822486f, -1.40829837f, 0.51331180f, -0.29928651f, -1.04348791f, -0.39911583f, 0.69380492f, 1.54516888f, 1.22791195f, 2.25008130f, 1.33348894f, -0.21775827f, -0.71937007f, 0.54982573f, 1.70691478f, 0.32459491f, -0.57187974f, -0.21614684f, 1.08274269f, 0.41384646f, 0.24497485f, -1.43703413f, 0.89616930f, 0.82032162f, -0.24598582f, 0.84271127f, -0.81894702f, -0.01828136f, 1.70397091f, 0.39505738f, -0.51221430f, -0.87979966f, 0.10795479f, 0.45194778f, -0.76008922f, 1.23394477f, -0.56798172f, 1.06459570f, -0.44333413f, -2.40399075f, -0.37267187f, 1.42946172f, 0.95734519f, 1.86127949f, -0.15217264f, 1.68742633f, 1.97638428f, -0.44211119f, -0.98393327f, -0.54173928f, -1.72017395f, 0.74697793f, -1.77827263f, -1.92299354f, -0.17189410f, -0.48633271f, -2.21230388f, -0.45906609f, -0.53493047f, 0.37253976f, -0.56951141f, 0.07728028f, 0.03530006f, -1.18123293f, 1.94158125f, -1.55930352f, 0.69334733f, -1.95163214f, -0.95800400f, -0.01804711f, -0.56747472f, -0.99099451f, -1.52853060f, -0.98279524f, -1.67307866f, 0.96121490f, 0.35654056f, 1.74034202f, -1.44633865f, -0.27781928f, 1.79457986f, -0.41029963f, -0.76871634f, 0.36555341f, -0.77664107f, 0.19535238f, -0.76185411f, -0.19828433f, -0.88820636f, 0.63885397f, 0.11346363f, -2.50265074f, 0.16319332f, -1.01288569f, 1.86605489f, 0.89761645f, 1.11795115f, -0.00714116f, -0.89034635f, -0.76447034f, -0.18822117f, -0.48340848f, -0.99788517f, 1.02172959f, -0.39395007f, 0.72566581f, -0.81438208f, -0.71715081f, 0.96243578f, -1.36424279f, -1.13870537f, 1.17602491f, 0.16320205f, 0.71959788f, 1.66669416f, 0.55690295f, -0.28912008f, -1.19219172f, 0.23308393f, -0.37963116f, 0.45347008f, -0.42606446f, 1.30938649f, 1.25128853f, 0.57649273f, 0.34440875f, -0.23893952f, -1.06604803f, 0.31336102f, 0.75727910f, 0.46772480f, -0.37650385f, -0.06036821f, 1.03686309f, 0.46158856f, -1.81028461f, 1.43393028f, 0.85494965f, -2.34685564f, -0.17571987f, -0.45592231f, -1.31190526f, 1.73194158f, -0.11856517f, 0.07041293f, 0.25689471f, -0.56000596f, 2.06649089f, 0.38954756f, 1.36627376f, 0.13905638f, 0.77370811f, 0.43944249f, -0.08798827f, 0.07245751f, -1.30234015f, 0.29710820f, 0.74389762f, 0.11971968f, -0.07381748f, 1.32652700f, 1.34079397f}); + + auto input2 = NDArrayFactory::create('c', {3, 4, 4, 5}, {0.98114507f, 0.96400015f, 0.58669623f, 0.60073098f, 0.75425418f, 0.44258752f, 0.76373084f, 0.96593234f, 0.34067846f, 0.57962620f, 0.77517051f, 0.97472977f, 0.79237527f, 0.68690428f, 0.21719366f, 0.79959206f, 0.84814187f, 0.22496814f, 0.08646965f, 0.31110474f, 0.79813162f, 0.19661444f, 0.57760099f, 0.72138960f, 0.15244268f, 0.87687051f, 0.11130344f, 0.01087698f, 0.34817841f, 0.54992017f, 0.23443850f, 0.31725614f, 0.59755220f, 0.20364695f, 0.00531392f, 0.23403114f, 0.07442912f, 0.83707647f, 0.89291743f, 0.09044587f, 0.69041462f, 0.29904183f, 0.61904680f, 0.85306847f, 0.34467042f, 0.95839152f, 0.54517124f, 0.29640937f, 0.94855959f, 0.95970016f, 0.94045145f, 0.95510301f, 0.34666505f, 0.34717010f, 0.69245678f, 0.71669175f, 0.59043738f, 0.64924132f, 0.06033522f, 0.60185199f, 0.04690073f, 0.59241154f, 0.40229547f, 0.23002481f, 0.45161195f, 0.73743778f, 0.93209113f, 0.37294358f, 0.50177744f, 0.15072501f, 0.26146917f, 0.05252146f, 0.04758931f, 0.76448288f, 0.85149045f, 0.08840467f, 0.07692576f, 0.33180160f, 0.27241259f, 0.74834620f, 0.56453640f, 0.23057286f, 0.68429752f, 0.11961551f, 0.39045977f, 0.44356094f, 0.77018807f, 0.07984410f, 0.47926806f, 0.26165759f, 0.18606064f, 0.89972877f, 0.17962874f, 0.47273120f, 0.64641705f, 0.61890443f, 0.58730015f, 0.25937832f, 0.35231561f, 0.10243882f, 0.17459193f, 0.95906995f, 0.09227025f, 0.30003223f, 0.41601210f, 0.38269713f, 0.84799751f, 0.59295173f, 0.76277990f, 0.68910424f, 0.37672606f, 0.40675461f, 0.94346058f, 0.91438505f, 0.84728183f, 0.64367667f, 0.74899979f, 0.60570691f, 0.16417363f, 0.68852426f, 0.85486889f, 0.22585792f, 0.86953176f, 0.07465519f, 0.93096301f, 0.38008822f, 0.38752587f, 0.44004038f, 0.13170612f, 0.94541045f, 0.89349973f, 0.69245307f, 0.94978877f, 0.98776658f, 0.79445884f, 0.30607409f, 0.58264961f, 0.37980538f, 0.41810784f, 0.48903038f, 0.51615888f, 0.57682794f, 0.82481897f, 0.78341080f, 0.48446465f, 0.17447931f, 0.71125424f, 0.30263851f, 0.70675352f, 0.03215584f, 0.92381065f, 0.22343694f, 0.08851149f, 0.91402490f, 0.70074717f, 0.30912192f, 0.37723206f, 0.97579397f, 0.23554587f, 0.95939133f, 0.41565709f, 0.01741416f, 0.58362787f, 0.22106662f, 0.89065537f, 0.31900249f, 0.41280911f, 0.67947610f, 0.04545590f, 0.15352812f, 0.85412524f, 0.84933222f, 0.80000225f, 0.93147073f, 0.70094105f, 0.69269875f, 0.95282194f, 0.65913582f, 0.79186874f, 0.59855248f, 0.39707430f, 0.95126239f, 0.15618217f, 0.33446689f, 0.98123758f, 0.84770758f, 0.98081012f, 0.54427413f, 0.18728519f, 0.89792955f, 0.53360126f, 0.72812986f, 0.13307744f, 0.51217443f, 0.66708084f, 0.29416915f, 0.31298995f, 0.39155037f, 0.29288291f, 0.87063305f, 0.61759154f, 0.73723332f, 0.37167635f, 0.82122716f, 0.22937430f, 0.76570536f, 0.47911792f, 0.02826214f, 0.94277323f, 0.59945469f, 0.19042060f, 0.68173155f, 0.82771295f, 0.95649538f, 0.40833101f, 0.90838542f, 0.55245881f, 0.49011012f, 0.36773444f, 0.34513527f, 0.42050683f, 0.16113964f, 0.30969388f, 0.27174174f, 0.12117655f, 0.35270175f, 0.81967867f, 0.63723136f, 0.84309389f, 0.71822576f, 0.84883484f, 0.32306117f, 0.08176457f, 0.56175486f, 0.34892198f, 0.09306929f, 0.85437582f, 0.13925577f, 0.48629188f, 0.29923539f}); + auto exp = NDArrayFactory::create('c', {3, 8, 8, 16}, {5.98743296f, -2.83037376f, -0.87943113f, 1.41339970f, 1.32433391f, -1.20299149f, -0.02893090f, 2.05326009f, 1.19417048f, 5.58212376f, 3.28139353f, 1.19237995f, -1.09431255f, -2.55264497f, 3.11014652f, 6.81296825f, -2.09029293f, -4.32068443f, -0.52808392f, -1.97968531f, -0.18673831f, 0.84605980f, 4.55825520f, 2.71503139f, 0.15210046f, 0.85310984f, -3.82062817f, 2.76470995f, 3.69004202f, -1.45017099f, -2.59361267f, -1.35094655f, 7.24145126f, -5.25432396f, 0.19920218f, -4.30596399f, 1.35318923f, -3.88142037f, 3.67493343f, 2.25931478f, 2.87630725f, 1.66349852f, 6.21347952f, 0.94105923f, -1.61742055f, -2.35699606f, 0.12850338f, 1.79141688f, -2.09535933f, -6.35418081f, -0.06303531f, -4.38615131f, 0.48237842f, 0.26528549f, 3.38231516f, 3.76315165f, -0.40254810f, -0.23716694f, -6.13381910f, -0.41950428f, -0.89680839f, -1.46491277f, -1.98541689f, -0.99357355f, 5.58237648f, -2.38937521f, -0.00872564f, -2.37138414f, 4.91117287f, -4.51916361f, 0.97943687f, 2.91052818f, -2.50362611f, 1.70252812f, 5.04137802f, 3.57108784f, -1.87532270f, -3.66677809f, -2.38861251f, 5.55765152f, -7.27571774f, -1.68887305f, -0.72266489f, -4.42809057f, -0.92118186f, 1.02381468f, 4.44284725f, 5.17150497f, -0.42438728f, 2.02693963f, -1.36484981f, -1.47912180f, 0.26649538f, -0.02091765f, -2.86906910f, -3.03046989f, 1.35122132f, -3.21707630f, 2.21112418f, 0.24121630f, 3.96940088f, -7.66105747f, 2.76352382f, -0.99061489f, -2.16720009f, -1.63170409f, 1.12701774f, -1.02415371f, -0.90435314f, -1.51372027f, -0.76884907f, 0.39066136f, -0.89562428f, -2.03204703f, 1.28074932f, -2.14551091f, -2.36843777f, 0.46580017f, 0.75451565f, -0.00336730f, -1.06597757f, 3.27195978f, -0.41307712f, -0.10376054f, -1.34102952f, -2.22901654f, 2.31929803f, 1.40851438f, -2.23774385f, 0.20417206f, -1.12153268f, -0.13188094f, -3.96649432f, 2.10269976f, 0.49845099f, 6.18937683f, -0.51783508f, -0.48048639f, -1.92970264f, 3.16670656f, 1.13355756f, -0.07890664f, 1.31536257f, -0.43924797f, -0.04562932f, -0.87974954f, 0.75411212f, -2.39745235f, -3.97132111f, 0.37202546f, -2.40399146f, -1.50796390f, -3.08302689f, 0.23075986f, -0.94316757f, 1.34948587f, 0.58591264f, 2.18529797f, 7.97652435f, 2.32798409f, -4.09404373f, 0.89634895f, 0.77697754f, -0.65091681f, -7.05506849f, 5.86194515f, 2.51394033f, 4.69959354f, 0.20835471f, 3.18049693f, -1.29682434f, 3.70832396f, -0.48123091f, -1.67904007f, -1.35418940f, 1.58435583f, -1.13851106f, -1.19225955f, 0.59713769f, -5.80462933f, -7.45143986f, -1.08658695f, 1.03244078f, -1.75307107f, -7.07100582f, 3.85825157f, 1.62127817f, 2.32572675f, 0.56171900f, -0.80591971f, 3.98835945f, 0.15742642f, -2.97832179f, 0.13821673f, -0.72556758f, -0.84936106f, -7.28444147f, 3.94134307f, 0.80779338f, 7.47784615f, 8.23335075f, 4.80595016f, -4.89574575f, 4.03362942f, -6.67522192f, -4.55204487f, 2.12511182f, -2.70781207f, -1.57226098f, -3.08408356f, -0.30812448f, -5.32870674f, -5.13238287f, 0.49605465f, -0.55042171f, 0.46324944f, -3.83545256f, -0.12562510f, -0.20978995f, -0.13068712f, -1.92144060f, -1.68787408f, 5.45581436f, -0.79583496f, -2.38866687f, -3.90546346f, -0.47028148f, -0.14319679f, -3.37016582f, 2.00905991f, -1.21345615f, 1.81376505f, 7.73004007f, 0.74310112f, -4.64536428f, 3.78111577f, -9.05182457f, -0.10674095f, 1.53476238f, 0.63345337f, -0.40907967f, -1.44729769f, -1.87145400f, -2.46623540f, 1.07472968f, 0.77390999f, -3.93438888f, 4.49174690f, -0.96686655f, 1.92278123f, 0.30049133f, -0.02388665f, -1.99777114f, -3.23885751f, 5.87784004f, 2.13776040f, 3.56758308f, -3.37774134f, -3.67526293f, 1.63700044f, -1.69959962f, -0.99112594f, 6.03103638f, 1.67399430f, -1.28699589f, 7.16759014f, 12.63490295f, 3.62937450f, -4.75982571f, 2.17861104f, -2.03065681f, 4.30207729f, -0.46797156f, -2.96022511f, -6.02702332f, 3.09229851f, -1.39771092f, -0.03471333f, 3.22175527f, 5.63565636f, 1.78195477f, -0.63545251f, -3.99497652f, 1.46043062f, 4.60050488f, -2.96651959f, -2.03159475f, -1.52386189f, -0.15129802f, -3.90390921f, -0.63852370f, 0.79210538f, 2.35288715f, -5.55609035f, 5.36427498f, -0.60248077f, -0.26181316f, 5.04884720f, 8.53192806f, 5.05080223f, -6.56371737f, 1.52260923f, -7.13623667f, 6.49414349f, 2.33445597f, -4.11490965f, -6.44347477f, -0.47079402f, -0.63467920f, 2.60399365f, 1.05958164f, 3.66901422f, -1.05657935f, 1.88611507f, -6.37475634f, 2.01480770f, 3.36020517f, -5.11001921f, -0.46132171f, 2.16525555f, 4.21938848f, -2.08346295f, 2.86168146f, 1.26987600f, 6.76066971f, -7.84916353f, 4.11700916f, 0.47985530f, -4.60113716f, 7.42062473f, 6.37472820f, 4.37820530f, -7.12197018f, 0.01357239f, -7.90392113f, 8.32131577f, -0.87593079f, -0.16994858f, -5.86345863f, -0.20697471f, -1.37845206f, 1.63819647f, 1.59720242f, -0.74357712f, -1.88725603f, -1.98357940f, -8.57950306f, -4.10104513f, 3.57231879f, -2.89855957f, -0.11263305f, 2.78033924f, 1.53078973f, -2.93089223f, 0.73189604f, 3.20563078f, 3.92601013f, -5.21916151f, 0.89163935f, -0.42978728f, -6.70888853f, 4.56477976f, 1.20105875f, 3.83393812f, -6.27205181f, 4.05993128f, -7.35513067f, 1.60660768f, -1.21052051f, 1.58191252f, -1.37899971f, -1.20117283f, 2.93301678f, 1.06302834f, 1.38993621f, -1.66884089f, -3.34452581f, 1.04498529f, -4.10412455f, -4.03310585f, 1.61513603f, -1.09388447f, 2.11451387f, -0.94192362f, -0.23287666f, 5.88265705f, -0.83010495f, -2.15317154f, -0.60276151f, -1.49265075f, 3.93397975f, 5.45194483f, 1.45161700f, -2.57401872f, -5.59288931f, 4.29170895f, 1.87151814f, 0.08362055f, -0.28767288f, 1.17675185f, 0.85266006f, 1.30549634f, -5.60830832f, 0.19398519f, -0.83982587f, 1.75940764f, -5.46077394f, 1.64495635f, 0.17102760f, -0.54459631f, -2.21975255f, -0.37443402f, -2.08474159f, 1.85959935f, 11.19680309f, -0.18611598f, -2.59765387f, 3.06330776f, -1.52183700f, -4.88415241f, -0.75097847f, 2.58201051f, 7.40885210f, 3.58994508f, 1.62457407f, 3.12514591f, -4.36833286f, 1.39830995f, 3.61003447f, -0.63837433f, -3.62661815f, 3.78898096f, 2.92802262f, 5.87374496f, -4.38554621f, -2.53411579f, -2.87311554f, -1.31391978f, -4.26736879f, 3.45099425f, 1.58769250f, 1.73341393f, -1.08842182f, 2.27120280f, -1.78938174f, -2.29940319f, 7.07046986f, 0.51426595f, -6.22928905f, 5.28968811f, 2.31827855f, -4.20915890f, -1.27249205f, 5.92120600f, 3.19458675f, 7.09252501f, 3.96577907f, 6.41484213f, -4.66009521f, 10.00181389f, 0.51108456f, -4.62243366f, -5.18351841f, 2.12961674f, 5.10694027f, 7.29412317f, 0.15912467f, -3.38902974f, -4.01918602f, -2.17383957f, 0.13118666f, 0.27872476f, -0.92317247f, 3.51440644f, 1.84171486f, 1.03378081f, 1.30569839f, -2.09583759f, 9.03952980f, -0.55187917f, -2.04549074f, 1.08294606f, -2.65263700f, -2.93977118f, 1.88909876f, 0.96043622f, 1.76579499f, 3.14314699f, 5.86394691f, 7.36944389f, -7.04524136f, 6.68673229f, -5.52591467f, -2.19745898f, -4.32036924f, 0.52971321f, 2.26268244f, 6.91575766f, -0.94590527f, -3.98923349f, -0.12266219f, 0.24294075f, -1.07783222f, 1.87989080f, -3.57109427f, 1.61553633f, 0.42486978f, 0.75852054f, -6.19481468f, -3.80570698f, 2.39946675f, -1.93851781f, -5.42234039f, -6.34092760f, -2.52374983f, -1.85044456f, 3.92693520f, 0.40042299f, 4.69742584f, 5.40483189f, -1.02398944f, 8.89605045f, 0.64680403f, 0.89943957f, 0.76993859f, -1.88244629f, 1.90714884f, 3.10836840f, -0.17064989f, 0.84892416f, -6.94988108f, 1.92141032f, -1.36458397f, 6.39284658f, 0.45201308f, 2.58823442f, 6.33375788f, -4.76916075f, -8.45738983f, -0.48962492f, 2.40652561f, 4.56602001f, -3.34420681f, 1.86862195f, -7.01420689f, -6.94657421f, -2.47419310f, -4.61693668f, -0.18822384f, -0.36949772f, 2.01374269f, 4.11018658f, -5.11564064f, 8.04294395f, 2.88567662f, -2.87645102f, -1.23238611f, -5.91409397f, -0.62205851f, 1.38689423f, -0.01120412f, 5.25955677f, -1.98474956f, -3.72012186f, 3.00445986f, 4.99141550f, 2.97457719f, 2.70827627f, 6.04544449f, -0.20756161f, -10.87035751f, 0.80454814f, 0.33568168f, -2.48132324f, -2.84452009f, 2.63126230f, -3.99351716f, -7.39294338f, 3.62798953f, -8.65815926f, 2.65992808f, -6.98126554f, 3.09881067f, 0.67735767f, -1.15946686f, 5.63180256f, -0.17694545f, -8.59651184f, 3.75297594f, -2.35913754f, -0.20330384f, 5.49958467f, 1.00861740f, 1.42849684f, 0.00062013f, -0.11073381f, 2.15207863f, 4.07368469f, 1.14344299f, -1.27953362f, 6.64699316f, -0.73672432f, -8.55606937f, -0.19439441f, -4.14319754f, -4.69964647f, -5.86446047f, 2.87106085f, -3.42714882f, -5.00668287f, 6.22464132f, -7.72335291f, 4.05667686f, -5.72637177f, 6.35073948f, -1.29593158f, 0.00813985f, 3.63368607f, -1.05764008f, -7.88486052f, 3.73919106f, 1.41835213f, -1.04935634f, 0.65119827f, 0.03547254f, 1.88996327f, 1.58701086f, -0.56215239f, -0.80187100f, 4.55604362f, -0.67249978f, 1.41084409f, 7.86281586f, -2.38301182f, -8.50535774f, -3.82098866f, -2.40856767f, -5.33439016f, -3.34747362f, 2.69389009f, -1.64118791f, 4.52447939f, 0.04468334f, -1.48768258f, -0.69848812f, -0.71123981f, 3.66259432f, 6.10314512f, 1.37305343f, -0.62758982f, -2.99383426f, 4.20510864f, 1.48497128f, -0.08954811f, 2.43872309f, -0.59880185f, 0.37431365f, 2.45458341f, -3.28401661f, -1.94629693f, -1.93975246f, -0.26385683f, -0.45814323f, -0.18108580f, -3.74811840f, -0.29739976f, -2.24116230f, -0.28150487f, -2.24421668f, 3.46930790f, 8.35415077f, 0.05562943f, -2.81079793f, 1.10388446f, -2.82245207f, -2.98102283f, -1.08132946f, 1.19089699f, 8.00183105f, 6.35385323f, 3.72591257f, 4.59467506f, -5.74890900f, 4.42238331f, -3.36533451f, 0.18350232f, 3.05606651f, 1.18788099f, 2.87450886f, 0.27472210f, -2.80111074f, -0.66314960f, -1.96376896f, 0.75167024f, -4.72056293f, 1.10629988f, -5.00775242f, 1.48246133f, -3.91681528f, -1.86573625f, -6.17714882f, -0.67820001f, 5.69730282f, 1.04399037f, -4.93794823f, 3.09619617f, 2.18692017f, -5.54232264f, -3.10046840f, -0.68972743f, 2.81824327f, 3.04334164f, 6.13203907f, 4.14081764f, 1.02573645f, 5.71970081f, -6.01574707f, -2.07346702f, 0.99554527f, 1.69641590f, 0.66776669f, -0.80132431f, -2.03513098f, -3.42513680f, -0.06704485f, -1.87195873f, -5.42428589f, -0.20748445f, -1.52408111f, 0.97084987f, -0.48799962f, -0.45379883f, -0.26652339f, -1.20720732f, 3.94169855f, -3.18480229f, -1.87440264f, -1.18028760f, 0.52011997f, -2.13437462f, -4.52583313f, 1.69722807f, -0.89371562f, 3.37972403f, 6.38838720f, 6.98663378f, -4.05421400f, 6.89512825f, -5.09085655f, -2.16257906f, -3.33272719f, -3.01246452f, 0.37613097f, 1.80455804f, -0.36456174f, -5.32273912f, -1.29978943f, -0.53685790f, -2.12896323f, 2.55506587f, -2.57999182f, 3.40891910f, 1.36033249f, 0.83864629f, -2.88629293f, -7.36048365f, 5.61314154f, 1.32668555f, -2.58041072f, -3.71943092f, 1.60647738f, -2.74816346f, 2.47269106f, 0.85507953f, 8.39183426f, 3.42624784f, -0.01519036f, 5.68412066f, 2.51771593f, 1.03045523f, -2.08733034f, -2.44337177f, 0.81668580f, 1.30275154f, 2.99679208f, -2.91957355f, -1.71337795f, 3.34979844f, 1.51825011f, 5.20375061f, 2.27888370f, 1.38787699f, 4.23474550f, -4.05878592f, -4.85074377f, -0.22794735f, 4.64402294f, 1.24391258f, -2.04935098f, 1.26285601f, -7.51862240f, 0.62138438f, -1.95792389f, -0.96587181f, 0.85141110f, 0.79354531f, 7.93766356f, 6.07677746f, 2.05947518f, 6.55480623f, 1.44032848f, -0.70615625f, -0.07896036f, -5.08359432f, -0.01047915f, -1.89632201f, 2.57555676f, 3.83779287f, 0.42850614f, 1.80754125f, -0.06942326f, 6.35997963f, 6.06101418f, -0.97032297f, 5.71477222f, -6.06671238f, -3.46607208f, -4.98306370f, 2.84659123f, -2.11025190f, -0.04609144f, 5.26831341f, -9.56940651f, -3.67193556f, -1.71143103f, -1.35221267f, -4.26226807f, -6.89146233f, 8.21761799f, 5.69823503f, 2.28137946f, 1.88911343f, -1.44562483f, -1.60295713f, -0.52568185f, -3.31892347f, -2.81997776f, 0.35287106f, 2.98202395f, -1.39432132f, -2.70001364f, -4.14169264f, 3.50194883f, 4.12610435f, 5.52755260f, 2.65859175f, 3.61353087f, -0.83027136f, -5.10652542f, -4.48625374f, 2.06585884f, -2.76383352f, -0.64300913f, 8.19686604f, 0.96106279f, 2.45952058f, 2.47275925f, -1.03288829f, -0.64897656f, -3.77937531f, 4.27940083f, 2.58320260f, -0.57665241f, 1.87247813f, -3.81604433f, -0.24543774f, -1.62118483f, -0.73075479f, -0.48533297f, 2.05016756f, 0.45561486f, 0.03316188f, 0.77791005f, -1.56283605f, 2.36616826f, 5.58082104f, -1.30925488f, -1.06329608f, 2.17189479f, -3.43008828f, -4.71520567f, -2.56184673f, 0.17508316f, -3.25817418f, -0.41749167f, 0.18119079f, -0.73181152f, 3.99792433f, -3.08002281f, -0.99143314f, -1.83520067f, 1.18565679f, 2.98040128f, 5.67814350f, 2.35128760f, 1.41600966f, 4.02718067f, -0.08193968f, 0.64636409f, 1.35931289f, 2.37125754f, 1.75978124f, 3.90977740f, 1.50662971f, -2.84089065f, 1.29824126f, -3.38730979f, -1.61005294f, 0.58292413f, -0.03019404f, -1.57986510f, -0.56102908f, -3.03128719f, 0.51644313f, -2.01147819f, 0.98400700f, 3.00028515f, 0.74579155f, -3.37098312f, 0.93339360f, -1.29018497f, -2.14695001f, 1.30411184f, 0.71501279f, 7.47793055f, 4.06516457f, 3.50772929f, 3.52762985f, 0.55643129f, 0.32272506f, -4.30955982f, 2.49414706f, 2.07820845f, -0.34377906f, 4.39805031f, 2.77561307f, -3.91292810f, 2.43981409f, 0.18861845f, -2.76658440f, -4.97148752f, 3.25273705f, -0.08929539f, 0.19818619f, -5.83767605f, -0.97381884f, -5.68745661f, -5.42433214f, 3.98769903f, -0.40394354f, -1.83387578f, -0.80109525f, 1.47454357f, -3.14899540f, 0.80130816f, -2.26348829f, 4.06121159f, 6.13077354f, 5.31226397f, 2.94966197f, -3.65217376f, -1.08136678f, -7.14119816f, -0.85269439f, -0.70365787f, -0.81598872f, 3.62807679f, 3.08123684f, -7.82739496f, 4.07951784f, -0.14204243f, -0.66969109f, -5.07225513f, 2.88492823f, 0.47202343f, 0.72683257f, -6.84280777f, 0.41807127f, -5.09785986f, -3.74514675f, 2.03936672f, -1.06096244f, -1.52409148f, -0.97046643f, 2.27491093f, -1.55597985f, -1.29215479f, -0.79737484f, -0.01979581f, 7.65407991f, 5.54527044f, 4.04147148f, -2.64274883f, -1.89246953f, -3.89547634f, -1.06029689f, -2.85982800f, -1.41247237f, 1.55836034f, 3.38194537f, -2.97655582f, 0.87510300f, 1.26282072f, -1.77029657f, -3.57144690f, -4.19456863f, 0.53179169f, -1.42221975f, -3.09144497f, -0.84294832f, -5.02758694f, -2.68011904f, 0.89156240f, -0.34783912f, 4.64484835f, -2.34453487f, -1.28573155f, 0.09990287f, 0.01828218f, -1.79960847f, -1.06579173f, 1.08763921f, 0.43687880f, 3.24747229f, 3.83097172f, 1.07253766f, -1.33810723f, 0.76530832f, 1.58660865f, 5.60743904f, -3.54124737f, -0.89264417f, -3.83942485f, -1.03707337f, -1.61659896f, 1.65349591f, 1.72698796f, 4.96013832f, 0.78927267f, -0.35563886f, -3.48121166f, 3.79677629f, 2.59023166f, 2.74940348f, -2.17589283f, -5.91757107f, 2.43766379f, -4.15906048f, -1.74731481f, -2.49113035f, -0.57349741f, -4.04455185f, -1.46939647f, 2.21418452f, 0.09153593f, 2.23016739f, 7.91880608f, 4.04464149f, 0.07706618f, -2.41892862f, -2.19280314f, 7.61760712f, -5.89153862f, 0.33551922f, -1.70855618f, -0.30561331f, -0.14341974f, -2.48878574f, 1.31269515f, 3.45388412f, -0.02453184f, -0.12132037f, -4.27916241f, 1.25179088f, 4.09455204f, -1.83801770f, -1.86743176f, -4.02864933f, 3.44515228f, -4.39244986f, -0.56988084f, -1.69426417f, 2.18254852f, -4.78135824f, 1.73193693f, -2.27968478f, -1.49523509f, 2.51696730f, 4.03677559f, -2.03679037f, 1.32167840f, -2.22570705f, -2.74843621f, 6.29655170f, -3.67230225f, -1.86765468f, -0.14842367f, -1.21552539f, -0.92038238f, -0.51692355f, 1.08433771f, -0.01929832f, 0.15660909f, 2.31432915f, -3.86507082f, -0.69797570f, 0.13505173f, -1.50951028f, -0.69980979f, -1.51297045f, 3.63725281f, 0.13388813f, 2.73131752f, -0.96528149f, 4.92000961f, -5.92699385f, 1.69444644f, -1.17121375f, -2.33710480f, 1.35302818f, 1.39608085f, 1.68293881f, 0.94960749f, 1.89011908f, -4.08865070f, 0.13722643f, -1.62849212f, -0.19044125f, 1.37906075f, -3.92504406f, -1.45033538f, -0.42085981f, 3.38237071f, -3.06508875f, -1.39420545f, 1.13067436f, 0.92206454f, 0.49917889f, -2.74508023f, -2.19221997f, 1.77914095f, 0.10854459f, -2.62178278f, 2.35042715f, -0.15322030f, -0.67014873f, -1.75627899f, 2.64074945f, 2.76339936f, 2.67275214f, -0.62736398f, 0.58251178f, -4.64895678f, 5.50419283f, 2.53566456f, -2.44196153f, -0.07845879f, -2.80389643f, -0.64810950f, -0.05813205f, 1.67155504f, -2.69673729f, -1.72486305f, -0.53888649f, 1.86805439f, -1.37128329f, -5.37923479f, -2.08133769f, 0.58187997f, -1.39498150f, 0.21874082f, 4.33726025f, 6.29673958f, 0.72312093f, -3.32683516f, 1.73482585f, -0.00766110f, -2.63785434f, -0.13511759f, 4.07195950f, 0.94139838f, 3.15717316f, 1.53720927f, 1.87664819f, -2.33655119f, 6.18176556f, -2.73912525f, -2.45279956f, 2.20392370f, -0.56854641f, 0.98915887f, -2.64472580f, 2.40633702f, -4.93327999f, -1.28942823f, 0.98247659f, 1.31774998f, 0.07669818f, -5.91169453f, -0.43135011f, 1.27404964f, -0.59787154f, -0.22716975f, 0.74409103f, 10.27316475f, -2.29192710f, -2.19403267f, 3.78925133f, 3.19553399f, -4.42490482f, -0.80781460f, 2.16568565f, -2.54165983f, 2.54885101f, 4.18779039f, 1.73079813f, -1.48891807f, 11.60153770f, -0.98686743f, -2.88813901f, 2.32898521f, -0.36101711f, 2.34522438f, 0.29057693f, 1.39800644f, -4.31848240f, -3.21217132f, 0.11740226f, -1.21613467f, 0.57248503f, -4.44853830f, 1.54665899f, 3.14459944f, 1.76809108f, 0.26693153f, 0.86913753f, 9.47121620f, -2.07677889f, 2.08578467f, 1.30181742f, 1.58683562f, -3.52757788f, -1.32763624f, 0.79821301f, -2.19358301f, 1.17707348f, 6.01983643f, 4.11209440f, -2.04209709f, 7.00413418f, -1.84904683f, -1.32542288f, -0.01298118f, 0.70377320f, 0.27815005f, 2.07879829f, -0.71606725f, -4.94399881f, -2.11898828f, -0.39051518f, -2.21034360f, 3.05337906f, -1.56889665f, 1.97065282f, 2.61320901f, -0.34063196f, -0.57001418f, -2.13183641f, 3.48879004f, -0.12067288f, 0.48568326f, -1.81424558f, 2.28868723f, 1.44802380f, 1.25918829f, -1.76415455f, 5.35742331f, 3.50682044f, 4.71371317f, 5.89110756f, 8.51241302f, 4.07391453f, -0.05887252f, -0.18202400f, 2.27119660f, 6.78274727f, -2.87470293f, -5.14336634f, 0.76443815f, 2.04625130f, -0.43199503f, -1.01353514f, 2.42951298f, 2.35641170f, 0.32345510f, -4.04195738f, -4.77967072f, 0.26564783f, 6.11455107f, -2.53868008f, -3.11839914f, -1.04203856f, 5.17195654f, -4.15338612f, -3.84149241f, 0.48130888f, 3.09706950f, -4.18423653f, 5.26233864f, 3.55831861f, 3.75122595f, 8.14969349f, 6.80038738f, 4.68907356f, -1.40135396f, -3.19287133f, -3.15895939f, 8.77363205f, -4.48793411f, -3.80537176f, -2.40145254f, -2.74341679f, -2.02862644f, 5.33402443f, 9.25365734f, 2.50246119f, 0.32847846f, -1.50564361f, -4.26163197f, -1.40994716f, 2.50708485f, 0.44500345f, -0.62516934f, 4.09846306f, 5.29355669f, -4.02224922f, 0.73442125f, 0.46648952f, 0.67028689f, -6.30715466f, 6.56297970f, 3.80854273f, -5.19078207f, 4.98839283f, 7.59161472f, 0.46010983f, -2.10227895f, 0.29324162f, -2.67019558f, 4.57838106f, -3.02338457f, -3.08647728f, -2.00112700f, -3.81710315f, -0.08346784f, 1.69288683f, 5.68807268f, 3.29351830f, 0.54618967f, 1.83540761f, -5.38810253f, 0.51326782f, 4.40081882f, -4.03805828f, 0.49482727f, -1.36024392f, 2.91845679f, -2.00959015f, 2.47489738f, -1.43354976f, 1.92024410f, -6.55897284f, 1.79488957f, -0.89570928f, -6.13094234f, -0.45504010f, 2.35239482f, 1.29039919f, -4.78849840f, -1.52545333f, -6.50420475f, 2.99257326f, -0.55620033f, 0.26807702f, -2.52090979f, -4.59419632f, 0.57965040f, 2.19423151f, 2.04760551f, -0.57048106f, -2.20812702f, -0.04777686f, 1.38053393f, -2.71448946f, -1.06219673f, -3.62008905f, 1.85719645f, 1.28355026f, -2.76315832f, 1.65295160f, -4.01645803f, -3.10454416f, -0.65713316f, 1.22384977f, -0.70416176f, 4.45064926f, 1.31602776f, 2.06907344f, 2.48872757f, 4.25775290f, 3.50504255f, -0.68262041f, 1.29799378f, -1.01969171f, 2.98593879f, 0.12607655f, 0.37219539f, -0.84196299f, -3.80019331f, -1.82315290f, -0.38489276f, -1.45200360f, -4.00882292f, 0.61042011f, -0.16738498f, 1.33787775f, -2.26938057f, 1.03656030f, 8.89089870f, -1.60370600f, -5.38691807f, 5.72182989f, 2.72854710f, -6.18535757f, -3.13408709f, 2.79175353f, 5.18425512f, 9.46434212f, 2.40110517f, 1.11330092f, -3.57366538f, 4.80967665f, 0.40691876f, -3.65484858f, 0.92398167f, 2.53852940f, 3.17747331f, 2.14199781f, -1.69107199f, -1.91864693f, -3.18452644f, -2.42408276f, -2.14332366f, -1.35526609f, -4.50732136f, 0.58234072f, -1.81547785f, 0.57311213f, 1.10584176f, -0.97226644f, 11.73174381f, -2.00559855f, -1.81175601f, 2.33131361f, 0.49264961f, -0.42245382f, -1.37528467f, 1.55768061f, 0.21152198f, 13.08896351f, 10.33674145f, 5.77929306f, -6.19886398f, 5.67007637f, -6.61288071f, -2.58029866f, -4.05192375f, 1.77221894f, 0.29821560f, 5.23508501f, -5.09560966f, -0.97536200f, -5.17957878f, 1.02876794f, -4.52072096f, 2.22126532f, -4.81708670f, 0.44538212f, -2.30738068f, 3.15900373f, -4.99227905f, 0.82632786f, 9.65415478f, -0.63819492f, -3.25479436f, -0.13276935f, 0.21337092f, -2.22116399f, -3.04922724f, 0.65568435f, -0.10706246f, 4.58047390f, 7.80782652f, 5.49080181f, -3.97114491f, 6.43327618f, -6.54772758f, -2.10962629f, -0.79831678f, -0.08316499f, 2.48658133f, 4.14070511f, -0.59806836f, -4.58636141f, -0.31166920f, 0.31757897f, -3.92562199f, 0.65357721f, 0.55871534f, 1.71843934f, 1.62395024f, 0.00695819f, -4.56716251f, -3.76420808f, 4.24979544f, -0.86128616f, 0.23126510f, -6.32968998f, 1.83346081f, 3.81335950f, 2.98407745f, -1.80454743f, 6.61764765f, -1.39372075f, -0.86780751f, 7.24317265f, 2.24205112f, 1.05702817f, 0.55431479f, -1.54557061f, 3.36389136f, 4.70898724f, 1.11327887f, -3.78462076f, -3.63381767f, 2.86510396f, 0.74203897f, 0.81488025f, 3.54250598f, 3.24824381f, 3.19000244f, -0.58995843f, -7.05670738f, 3.18306041f, 3.95191574f, 0.81820154f, -1.91068232f, -2.05426741f, -1.05589008f, -3.18377590f, -1.86278260f, -8.80374908f, 0.93416154f, -4.60517359f, 8.38999462f, 5.26356745f, -8.89992714f, 8.95298958f, 4.22590351f, 1.00351548f, -6.90151119f, -8.07641125f, -4.82450199f, 8.02293015f, 4.11661243f, 0.95457208f, -7.07843113f, -4.30524826f, 5.02697992f, 5.21011686f, 0.80132771f, 3.23420191f, 3.82452774f, -2.13171721f, -7.88879967f, 1.31062031f, 1.90848613f, -3.51572514f, -3.75684500f, 3.62577081f, -5.76075602f, -2.79389215f, 0.32598805f, -4.28981733f, 4.21048594f, -3.84532523f, 3.19815183f, -0.40756655f, -2.19974327f, 6.25655174f, 3.42396951f, -1.88986623f, -1.92803884f, -2.97344875f, -0.09756154f, 5.24342251f, -0.72513700f, 1.06113195f, -1.30720282f, 4.69107103f, 0.58984971f, 2.33985567f, 1.46385121f, 3.16576266f, 6.77769995f, -5.92685127f, -12.61141014f, -2.83663774f, 4.90253258f, -6.32688522f, -3.00096869f, 2.38634992f, -7.21459866f, -5.89208746f, 2.84085894f, -1.21792030f, 6.70161343f, -4.00450230f, 5.29881001f, -1.45574808f, 0.77542424f, 1.38336325f, -0.21572059f, -3.38088870f, 2.33249640f, 0.68824625f, -3.68440270f, 0.33481622f, -0.39239681f, 0.14560902f, 1.61039007f, -3.11967754f, 2.49372435f, 2.68783092f, -1.17559779f, 0.95257235f, 4.35451412f, -0.56818569f, -7.32110357f, -7.58534050f, -2.10573673f, -3.34446383f, -0.32183546f, -0.78525496f, -1.76974547f, 5.19060802f, -2.11319876f, -3.41755080f, -0.36864156f, 1.32680905f, 0.45004874f, 6.17223930f, -1.60707474f, 0.46096295f, -3.88852644f, 1.84729624f, -0.03412050f, 0.99224162f, -2.05553341f, 3.47793245f, -0.06305170f, 0.51314175f, -2.91650558f, -1.78121483f, -2.85465693f, 0.24649808f, -2.70376635f, 0.42334458f, -1.13862336f, -0.98409218f, -0.96593523f, 2.22128963f, 0.53402066f, 3.33979344f, 8.57430458f, 2.34217858f, -2.40062976f, 5.81624222f, 1.13290989f, -5.06850052f, -4.72865725f, 1.82859278f, 6.78569555f, 8.56885242f, 2.76462936f, 0.33891773f, -2.81092787f, 0.79498398f, -2.27208567f, 1.55182552f, 2.17166376f, 6.12517643f, 3.56859684f, 0.27685475f, -1.38408327f, -1.03533340f, -3.46618199f, 0.79240030f, -3.89390516f, -0.55852515f, -1.16367757f, -0.07008934f, -2.20105195f, 3.81210446f, -0.66834474f, 0.43603873f, 10.92334938f, 2.48571420f, -6.34997845f, 4.23135757f, 0.45045292f, -4.13489866f, -3.92324209f, 1.88537407f, 2.57159734f, 9.90973091f, 4.37453461f, 7.34546280f, -2.51120615f, 11.12575245f, -3.23452854f, -2.49947500f, 1.39819741f, -3.78950691f, 2.40617585f, 5.10036278f, -3.55743456f, -6.42888737f, -2.51929998f, -1.90880990f, -1.81618094f, 1.60946512f, -4.09737110f, 1.96408439f, -1.90115595f, 2.44444203f, -2.31254292f, -4.01332951f, 8.65541840f, -0.58626485f, -4.02226830f, 0.43893200f, -3.78272748f, -5.46277428f, 0.01306701f, 0.61185312f, 0.24469066f, 1.30214953f, 5.87789631f, 8.75197792f, -5.31634712f, 3.43556309f, -5.90755081f, 0.54375106f, -2.48162293f, -3.51843548f, 2.55853295f, 5.06387186f, -2.09662485f, -3.00377345f, -3.21781397f, -0.14537808f, -4.65453672f, 1.92747557f, 0.41553855f, 4.09379959f, 0.83387995f, 1.50868511f, -6.54959488f, -8.38881016f, 5.50689125f, -2.88616610f, -1.21597648f, -0.23817590f, 1.50816703f, -2.26873541f, 2.29862142f, -1.61143053f, 5.97371244f, 4.71440220f, -0.20635787f, 8.85926723f, 0.56064367f, -1.04103339f, -4.47060108f, -2.63824081f, 3.06782055f, -2.07702565f, 3.38269401f, -1.59988797f, -3.80122590f, 2.35341501f, 2.69095278f, 3.87612104f, 1.89984226f, 0.95496917f, 3.14841127f, -5.84543085f, -7.24945450f, -2.65708590f, 2.87417006f, 0.97556210f, -3.75203967f, 1.55287778f, -7.43401051f, -1.29005826f, -3.40252638f, -4.01049423f, 2.82721639f, -1.21479535f, 8.54563904f, 7.39749908f, -0.61361837f, 7.60177565f, 1.65812778f, -0.83008504f, -3.60961151f, -7.69062138f, -1.26275063f, -4.17071676f, 5.28448200f, 4.04685593f, -1.18231702f, 1.15276611f, 1.58620787f, 6.75060844f, 3.29332161f, -0.67640316f, 5.78984785f, -3.14913464f, -6.41867924f, -2.58316016f, -2.04366302f, 2.01089478f, -3.81723452f, 3.63843751f, -5.13238430f, -3.79432917f, 4.86581373f, -1.06922054f, 3.95978498f, -0.78166616f, 8.35650539f, 5.35834265f, 0.35594034f, 9.41657066f, -0.84108615f, -6.54425859f, -3.44328952f, -6.55536795f, -0.08963367f, -1.53906262f, 0.17658240f, -0.13108420f, -0.44371247f, -0.78411150f, 2.64754868f, 9.66306782f, 1.70506203f, -0.31588936f, 4.31715870f, -6.16665173f, -10.43371868f, -3.72962189f, 4.35245228f, -1.75867891f, -4.20046234f, 8.62637043f, 1.45946813f, -3.30153608f, 0.85179043f, -2.66643381f, 3.01863337f, -2.52916121f, 8.35405540f, -0.37298933f, -0.89473486f, 6.88681793f, -4.46370125f, -7.50776386f, 3.80255938f, -3.55003357f, 1.43528831f, -2.20383263f, 2.34999895f, 2.03803205f, 1.94830751f, -1.85976326f, 0.97718471f, 5.53710842f, -0.80560827f, 0.23925614f, 5.98795223f, -2.03578377f, -7.77835321f, -2.79955530f, -1.88185954f, -2.49112058f, -0.76095992f, 2.71161270f, -0.55918610f, 0.83789903f, -1.42063200f, -0.61528748f, -4.18273115f, 1.76384258f, 4.21265936f, 5.50964785f, -0.93324339f, 3.83215356f, 1.52210593f, -0.91594946f, 1.31148386f, 3.20160103f, 1.24493563f, -0.72693497f, 1.84716725f, 3.09897518f, -1.34605026f, -1.17511916f, -1.05526352f, -1.08590937f, -1.41319299f, -3.75052118f, -2.67095542f, -0.76179552f, -3.32081509f, -1.04692316f, -1.30194843f, -1.98795474f, 5.01223469f, 0.21895903f, -1.85535169f, 3.12362719f, 0.16198632f, -3.86784005f, -2.03062248f, -0.15415624f, 8.22020721f, 4.83055592f, 4.50315666f, 4.19443417f, 0.42727345f, -4.67786789f, -5.18739986f, 2.53988838f, 3.19683266f, 1.80313504f, 1.94664574f, 0.59795094f, -4.21626759f, 0.50492239f, -0.41232634f, -0.99224532f, -3.94929314f, 1.74060190f, -0.92474866f, -1.00664830f, -6.17397356f, -1.33146775f, -3.78111315f, -4.91876888f, 2.50303864f, -0.34890354f, -1.25013232f, 0.38168997f, -1.84135628f, -4.46107960f, -4.05920792f, -2.61709857f, 0.71046209f, 9.80566883f, 6.34086990f, 2.73394704f, -2.03342366f, -2.21424174f, -5.56514263f, -4.74755144f, -2.20672894f, 0.09010231f, 1.70423889f, 3.19200158f, -6.99027634f, 1.14216340f, 0.05824995f, -0.76996505f, -6.51575899f, -0.41109252f, 0.78229940f, 1.36170781f, -5.65170193f, 1.12221193f, -4.60430050f, -4.40174437f, 4.01805925f, 0.10774946f, -2.77991009f, -0.18023163f, 0.02151692f, -1.77023101f, -1.86639869f, -0.69443607f, 4.92290831f, 6.83520412f, 4.27372265f, 6.54272366f, -7.59249687f, -1.40776849f, -3.52368808f, 1.01398587f, -3.58802676f, -0.35658866f, 1.14716864f, 3.75847244f, -2.30159235f, -0.72130895f, -0.24564353f, -1.77531350f, -3.08677864f, -0.73486501f, -1.20357263f, 0.60789430f, -3.46990204f, -0.20668676f, -5.46096087f, -5.22016764f, 0.98259866f, 1.81012678f, 3.92534304f, -2.94997001f, 1.65154219f, 2.27040243f, 0.99095678f, 0.09144652f, -0.99103236f, -1.11210847f, 0.78181303f, 2.38706732f, 2.96695375f, -0.17279971f, 0.31143007f, 1.35465562f, 2.03586054f, 6.19515753f, -3.14652419f, -2.89027119f, -3.26665854f, -1.93043876f, -0.46601450f, 1.07655203f, 1.74946189f, 4.02148342f, 0.69275337f, 0.50094581f, -4.07613230f, 2.98369169f, 4.24537849f, 0.49480581f, -2.02408123f, -2.02068973f, 6.54505825f, -5.19377470f, -0.12596917f, -0.70204186f, -0.98308045f, -3.19708824f, 1.63609934f, 1.35475993f, 0.16313422f, 4.13918924f, 7.69187021f, 3.72601676f, -1.97790039f, -1.16739464f, -3.31835508f, 8.14553452f, -1.78718984f, 1.21505618f, -3.84255409f, -3.21992350f, 0.07376552f, -0.81223297f, 3.57002878f, 1.48521733f, -0.45995998f, 0.30551746f, -3.33944130f, 1.39538884f, 1.84758544f, -0.21494150f, -2.27316713f, -4.37771225f, 6.48841667f, -5.00251961f, -0.45162797f, -5.01056004f, 0.70199943f, -4.60057783f, -2.22394514f, 0.07777429f, -1.49820781f, 3.47308421f, 6.13231564f, 1.18605387f, -4.78924608f, -3.49548388f, -2.73382568f, 6.24617863f, -2.74291611f, -1.03833354f, -2.20752788f, -2.33219409f, 1.48633552f, 1.65796840f, 4.95045471f, 2.58479190f, -0.90922785f, 0.71312457f, -4.44465590f, 1.37020862f, 2.37683725f, 0.18805164f, -3.28422308f, -1.64939332f, 3.64181972f, -3.75277281f, 3.67203593f, -0.11204052f, 2.24140930f, -3.90657187f, 2.56883717f, -1.44016707f, -2.83842611f, -0.29104578f, 2.17757058f, -0.71431804f, 1.36911654f, 0.85083604f, -1.60110259f, -1.97247636f, -1.61163378f, -0.81236130f, -0.38993555f, -3.03631902f, -0.38213277f, 0.06394482f, 3.19348621f, 0.36771113f, 1.36763072f, 2.49159527f, -0.39599860f, -2.69996762f, -0.97561121f, -2.97563028f, -0.49662948f, -0.17564940f, -2.79042959f, 0.72395414f, 2.07260203f, -0.99439794f, -2.20248008f, -0.07389921f, 0.65536159f, 4.73054695f, -0.63917702f, 0.58788192f, -3.60156059f, 6.59609890f, 3.88419437f, -3.38469863f, -3.56237841f, -2.03295064f, 0.07279694f, 3.71804547f, 0.79928309f, -2.13411403f, -1.13909864f, -0.34193408f, -1.00338125f, -1.44231665f, -5.39835978f, -0.45086145f, 1.16064668f, 2.58335257f, 2.10072684f, 4.64244223f, 7.10090065f, 1.01974952f, -4.44687223f, 2.99792576f, 1.10303724f, -1.22736573f, -3.91514421f, 3.07458854f, 2.18765211f, 3.34481716f, 2.46166849f, 2.99648619f, -0.94046807f, 5.55028200f, 0.92199719f, -0.83934361f, -0.72042274f, 0.84869325f, 1.46914721f, 0.85937387f, 4.77306223f, -4.06436539f, -2.59847593f, 2.44828081f, 0.50484699f, -2.71092367f, -6.39010477f, 0.91778028f, 3.25469685f, 1.30310678f, 1.35258150f, 3.56171441f, 7.82435083f, -2.51527429f, -4.24328852f, 2.36876059f, 1.94595242f, -2.59290171f, -6.62389565f, 3.32567835f, 2.13659120f, 4.09299326f, 3.48293996f, 2.64965177f, -3.19157362f, 13.37204266f, -0.50297594f, -4.57448196f, 3.95582604f, -0.69038916f, 0.10098404f, 1.18737555f, 3.65761185f, -5.69623756f, -2.03357077f, 1.02868807f, -1.38448596f, -0.05690211f, -8.48874187f, 0.56755424f, 1.45485961f, 0.66273880f, 0.06495565f, 1.79539490f, 8.46864319f, -1.22696662f, -1.87585378f, -0.99768794f, 2.72801924f, -0.66980243f, -2.31924677f, 0.33271110f, 0.11666083f, 1.86980045f, 5.95332909f, 7.38583708f, -2.80956483f, 6.79227638f, -6.78070831f, 1.21884382f, -1.40695429f, 0.90236962f, -1.13695288f, 0.50760663f, 1.00955284f, -5.39029121f, 0.24987072f, 2.24283314f, -4.02145576f, 2.18057394f, -3.35627747f, 1.26061773f, 1.30342579f, 0.11311233f, -1.11199212f, -4.06509686f, 5.82649660f, -1.24059582f, 5.51652861f, -1.90937877f, 1.10658336f, -0.47065550f, -2.39167786f, -1.95931304f, 4.12717247f, 1.15396059f, 1.26015663f, 7.97836876f, 7.33633423f, 2.27785325f, -2.83802366f, -2.74850106f, 0.86126029f, 6.18781090f, -1.43707538f, -6.97134876f, -3.25486469f, -1.95214593f, 0.91066706f, 0.89637989f, 1.06481194f, 6.25791073f, 0.81779671f, -1.08384395f, -3.21191931f, 2.04216075f, 4.76030350f, -2.37217665f, -1.42571259f, -6.35876131f, 4.62536526f, -5.40060568f, -3.14868999f, -1.00587153f, 1.80662942f, -7.03201485f, 6.08373499f, 0.99862772f, 2.21717811f, 4.06814623f, 6.02428913f, 5.33422756f, -0.87013257f, -2.22477579f, -2.51505303f, 5.82925224f, -0.82854009f, -4.30698347f, -1.75007713f, 2.08352375f, -2.25235629f, 1.17517352f, 5.77717733f, 2.27472878f, 2.72778273f, -1.95411634f, -4.52602863f, 1.13983536f, 1.16340065f, -2.02740526f, -3.11290503f, -1.94906235f, 1.54855204f, -4.52984142f, 1.97465122f, -1.79415476f, 4.03510094f, -8.45349979f, 10.87430096f, 2.19863629f, -5.39083815f, 5.86213875f, 6.25744534f, 6.52600002f, -4.72149038f, -1.75254321f, -5.51459169f, 7.03155518f, -2.01889277f, -4.58441257f, -3.61226106f, 0.42395937f, -0.93263882f, 2.28703761f, 2.80611467f, 2.59498215f, 0.65989012f, -1.51268566f, -4.49465561f, -4.70453882f, 5.44696808f, -4.37603617f, 0.46670085f, 2.82488608f, 2.18854523f, -2.04817152f, 1.19557285f, 1.53618634f, 4.44758606f, -7.31593513f, 7.43966007f, -3.55480957f, -5.29834652f, 2.14622784f, 1.65194583f, 2.71262598f, -4.86145496f, 0.79726243f, -8.88541985f, 1.19627261f, 0.79660845f, -1.98016644f, 1.03741014f, -3.93128228f, 1.05535269f, 2.01378822f, -0.46086323f, -0.77754641f, -1.43942690f, 0.49809402f, -2.27861357f, -3.29815221f, 0.38201320f, -3.98481083f, 4.88261318f, -0.44555628f, -2.57224536f, 2.35001850f, -2.65835261f, -2.43422794f, -2.97889376f, 1.07349825f, 1.88157082f, 4.74075413f, 0.60376728f, -0.48894715f, -1.15800071f, 4.68110943f, -0.86976886f, 1.49192941f, 0.62665290f, 0.20652676f, 0.53916287f, -1.45706177f, 0.66133004f, 1.34405875f, -4.27689552f, -0.20838106f, -5.14266443f, -1.29718637f, -1.74506426f, -0.86022055f, -3.57553625f, 0.46880072f, -1.25287139f, 3.28596354f, 11.33191013f, 1.23942876f, -3.87616491f, 7.57880497f, -0.22940339f, -5.68512678f, -1.94969654f, 5.85449600f, 3.75705457f, 4.24395847f, 1.60086083f, 2.62553668f, -0.93964291f, 5.84753895f, -0.79931092f, 0.48274064f, 2.07170033f, 3.02243996f, 2.63509989f, -0.76043403f, -1.64048159f, -6.17683458f, -3.09974527f, -2.12773156f, -0.89379883f, 2.82242465f, -1.99981332f, -0.08763933f, 0.01921120f, -1.94142103f, 2.48067307f, 0.41083777f, 8.24922180f, -1.84516132f, -1.39224625f, 5.03956223f, 0.49562740f, -5.28296328f, -0.20005548f, 3.13672113f, 0.51187158f, 7.11563921f, 6.43059587f, 3.48430967f, -5.37095928f, 8.03863049f, -5.53923941f, -2.16421175f, -3.77641368f, 3.29633045f, 5.04030085f, 2.25945377f, -3.04169011f, -2.16198015f, -2.49559617f, -0.26252726f, -6.99201345f, 2.87374353f, -0.12568980f, 0.23314142f, -1.32087135f, 4.39030552f, -0.24638844f, -4.37242651f, 14.09276772f, 1.23987353f, -1.72249663f, 0.31124914f, -2.13725138f, -3.74915648f, -1.87147236f, 0.47318631f, 1.13337576f, 3.00416899f, 8.82548523f, 4.80538750f, -5.28486395f, 5.51870108f, -5.15801477f, 0.95712411f, -1.50416136f, 2.34657240f, 4.20726633f, 5.56757259f, -3.30645251f, -3.39945269f, -2.68488026f, -2.53525281f, -3.15145874f, 2.74529529f, -0.96283442f, 2.87778258f, 0.22186530f, 1.24905694f, -7.07941198f, -5.45916176f, 3.46988297f, 0.92430985f, -0.98330998f, -2.23672342f, -3.03262734f, 0.73941302f, 0.98004431f, 0.83219361f, 7.17411804f, 4.27849865f, 0.14765590f, 8.61269569f, 9.04497051f, 1.53991723f, -2.08305025f, -4.34939337f, 0.63786775f, 2.60098696f, 0.02432060f, -1.48516297f, -4.06825686f, 5.12420368f, -0.75312757f, 1.96927559f, 4.91575956f, 3.41533065f, 3.62557888f, -4.35002136f, -5.91343403f, 0.45026422f, 4.93286371f, 3.45830250f, -4.39032364f, -0.51697755f, -7.41543341f, -3.06703568f, 1.01196158f, 2.47106576f, 5.54014874f, -4.65312243f, 8.61000633f, 8.25905323f, -1.41497111f, 8.69221878f, 0.40090930f, 1.11325574f, -1.67089832f, -4.01080132f, 1.07925677f, 2.68086481f, -0.73093414f, -1.35081220f, -7.85765076f, -5.98989439f, -0.04651213f, 4.63693142f, 2.07757711f, -0.22652936f, 3.45525455f, -0.69198442f, -10.39761639f, -2.02106953f, 4.77755499f, -2.67665577f, -1.72481167f, 4.49634743f, -2.55717134f, -4.55044937f, 0.46377492f, -3.08933020f, 3.86891365f, -2.79104614f, 8.36974335f, 0.86471701f, -5.39342690f, 12.54906940f, -0.41536295f, -5.29502535f, -3.94430566f, -5.67391300f, -4.65079165f, 2.22505951f, -0.30000746f, 2.27855444f, -4.81604433f, -1.73440599f, 4.68784523f, 5.00208044f, 0.18863934f, -1.74989462f, 3.17923450f, -1.59773099f, -12.59962940f, -1.54495025f, -0.00576371f, 1.79913878f, -2.43449807f, 1.49516344f, -3.90507102f, 1.68647158f, 4.50177765f, -5.32286358f, 3.47539330f, -2.90529680f, 1.61576962f, 0.83679676f, -5.55615807f, 3.78939056f, -4.46644831f, -5.95550919f, 0.37808037f, 0.51334500f, 1.74658906f, -0.82085419f, -0.65387219f, 3.67790437f, 0.03758264f, -2.42622781f, 1.83335185f, 4.73835945f, -0.83536482f, -0.03993917f, 3.78230667f, -4.81265640f, -8.26869011f, -1.30363441f, -2.09106350f, -3.96769738f, -1.89037073f, 0.38682747f, 0.05434489f, 5.72213697f, 0.55685395f, -3.47729349f, -1.11535001f, 2.09416127f, 5.08877802f, 5.72183466f, 1.29632664f, 0.16822398f, -2.43180108f, 3.49967623f, 2.15753818f, -0.26548505f, 3.24446392f, -0.00599277f, 1.08215356f, -0.23225522f, -2.40723038f, 0.18496060f, -3.70608735f, -0.19918591f, -1.64028871f, 0.80792952f, -0.85334057f, -2.52314138f, -3.12099195f, 0.17949918f, -0.82650864f, 2.32224989f, 9.56476116f, -0.20134282f, -0.48428559f, 2.86784410f, 0.07289505f, -3.92880869f, -2.11887884f, 0.59164631f, 6.31267452f, 7.49149418f, 2.88749456f, 2.40504885f, -3.57608175f, -1.48019314f, -0.69410253f, 0.90275228f, -0.34111357f, 2.19190216f, 3.39090061f, 3.39631820f, -5.19105434f, 2.67546582f, -2.56549048f, -0.59797800f, -4.21802664f, 0.63918972f, -0.69969130f, 0.47496963f, -4.30976725f, 0.16531238f, -3.59595251f, -0.76877379f, 11.79971790f, -0.93276632f, -1.48630571f, 8.04754066f, 2.09168458f, -3.77018499f, -4.19337654f, 0.26171905f, 1.99359691f, 8.96759701f, 8.39609814f, 6.19231987f, -5.36037970f, 4.69818354f, -4.22453928f, -4.61665344f, -2.52073431f, 1.34026706f, 2.80182385f, 2.56681514f, -4.04676390f, -3.01466990f, -4.10480118f, 0.38737059f, -0.37146521f, -2.26529670f, -1.72867084f, 0.93472683f, -2.47562981f, 0.89871657f, -1.67618203f, -0.28950238f, 5.30124855f, -0.14731219f, -0.81319761f, -1.11265934f, 0.11356127f, -2.52802444f, -1.93826056f, 1.06187987f, 1.48062325f, 4.28070498f, 5.69893932f, 9.26904392f, -4.23773003f, 5.78582096f, -6.18445301f, -2.85200453f, -5.30461454f, -4.16009140f, -0.07239690f, 4.11531162f, -1.12266588f, -1.50265646f, 0.47661865f, -1.90043914f, -6.48978710f, 1.71005368f, 0.18256521f, -0.88272136f, -0.51324779f, -0.78045660f, -5.21036625f, -4.11805344f, 3.99454761f, -1.04999924f, -6.99629354f, -5.02737141f, 0.94748145f, -2.35882139f, 4.13982439f, -1.41835535f, 7.56763077f, 3.97024012f, -4.08156776f, 6.90305424f, 0.53571963f, -2.22625160f, -2.09144926f, -4.98530245f, -0.15102190f, 0.59995949f, 3.28562784f, 0.77991986f, -3.08389306f, 3.34046674f, 0.41394949f, 5.10031366f, 2.99692893f, 0.17706826f, 2.85998058f, -6.68330860f, -6.72653008f, -0.04071128f, 3.71085787f, 3.17834806f, -4.88019037f, 6.74075413f, -7.41782188f, -5.22026348f, -1.94595623f, -3.61318684f, 1.85610664f, 1.08613706f, 6.41580677f, 1.46376514f, -4.11524010f, 9.59146214f, -2.92772651f, -1.70753336f, -1.51594138f, -4.88185692f, 1.47331417f, -2.23893595f, 4.98459148f, 1.29359996f, -2.29221845f, -0.99594390f, 3.05759239f, 6.86030054f, 2.40487719f, 3.28339863f, 7.72739315f, -3.60563445f, -9.73502827f, -1.51672328f, -0.08473521f, -2.43673515f, -3.26616001f, 3.63767886f, -11.25394535f, -5.17597103f, -1.27523947f, -7.82669783f, 0.67929745f, -4.50530529f, 5.49323797f, 6.78993320f, -2.28033876f, 4.61412525f, 2.55109429f, -12.38607693f, -0.63024014f, -3.45992327f, -0.84092742f, -0.03252453f, 4.58635283f, 5.28213978f, -1.28417206f, -1.71185923f, -0.26850975f, 8.28257561f, 4.47432184f, 2.72818279f, 8.42217731f, -4.22216320f, -8.95128918f, -1.57179546f, 1.34253705f, -5.47035217f, -5.50866985f, 4.64156532f, -6.11207914f, -5.46734476f, 3.54298997f, -2.79237103f, -0.70766860f, -3.62739944f, 3.22660995f, -2.02262759f, 0.11224222f, 2.63832402f, -0.91955596f, -4.65958309f, -0.29729855f, -1.78957534f, -0.40749407f, 0.51688713f, 0.83725226f, 0.30945438f, 1.20769620f, -1.75219965f, 2.59689760f, 5.01501608f, -1.59034789f, 0.58155286f, 3.75831509f, -5.26110506f, -8.65382767f, -6.19066620f, -0.61932850f, -2.71863723f, -0.87443137f, 3.40582991f, -1.27868056f, 3.51236677f, -2.07806540f, -0.85076392f, -1.14599180f, 1.16361260f, 1.86411846f, 5.86179352f, 0.69029891f, -0.06060839f, 1.54649436f, -0.60351688f, 1.51970077f, 0.04187265f, 1.64540339f, 2.75502157f, 2.46308279f, 1.69071770f, -3.23827076f, 0.92096543f, -3.09458661f, -1.23823690f, 0.24035048f, -0.74456501f, -1.85476089f, -0.32914662f, -2.10325241f, 1.19795251f, -2.05372071f, 1.02114081f, 2.56286955f, 0.42165697f, -1.65826249f, 4.00724554f, -2.18727994f, -1.05848944f, -0.52338278f, -0.28714985f, 8.08780861f, 5.04444599f, 3.51866961f, 3.37445784f, -1.96067202f, -1.21509445f, -3.96595931f, -0.80801201f, 0.76944816f, 1.80147493f, 4.14419460f, -0.12201095f, -2.77788162f, 1.13284469f, -2.05441403f, -0.61129224f, -2.69690657f, 1.91634214f, -2.17146754f, -0.22308528f, -6.02561045f, 0.49161875f, -6.74280357f, -4.62689781f, 2.47910833f, 1.86534905f, -3.24152899f, -1.39898300f, 0.29427958f, -2.16338181f, 0.90073711f, 1.75551236f, 4.42651892f, 8.34437466f, 5.50070190f, 5.68162251f, 1.65345454f, -2.72315669f, -5.43411493f, -0.29380533f, 1.07508349f, -1.73533511f, 2.56912184f, 3.62010550f, -6.30422783f, 1.74158525f, -1.22070909f, -0.80982518f, -4.14757967f, 4.29217434f, 0.70600843f, -2.09282112f, -5.09018898f, -0.11623126f, -5.99775553f, -4.66743088f, 1.61512172f, -1.30276895f, -3.17103505f, -0.26310229f, -1.00843918f, -0.77664804f, -2.05240250f, 0.04728425f, 1.15720487f, 4.01001406f, 7.24615860f, 2.55452180f, -5.76347876f, 0.34683830f, -6.05540276f, -4.70677900f, -0.93182588f, -4.37759733f, 2.93209839f, 1.63947964f, -2.43563962f, 1.35213876f, 0.00670356f, -0.02742785f, -2.16460943f, 1.39449501f, 0.23929763f, 2.37476778f, -4.17733765f, -0.81475425f, -6.15027046f, -5.74441719f, 3.53978682f, 0.66798484f}); + + sd::ops::deconv2d_tf op; + auto result = op.evaluate({&input0, &input1, &input2}, {}, {7,7, 2,2, 0,0, 1,1, 1,1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, Test_Dilation2D_Again_1) { + auto x = NDArrayFactory::create('c', {4, 128, 128, 4}); + auto w = NDArrayFactory::create('c', {4, 5, 4}); + auto exp = NDArrayFactory::create('c', {4, 64, 43, 4}); + + + sd::ops::dilation2d op; + auto result = op.evaluate({&x, &w}, {}, {1, 1,5,7,1, 1,2,3,1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, Test_Dilation2D_Again_2) { + auto x = NDArrayFactory::create('c', {4, 26, 19, 4}); + auto w = NDArrayFactory::create('c', {11, 7, 4}); + + sd::ops::dilation2d op; + auto result = op.evaluate({&x, &w}, {}, {0, 1,2,3,1, 1,3,2,1}); + ASSERT_EQ(Status::OK(), result.status()); + +} + +TYPED_TEST(TypedConvolutionTests2, sconv2d_bp_1) { + TypeParam _expGradWpB[] = {1603.7102981f, 10645.6278024f, 5975.4227995f, 17697.0903052f, 12133.6353024f, 26535.0528052f, 1779.221097f, 11795.5686029f, 6721.9835994f, 19904.0811062f, 13775.2461029f, 30123.0936062f, 1954.7318976f, 12945.5094033f, 7468.5443993f, 22111.071907f, 15416.8569033f, 33711.134407f, 2130.2426974f, 14095.4502038f, 8215.1051992f, 24318.0627081f, 17058.4677038f, 37299.1752081f, 2305.7534972f, 15245.3910042f, 8961.6659991f, 26525.0535091f, 18700.0785042f, 40887.2160091f, 2481.2642970f, 16395.3318047f, 9708.2267991f, 28732.0443100f, 20341.6893047f, 44475.2568100f, 2656.7750968f, 17545.2726051f, 10454.7875990f, 30939.0351110f, 21983.3001051f, 48063.2976110f, 2832.2858966f, 18695.2134056f, 11201.3483989f, 33146.0259119f, 23624.9109056f, 51651.3384119f, 3007.7966964f, 19845.1542060f, 11947.9091988f, 35353.0167129f, 25266.5217060f, 55239.3792129f, 3183.3074962f, 20995.095006f, 12694.4699987f, 37560.007513f, 26908.132506f, 58827.4200139f}; + Nd4jLong _expGradWpS[] {4, 10, 6, 1, 1, 6, 1, 1, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99}; + NDArray expGWP(_expGradWpB, _expGradWpS); + expGWP.permutei({2,3,1,0}); + + TypeParam _expGradWdB[] = {2074.21032f, 2082.76104f, 2091.31176f, 2099.86248f, 2108.4132f, 2159.71752f, 2168.26824f, 2176.81896f, 2185.36968f, 2193.9204f, 2245.22472f, 2253.77544f, 2262.32616f, 2270.87688f, 2279.4276f, 2330.73192f, 2339.28264f, 2347.83336f, 2356.38408f, 2364.9348f, 2416.23912f, 2424.78984f, 2433.34056f, 2441.89128f, 2450.442f, 3112.99344f, 3122.06328f, 3131.13312f, 3140.20296f, 3149.2728f, 3203.69184f, 3212.76168f, 3221.83152f, 3230.90136f, 3239.9712f, 3294.39024f, 3303.46008f, 3312.52992f, 3321.59976f, 3330.6696f, 3385.08864f, 3394.15848f, 3403.22832f, 3412.29816f, 3421.368f, 3475.78704f, 3484.85688f, 3493.92672f, 3502.99656f, 3512.0664f, 4255.60056f, 4265.18952f, 4274.77848f, 4284.36744f, 4293.9564f, 4351.49016f, 4361.07912f, 4370.66808f, 4380.25704f, 4389.846f, 4447.37976f, 4456.96872f, 4466.55768f, 4476.14664f, 4485.7356f, 4543.26936f, 4552.85832f, 4562.44728f, 4572.03624f, 4581.6252f, 4639.15896f, 4648.74792f, 4658.33688f, 4667.92584f, 4677.5148f, 2140.10988f, 2148.92016f, 2157.73044f, 2166.54072f, 2175.351f, 2228.21268f, 2237.02296f, 2245.83324f, 2254.64352f, 2263.4538f, 2316.31548f, 2325.12576f, 2333.93604f, 2342.74632f, 2351.5566f, 2404.41828f, 2413.22856f, 2422.03884f, 2430.84912f, 2439.6594f, 2492.52108f, 2501.33136f, 2510.14164f, 2518.95192f, 2527.7622f, 3204.849f, 3214.1784f, 3223.5078f, 3232.8372f, 3242.1666f, 3298.143f, 3307.4724f, 3316.8018f, 3326.1312f, 3335.4606f, 3391.437f, 3400.7664f, 3410.0958f, 3419.4252f, 3428.7546f, 3484.731f, 3494.0604f, 3503.3898f, 3512.7192f, 3522.0486f, 3578.025f, 3587.3544f, 3596.6838f, 3606.0132f, 3615.3426f, 4373.41212f, 4383.26064f, 4393.10916f, 4402.95768f, 4412.8062f, 4471.89732f, 4481.74584f, 4491.59436f, 4501.44288f, 4511.2914f, 4570.38252f, 4580.23104f, 4590.07956f, 4599.92808f, 4609.7766f, 4668.86772f, 4678.71624f, 4688.56476f, 4698.41328f, 4708.2618f, 4767.35292f, 4777.20144f, 4787.04996f, 4796.89848f, 4806.747f}; + Nd4jLong _expGradWdS[] = {4, 2, 3, 5, 5, 75, 25, 5, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99}; + NDArray expGWD(_expGradWdB, _expGradWdS); + expGWD.permutei({2,3,1,0}); + + TypeParam _expEB[] = {5.0103f, 10.17147f, 15.48408f, 20.9487f, 26.5659f, 26.6832f, 21.65628f, 16.47507f, 11.139f, 5.6475f, 10.79727f, 21.90255f, 33.31698f, 45.0417f, 57.07785f, 57.3267f, 46.49334f, 35.34513f, 23.88093f, 12.0996f, 17.37801f, 35.22744f, 53.55f, 72.3474f, 91.62135f, 92.016f, 74.57958f, 56.66148f, 38.25999f, 19.3734f, 24.76962f, 50.18034f, 76.23444f, 102.9342f, 130.2819f, 130.8366f, 105.9834f, 80.47542f, 54.31038f, 27.486f, 32.9892f, 66.79545f, 101.4216f, 136.8705f, 173.145f, 173.874f, 140.7732f, 106.83825f, 72.0663f, 36.4545f, 33.8298f, 68.49375f, 103.9947f, 140.3355f, 177.519f, 178.248f, 144.3066f, 109.51395f, 73.8672f, 37.3635f, 28.85658f, 58.39302f, 88.6116f, 119.5146f, 151.1043f, 151.716f, 122.76444f, 93.11934f, 62.77842f, 31.7394f, 23.00409f, 46.52748f, 70.57188f, 95.139f, 120.23055f, 120.7107f, 97.6311f, 74.02194f, 49.88151f, 25.2081f, 16.25523f, 32.86293f, 49.82424f, 67.1403f, 84.81225f, 85.1466f, 68.83818f, 52.17045f, 35.14227f, 17.7525f, 8.5929f, 17.36517f, 26.31738f, 35.4501f, 44.7639f, 44.9382f, 36.31728f, 27.51357f, 18.5265f, 9.3555f, 8.63807f, 17.45032f, 26.43736f, 35.5998f, 44.93825f, 45.1399f, 36.46882f, 27.6199f, 18.59253f, 9.3861f, 18.18615f, 36.72737f, 55.62488f, 74.8799f, 94.49365f, 94.9122f, 76.65698f, 58.03937f, 39.05815f, 19.7121f, 28.66254f, 57.86775f, 87.61746f, 117.9135f, 148.7577f, 149.4084f, 120.63768f, 91.31331f, 61.43346f, 30.9963f, 40.08554f, 80.90806f, 122.47f, 164.7738f, 207.8219f, 208.72f, 168.48412f, 127.49662f, 85.75506f, 43.257f, 52.47345f, 105.8849f, 160.2374f, 215.534f, 271.77775f, 272.9385f, 220.2695f, 166.6442f, 112.05955f, 56.5125f, 53.82975f, 108.6158f, 164.3612f, 221.069f, 278.74225f, 279.903f, 225.8777f, 170.8778f, 114.90025f, 57.942f, 45.14002f, 91.0585f, 137.75788f, 185.2406f, 233.5091f, 234.4682f, 189.16564f, 143.06998f, 96.17878f, 48.4896f, 35.43048f, 71.45487f, 108.075f, 145.2927f, 183.1098f, 183.852f, 148.29504f, 112.13319f, 75.36462f, 37.9875f, 24.68283f, 49.76831f, 75.25766f, 101.1521f, 127.45285f, 127.9629f, 103.1927f, 78.01253f, 52.42117f, 26.4174f, 12.87877f, 25.96222f, 39.25096f, 52.7456f, 66.44675f, 66.7094f, 53.78542f, 40.6531f, 27.31183f, 13.761f, 12.59184f, 25.38317f, 38.37464f, 51.5669f, 64.9606f, 65.2566f, 52.61336f, 39.76673f, 26.71606f, 13.4607f, 26.23903f, 52.88419f, 79.93678f, 107.3981f, 135.26945f, 135.8777f, 109.53262f, 82.77361f, 55.59937f, 28.0086f, 40.96107f, 82.54206f, 124.74492f, 167.5716f, 211.02405f, 211.9608f, 170.83578f, 129.07914f, 86.68893f, 43.6632f, 56.77746f, 114.39578f, 172.85756f, 232.1654f, 292.3219f, 293.6034f, 236.60084f, 178.74182f, 120.02374f, 60.444f, 73.7077f, 148.48435f, 224.3332f, 301.2575f, 379.2605f, 380.903f, 306.9058f, 231.82015f, 155.6428f, 78.3705f, 75.6397f, 152.36785f, 230.1877f, 309.1025f, 389.1155f, 390.758f, 314.8288f, 237.79165f, 159.6433f, 80.3805f, 62.89546f, 126.67598f, 191.34416f, 256.9026f, 323.3539f, 324.7004f, 261.56684f, 197.53262f, 132.59514f, 66.7518f, 48.97887f, 98.63226f, 148.96212f, 199.9704f, 251.65905f, 252.6933f, 203.53098f, 153.68244f, 103.14573f, 51.9189f, 33.87043f, 68.19769f, 102.98308f, 138.2279f, 173.93345f, 174.6392f, 140.64322f, 106.18261f, 71.25607f, 35.8623f, 17.55064f, 35.33327f, 53.34854f, 71.5971f, 90.0796f, 90.4406f, 72.82556f, 54.97463f, 36.88716f, 18.5625f, 13.0455f, 26.44707f, 40.20528f, 54.3207f, 68.7939f, 68.9112f, 55.84908f, 42.42747f, 28.6458f, 14.5035f, 27.89367f, 56.50575f, 85.83738f, 115.8897f, 146.66385f, 146.9127f, 118.98294f, 90.32793f, 60.94653f, 30.8376f, 44.56161f, 90.21024f, 136.9476f, 184.7754f, 233.69535f, 234.09f, 189.46998f, 143.75268f, 96.93639f, 49.0194f, 63.06642f, 127.59474f, 193.58724f, 261.0462f, 329.9739f, 330.5286f, 267.3786f, 202.75302f, 136.64958f, 69.066f, 83.4252f, 168.69345f, 255.8076f, 344.7705f, 435.585f, 436.314f, 352.7772f, 267.38025f, 180.1203f, 90.9945f, 84.2658f, 170.39175f, 258.3807f, 348.2355f, 439.959f, 440.688f, 356.3106f, 270.05595f, 181.9212f, 91.9035f, 71.25738f, 144.01542f, 218.2764f, 294.0426f, 371.3163f, 371.928f, 300.57564f, 227.70894f, 153.32562f, 77.4234f, 56.34369f, 113.82228f, 172.43748f, 232.191f, 293.08455f, 293.5647f, 237.1455f, 179.58114f, 120.86991f, 61.0101f, 39.50763f, 79.77813f, 120.81264f, 162.6123f, 205.17825f, 205.5126f, 165.95178f, 125.62125f, 84.51987f, 42.6465f, 20.7321f, 41.84877f, 63.35058f, 85.2381f, 107.5119f, 107.6862f, 86.92608f, 65.77797f, 44.2413f, 22.3155f, 22.71767f, 45.82912f, 69.33496f, 93.2358f, 117.53225f, 117.7339f, 94.98322f, 71.8351f, 48.28893f, 24.3441f, 47.44335f, 95.68097f, 144.71408f, 194.5439f, 245.17165f, 245.5902f, 198.07778f, 149.76377f, 100.64695f, 50.7261f, 74.19534f, 149.59215f, 226.19226f, 303.9975f, 383.0097f, 383.6604f, 309.35688f, 233.84091f, 157.11066f, 79.1643f, 102.99194f, 207.59926f, 313.8244f, 421.6698f, 531.1379f, 532.036f, 428.89372f, 324.12142f, 217.71666f, 109.677f, 133.85145f, 269.7389f, 407.6654f, 547.634f, 689.64775f, 690.8085f, 556.7615f, 420.6602f, 282.50155f, 142.2825f, 135.20775f, 272.4698f, 411.7892f, 553.169f, 696.61225f, 697.773f, 562.3697f, 424.8938f, 285.34225f, 143.712f, 112.43842f, 226.5337f, 342.28828f, 459.7046f, 578.7851f, 579.7442f, 467.14324f, 352.87078f, 236.92438f, 119.3016f, 87.55128f, 176.35527f, 266.4138f, 357.7287f, 450.3018f, 451.044f, 363.36624f, 274.42479f, 184.21782f, 92.7435f, 60.52803f, 121.89791f, 184.11086f, 247.1681f, 311.07085f, 311.5809f, 250.9655f, 189.50093f, 127.18597f, 64.0194f, 31.35037f, 63.12502f, 95.32456f, 127.9496f, 161.00075f, 161.2634f, 129.86782f, 98.0443f, 65.79223f, 33.111f, 33.43584f, 67.30517f, 101.60864f, 136.3469f, 171.5206f, 171.8166f, 138.32936f, 104.40473f, 70.04206f, 35.2407f, 69.09703f, 139.06819f, 209.91478f, 281.6381f, 354.23945f, 354.8477f, 285.64462f, 215.55961f, 144.59137f, 72.7386f, 107.00307f, 215.32806f, 324.97692f, 435.9516f, 548.25405f, 549.1908f, 442.02378f, 333.52314f, 223.68693f, 112.5132f, 147.17346f, 296.12378f, 446.85356f, 599.3654f, 753.6619f, 754.9434f, 607.54484f, 458.35382f, 307.36774f, 154.584f, 189.6277f, 381.49435f, 575.6032f, 771.9575f, 970.5605f, 972.203f, 782.2858f, 590.11015f, 395.6728f, 198.9705f, 191.5597f, 385.37785f, 581.4577f, 779.8025f, 980.4155f, 982.058f, 790.2088f, 596.08165f, 399.6733f, 200.9805f, 157.97146f, 317.76398f, 479.38016f, 642.8226f, 808.0939f, 809.4404f, 651.23084f, 491.18462f, 329.29914f, 165.5718f, 122.04087f, 245.45826f, 370.25412f, 496.4304f, 623.98905f, 625.0233f, 502.79898f, 379.18644f, 254.18373f, 127.7889f, 83.74843f, 168.42169f, 254.02108f, 340.5479f, 428.00345f, 428.7092f, 344.83522f, 260.02861f, 174.28807f, 87.6123f, 43.07464f, 86.61527f, 130.62254f, 175.0971f, 220.0396f, 220.4006f, 177.26156f, 133.65263f, 89.57316f, 45.0225f }; + Nd4jLong _expES[] = {4, 2, 3, 10, 10, 300, 100, 10, 1, typeid(TypeParam) == typeid(float) ? 8192 : 16384, 1, 99}; + NDArray expE(_expEB, _expES); + + auto input = NDArrayFactory::create('c', {2, 3, 10, 10}); + auto weightsD = NDArrayFactory::create('c', {2, 3, 5, 5}); + auto weightsP = NDArrayFactory::create('c', {10, 6, 1, 1}); + + auto epsilon = NDArrayFactory::create('c', {2, 3, 10, 10}); + auto epsilonNext = NDArrayFactory::create('c', {2, 10, 6, 6}); + + input.linspace(1); + weightsD.linspace(1); + weightsP.linspace(1); + epsilonNext.linspace(1); + weightsD.permutei({2,3,1,0}); + weightsP.permutei({2,3,1,0}); + + input.applyScalar(scalar::Divide, 100.0, input); + weightsD.applyScalar(scalar::Divide, 100.0, weightsD); + weightsP.applyScalar(scalar::Divide, 100.0, weightsP); + epsilonNext.applyScalar(scalar::Divide, 100.0, epsilonNext); + + sd::ops::sconv2d_bp op; + auto resultBP = op.evaluate({&input, &epsilonNext, &weightsD, &weightsP },{}, {5, 5, 1, 1, 0, 0, 1, 1, 0}, {}); + + ASSERT_EQ(3, resultBP.size()); + + auto _epsilon = resultBP.at(0); + auto _gradWD = resultBP.at(1); + auto _gradWP = resultBP.at(2); + + //_gradWP->printBuffer("gradWP"); + + ASSERT_TRUE(_gradWP->isSameShape(&expGWP)); + ASSERT_TRUE(_gradWP->isSameShape(&weightsP)); + + ASSERT_TRUE(_gradWP->equalsTo(&expGWP)); + + //_gradWD->printShapeInfo("gradWD shape"); + + ASSERT_TRUE(_gradWD->isSameShape(&expGWD)); + ASSERT_TRUE(_gradWD->isSameShape(&weightsD)); +// _gradWD->printIndexedBuffer(); + ASSERT_TRUE(_gradWD->equalsTo(&expGWD)); + + ASSERT_TRUE(_epsilon->isSameShape(&input)); + ASSERT_TRUE(_epsilon->isSameShape(&expE)); + + ASSERT_TRUE(_epsilon->equalsTo(&expE)); + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, sconv2d_bp_2) { + + int bS=3, iH=16,iW=16, iC=3,mC=3, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=2,dW=2; + int oH=16,oW=16; + int oC=2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iH, iW}, typeid(TypeParam) == typeid(float) ? sd::DataType::FLOAT32 : sd::DataType::DOUBLE); + NDArray gradO('c', {bS, oC, oH, oW}, typeid(TypeParam) == typeid(float) ? sd::DataType::FLOAT32 : sd::DataType::DOUBLE); + NDArray weightsDepth('c', {kH, kW, iC, mC}, typeid(TypeParam) == typeid(float) ? sd::DataType::FLOAT32 : sd::DataType::DOUBLE); + NDArray weightsPoint('f', {1, 1, iC*mC, oC}, typeid(TypeParam) == typeid(float) ? sd::DataType::FLOAT32 : sd::DataType::DOUBLE); + NDArray bias('c', {1,oC}, {0.5, 0.5}, typeid(TypeParam) == typeid(float) ? sd::DataType::FLOAT32 : sd::DataType::DOUBLE); + + NDArray gradI(&input); + NDArray gradWD(&weightsDepth); + NDArray gradWP(&weightsPoint); + NDArray gradB(&bias); + + input = 2.; + weightsDepth.linspace(0.1, 0.1); + weightsPoint.linspace(0.15, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::sconv2d_bp op; + Nd4jStatus status = op.execute({&input, &gradO, &weightsDepth, & weightsPoint, &bias}, + {&gradI, &gradWD, &gradWP, &gradB}, + {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}, {}); + + ASSERT_EQ(Status::OK(), status); + + NDArray expGradI = gradI; + NDArray expGradWD = gradWD; + NDArray expGradWP = gradWP; + NDArray expGradB = gradB; + + for( int i=0; i<10; i++ ) { + Nd4jStatus status = op.execute({&input, &gradO, &weightsDepth, & weightsPoint, &bias}, + {&gradI, &gradWD, &gradWP, &gradB}, + {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}, {}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(expGradI.equalsTo(gradI)); + ASSERT_TRUE(expGradWD.equalsTo(gradWD)); + ASSERT_TRUE(expGradWP.equalsTo(gradWP)); + ASSERT_TRUE(expGradB.equalsTo(expGradB)); + } +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, sconv2d_bp_3) { + + auto input = NDArrayFactory::create('c', {3, 3, 16, 16}); + auto weightsD = NDArrayFactory::create('c', {1, 3, 2, 2}); + auto weightsP = NDArrayFactory::create('c', {2, 3, 1, 1}); + auto bias = NDArrayFactory::create('c', {1, 2}); + + weightsD.permutei({2,3,1,0}); + weightsP.permutei({2,3,1,0}); + + auto epsilonNext = NDArrayFactory::create('c', {3, 2, 14, 14}); + + auto epsilon = NDArrayFactory::create('c', {3, 3, 16, 16}); + + sd::ops::sconv2d_bp op; + auto result = op.evaluate({&input, &epsilonNext, &weightsD, &weightsP}, {}, {2, 2, 1, 1, 0, 0, 2, 2, 0}); + + auto eps = result.at(0); + auto gWD = result.at(1); + auto gWP = result.at(2); + + + ASSERT_TRUE(epsilon.isSameShape(eps)); +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, sconv2d_bp_4) { + + int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=4,oW=3; + int oC=iC*mC; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weightsDepth = NDArrayFactory::create('c', {kH, kW, iC, mC}); + auto bias = NDArrayFactory::create('c', {oC}, {1,2,3,4}); + auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); + + auto expGradI = NDArrayFactory::create('c', {bS, iH, iW, iC},{0.07f, 0.19f, 0.348f, 0.652f, 0.588f, 0.956f, 0.387f, 0.687f, 1.326f, 2.022f, 1.878f, 2.67f, 1.071f, 1.515f, 2.982f, 3.966f, 3.534f, 4.614f, 1.606f, 1.982f, 3.932f, 4.748f, 4.428f, 5.308f, + 1.126f, 1.63f, 3.228f, 4.3f, 3.468f, 4.604f, 3.123f, 3.999f, 7.95f, 9.798f, 8.502f, 10.446f, 3.807f, 4.827f, 9.606f, 11.742f,10.158f, 12.39f, 4.198f, 4.958f, 9.884f, 11.468f,10.38f, 12.028f}); + + auto expGradW = NDArrayFactory::create('c', {kH, kW, iC, mC},{19.08f, 19.44f, 19.8f, 20.16f, 12.24f, 12.48f, 12.72f, 12.96f, 22.56f, 23.04f, 23.52f, 24.f, 14.4f, 14.72f, 15.04f, 15.36f, 14.76f, 15.12f, 15.48f, 15.84f, 9.36f, 9.6f, 9.84f, 10.08f}); + + input = 2.; + weightsDepth.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::sconv2d_bp op; + auto results = op.evaluate({&input, &gradO, &weightsDepth, &bias}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto* gradI = results.at(0); + auto* gradWD = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradWD)); + ASSERT_TRUE(expGradW.equalsTo(gradWD)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, sconv2d_bp_5) { + + int bS=1, iH=8,iW=8, iC=3,mC=3, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=8,oW=8; + int oC=2; // iC*mC if weightsPoint = nullptr + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, oC, oH, oW}); + auto weightsDepth = NDArrayFactory::create('c', {kH, kW, iC, mC}); + auto weightsPoint = NDArrayFactory::create('c', {1, 1, iC*mC, oC}); + auto bias = NDArrayFactory::create('c', {1,oC}, {1,2}); + + auto gradI = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto gradWD = NDArrayFactory::create('f', {kH, kW, iC, mC}); + auto gradWP = NDArrayFactory::create('c', {1, 1, iC*mC, oC}); + auto gradB = NDArrayFactory::create('c', {1,oC}, {1,2}); + + input = 2.; + weightsDepth.linspace(0.1, 0.1); + weightsDepth.linspace(-0.5, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::sconv2d_bp op; + auto status = op.execute({&input, &gradO, &weightsDepth, &weightsPoint, &bias}, {&gradI, &gradWD, &gradWP, &gradB}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}, {}); + ASSERT_EQ(Status::OK(), status); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, im2col_bp_1) { + + int bS=3, iH=12,iW=12, iC=6,oC=3, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=12,oW=12; + + // [bS, iC, kH, kW, oH, oW] is de-convoluted to [bS, iC, iH, iW] + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::DOUBLE); + NDArray gradO('c', {bS, iC, kH, kW, oH, oW}, sd::DataType::DOUBLE); + NDArray gradI('c', {bS, iC, iH, iW}, sd::DataType::DOUBLE); // output + + sd::ops::im2col_bp op; + Nd4jStatus status = op.execute({&input, &gradO}, {&gradI}, {}, {kH, kW, sH, sW, pH, pW, dH, dW, 1}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, deconv3d_test1) { + + int bS=2, iD=4,iH=4,iW=4, iC=2,oC=3, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=3,oH=3,oW=3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto exp = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0.3 , 0.75, 1.5 , 2.4 , 1.5 , 2.4 , 1.2 , 1.65, 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 4.2 , 5.1 , 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 4.2 , 5.1 , 2.1 , 2.55, 5.1 , 6. , 5.1 , 6. , 3. , 3.45, + 4.2 , 5.1 ,10.2 ,12. ,10.2 ,12. , 6. , 6.9 ,12. ,13.8 ,27.6 ,31.2 ,27.6 ,31.2 ,15.6 ,17.4 ,12. ,13.8 ,27.6 ,31.2 ,27.6 ,31.2 ,15.6 ,17.4 , 7.8 , 8.7 ,17.4 ,19.2 ,17.4 ,19.2 , 9.6 ,10.5 , + 4.2 , 5.1 ,10.2 ,12. ,10.2 ,12. , 6. , 6.9 ,12. ,13.8 ,27.6 ,31.2 ,27.6 ,31.2 ,15.6 ,17.4 ,12. ,13.8 ,27.6 ,31.2 ,27.6 ,31.2 ,15.6 ,17.4 , 7.8 , 8.7 ,17.4 ,19.2 ,17.4 ,19.2 , 9.6 ,10.5 , + 3.9 , 4.35, 8.7 , 9.6 , 8.7 , 9.6 , 4.8 , 5.25, 9.6 ,10.5 ,21. ,22.8 ,21. ,22.8 ,11.4 ,12.3 , 9.6 ,10.5 ,21. ,22.8 ,21. ,22.8 ,11.4 ,12.3 , 5.7 , 6.15,12.3 ,13.2 ,12.3 ,13.2 , 6.6 , 7.05, + 0.3 , 0.75, 1.5 , 2.4 , 1.5 , 2.4 , 1.2 , 1.65, 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 4.2 , 5.1 , 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 4.2 , 5.1 , 2.1 , 2.55, 5.1 , 6. , 5.1 , 6. , 3. , 3.45, + 4.2 , 5.1 ,10.2 ,12. ,10.2 ,12. , 6. , 6.9 ,12. ,13.8 ,27.6 ,31.2 ,27.6 ,31.2 ,15.6 ,17.4 ,12. ,13.8 ,27.6 ,31.2 ,27.6 ,31.2 ,15.6 ,17.4 , 7.8 , 8.7 ,17.4 ,19.2 ,17.4 ,19.2 , 9.6 ,10.5 , + 4.2 , 5.1 ,10.2 ,12. ,10.2 ,12. , 6. , 6.9 ,12. ,13.8 ,27.6 ,31.2 ,27.6 ,31.2 ,15.6 ,17.4 ,12. ,13.8 ,27.6 ,31.2 ,27.6 ,31.2 ,15.6 ,17.4 , 7.8 , 8.7 ,17.4 ,19.2 ,17.4 ,19.2 , 9.6 ,10.5 , + 3.9 , 4.35, 8.7 , 9.6 , 8.7 , 9.6 , 4.8 , 5.25, 9.6 ,10.5 ,21. ,22.8 ,21. ,22.8 ,11.4 ,12.3 , 9.6 ,10.5 ,21. ,22.8 ,21. ,22.8 ,11.4 ,12.3 , 5.7 , 6.15,12.3 ,13.2 ,12.3 ,13.2 , 6.6 , 7.05}); + input = 0.5; + weights.linspace(0.1, 0.1); + + sd::ops::deconv3d op; + auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); + auto output = results.at(0); + + // output->printBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, deconv3d_test2) { + + int bS=2, iD=4,iH=4,iW=4, iC=2,oC=3, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=4,oH=4,oW=4; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto exp = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0.3 , 0.75, 1.5 , 2.4 , 1.5 , 2.4 , 1.5 , 2.4 , 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 6.6 , 8.4 , 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 6.6 , 8.4 , 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 6.6 , 8.4 , + 4.2 , 5.1 ,10.2 , 12. ,10.2 , 12. ,10.2 , 12. ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 , + 4.2 , 5.1 ,10.2 , 12. ,10.2 , 12. ,10.2 , 12. ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 , + 4.2 , 5.1 ,10.2 , 12. ,10.2 , 12. ,10.2 , 12. ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 , + 0.3 , 0.75, 1.5 , 2.4 , 1.5 , 2.4 , 1.5 , 2.4 , 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 6.6 , 8.4 , 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 6.6 , 8.4 , 2.4 , 3.3 , 6.6 , 8.4 , 6.6 , 8.4 , 6.6 , 8.4 , + 4.2 , 5.1 ,10.2 , 12. ,10.2 , 12. ,10.2 , 12. ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 , + 4.2 , 5.1 ,10.2 , 12. ,10.2 , 12. ,10.2 , 12. ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 , + 4.2 , 5.1 ,10.2 , 12. ,10.2 , 12. ,10.2 , 12. ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 ,12. , 13.8 ,27.6 , 31.2 ,27.6 , 31.2 ,27.6 , 31.2 }); + input = 0.5; + weights.linspace(0.1, 0.1); + + sd::ops::deconv3d op; + auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, deconv3d_test3) { + + int bS=2, iD=4,iH=4,iW=4, iC=2,oC=3, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=3,oH=3,oW=3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); + auto weights = NDArrayFactory::create('c', {oC, iC, kD, kH, kW}); + auto exp = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {2.55, 5.25, 5.25, 2.7, 5.4 , 11.1 , 11.1 , 5.7, 5.4 , 11.1 , 11.1 , 5.7, 2.85, 5.85, 5.85, 3. , 5.7 , 11.7 , 11.7 , 6. ,12. , 24.6 , 24.6 , 12.6,12. , 24.6 , 24.6 , 12.6, 6.3 , 12.9 , 12.9 , 6.6, + 5.7 , 11.7 , 11.7 , 6. ,12. , 24.6 , 24.6 , 12.6,12. , 24.6 , 24.6 , 12.6, 6.3 , 12.9 , 12.9 , 6.6, 3.15, 6.45, 6.45, 3.3, 6.6 , 13.5 , 13.5 , 6.9, 6.6 , 13.5 , 13.5 , 6.9, 3.45, 7.05, 7.05, 3.6, + 3.75, 7.65, 7.65, 3.9, 7.8 , 15.9 , 15.9 , 8.1, 7.8 , 15.9 , 15.9 , 8.1, 4.05, 8.25, 8.25, 4.2, 8.1 , 16.5 , 16.5 , 8.4,16.8 , 34.2 , 34.2 , 17.4,16.8 , 34.2 , 34.2 , 17.4, 8.7 , 17.7 , 17.7 , 9. , + 8.1 , 16.5 , 16.5 , 8.4,16.8 , 34.2 , 34.2 , 17.4,16.8 , 34.2 , 34.2 , 17.4, 8.7 , 17.7 , 17.7 , 9. , 4.35, 8.85, 8.85, 4.5, 9. , 18.3 , 18.3 , 9.3, 9. , 18.3 , 18.3 , 9.3, 4.65, 9.45, 9.45, 4.8, + 2.55, 5.25, 5.25, 2.7, 5.4 , 11.1 , 11.1 , 5.7, 5.4 , 11.1 , 11.1 , 5.7, 2.85, 5.85, 5.85, 3. , 5.7 , 11.7 , 11.7 , 6. ,12. , 24.6 , 24.6 , 12.6,12. , 24.6 , 24.6 , 12.6, 6.3 , 12.9 , 12.9 , 6.6, + 5.7 , 11.7 , 11.7 , 6. ,12. , 24.6 , 24.6 , 12.6,12. , 24.6 , 24.6 , 12.6, 6.3 , 12.9 , 12.9 , 6.6, 3.15, 6.45, 6.45, 3.3, 6.6 , 13.5 , 13.5 , 6.9, 6.6 , 13.5 , 13.5 , 6.9, 3.45, 7.05, 7.05, 3.6, + 3.75, 7.65, 7.65, 3.9, 7.8 , 15.9 , 15.9 , 8.1, 7.8 , 15.9 , 15.9 , 8.1, 4.05, 8.25, 8.25, 4.2, 8.1 , 16.5 , 16.5 , 8.4,16.8 , 34.2 , 34.2 , 17.4,16.8 , 34.2 , 34.2 , 17.4, 8.7 , 17.7 , 17.7 , 9. , + 8.1 , 16.5 , 16.5 , 8.4,16.8 , 34.2 , 34.2 , 17.4,16.8 , 34.2 , 34.2 , 17.4, 8.7 , 17.7 , 17.7 , 9. , 4.35, 8.85, 8.85, 4.5, 9. , 18.3 , 18.3 , 9.3, 9. , 18.3 , 18.3 , 9.3, 4.65, 9.45, 9.45, 4.8}); + input = 0.5; + weights.linspace(0.1, 0.1); + weights.permutei({2, 3, 4, 1, 0}); + + sd::ops::deconv3d op; + auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, deconv3d_test4) { + + int bS=2, iD=2,iH=2,iW=2, iC=2,oC=3, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=1,pH=1,pW=1, dD=1,dH=1,dW=1; + int oD=3,oH=3,oW=3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); + auto weights = NDArrayFactory::create('c', {oC, iC, kD, kH, kW}); + auto exp = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {24.6, 24.6,24.6, 24.6,24.6, 24.6,24.6, 24.6,34.2, 34.2,34.2, 34.2,34.2, 34.2,34.2, 34.2,24.6, 24.6,24.6, 24.6, + 24.6, 24.6,24.6, 24.6,34.2, 34.2,34.2, 34.2,34.2, 34.2,34.2, 34.2}); + input = 0.5; + weights.linspace(0.1, 0.1); + weights.permutei({2, 3, 4, 1, 0}); + + sd::ops::deconv3d op; + auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, deconv3d_test5) { + int bS=1, oD=5,oH=5,oW=5, oC=3,iC=2, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=2,dH=2,dW=2; + int iD=3,iH=3,iW=3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, oC, iC}); + auto bias = NDArrayFactory::create('c', {oC}); + + auto exp = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}, {-2.9f, -6.8f, -10.7f, -2.6f, -6.1f, -9.6f, -16.9f, -23.9f, -30.9f, -13.1f, -16.6f, -20.1f, -11.6f, -14.7f, -17.8f, -2.0f, -4.7f, -7.4f, -1.7f, -4.0f, -6.3f, -11.5f, + -16.1f, -20.7f, -8.6f, -10.9f, -13.2f, -7.1f, -9.0f, -10.9f, -27.4f, -32.8f, -38.2f, -24.4f, -29.0f, -33.6f, -65.0f, -74.2f, -83.4f, -38.2f, -42.8f, -47.4f, -32.8f, + -36.6f, -40.4f, -18.2f, -20.9f, -23.6f, -15.5f, -17.8f, -20.1f, -39.1f, -43.7f, -48.3f, -22.4f, -24.7f, -27.0f, -18.5f, -20.4f, -22.3f, -10.1f, -11.6f, -13.1f, -7.4f, + -8.5f, -9.6f, -19.3f, -21.5f, -23.7f, -10.7f, -11.8f, -12.9f, -6.8f, -7.5f, -8.2f, -0.2f, -0.5f, -0.8f, 0.1f, 0.2f, 0.3f, -0.7f, -0.5f, -0.3f, 0.4f, 0.5f, 0.6f, 1.9f, 2.4f, + 2.9f, 0.7f, 1.6f, 2.5f, 1.0f, 2.3f, 3.6f, 4.7f, 7.3f, 9.9f, 4.9f, 6.2f, 7.5f, 6.4f, 8.1f, 9.8f, -0.4f, 1.4f, 3.2f, 2.6f, 5.2f, 7.8f, 10.6f, 15.8f, 21.0f, 10.4f, 13.0f, 15.6f, + 15.8f, 19.2f, 22.6f, 6.1f, 7.0f, 7.9f, 8.8f, 10.1f, 11.4f, 20.3f, 22.9f, 25.5f, 12.7f, 14.0f, 15.3f, 16.6f, 18.3f, 20.0f, 14.2f, 16.3f, 18.4f, 16.9f, 19.4f, 21.9f, 40.1f, + 45.1f, 50.1f, 24.4f, 26.9f, 29.4f, 28.3f, 31.2f, 34.1f, -47.2f, -47.8f, -48.4f, -41.8f, -41.6f, -41.4f, -85.4f, -85.f, -84.6f, -41.2f, -41.0f, -40.8f, -33.4f, -32.4f, -31.4f, + -31.f, -29.2f, -27.4f, -25.6f, -23.0f, -20.4f, -45.8f, -40.6f, -35.4f, -17.8f, -15.2f, -12.6f, -10.0f, -6.6f, -3.2f, -65.6f, -62.0f, -58.4f, -50.0f, -44.8f, -39.6f, -89.2f, + -78.8f, -68.4f, -34.4f, -29.2f, -24.f, -14.0f, -7.2f, -0.4f, -20.2f, -18.4f, -16.6f, -10.f, -7.4f, -4.8f, -14.6f, -9.4f, -4.2f, -2.2f, 0.4f, 3.0f, 10.4f, 13.8f, 17.2f, 10.4f, + 14.6f, 18.8f, 20.6f, 25.6f, 30.6f, 53.8f, 63.8f, 73.8f, 35.6f, 40.6f, 45.6f, 48.2f, 54.0f, 59.8f, -3.8f, -4.1f, -4.4f, 1.3f, 1.4f, 1.5f, 1.7f, 1.9f, 2.1f, 1.6f, 1.7f, 1.8f, 7.9f, + 8.4f, 8.9f, 11.5f, 12.4f, 13.3f, 16.6f, 17.9f, 19.2f, 35.9f, 38.5f, 41.1f, 20.5f, 21.8f, 23.1f, 26.8f, 28.5f, 30.2f, 21.2f, 23.0f, 24.8f, 33.8f, 36.4f, 39.0f, 73.0f, 78.2f, + 83.4f, 41.6f, 44.2f, 46.8f, 56.6f, 60.0f, 63.4f, 16.9f, 17.8f, 18.7f, 24.4f, 25.7f, 27.f, 51.5f, 54.1f, 56.7f, 28.3f, 29.6f, 30.9f, 37.0f, 38.7f, 40.4f, 39.4f, 41.5f, + 43.6f, 46.9f, 49.4f, 51.9f, 100.1f, 105.1f, 110.1f, 54.4f, 56.9f, 59.4f, 63.1f, 66.0f, 68.9f, 42.1f, 45.4f, 48.7f, 47.2f, 50.9f, 54.6f, 104.3f, 111.7f, + 119.1f, 58.3f, 62.0f, 65.7f, 64.6f, 68.7f, 72.8f, 57.4f, 61.9f, 66.4f, 62.5f, 67.4f, 72.3f, 138.5f, 148.3f, 158.1f, 77.2f, 82.1f, 87.0f, 83.5f, 88.8f, 94.1f, + 134.6f, 143.6f, 152.6f, 147.2f, 157.0f, 166.8f, 321.4f, 341.0f, 360.6f, 176.6f, 186.4f, 196.2f, 191.6f, 202.2f, 212.8f, 84.4f, 88.9f, + 93.4f, 91.9f, 96.8f, 101.7f, 197.3f, 207.1f, 216.9f, 106.6f, 111.5f, 116.4f, 115.3f, 120.6f, 125.9f, 106.9f, 112.6f, 118.3f, 114.4f, 120.5f, 126.6f, 245.9f, 258.1f, 270.3f, 132.7f, 138.8f, 144.9f, 141.4f, 147.9f, 154.4f}); + + input.linspace(-10, 0.5); + weights.linspace(0.1, 0.1); + bias = 0.2; + + sd::ops::deconv3d op; + auto results = op.evaluate({&input, &weights}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}); + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, deconv3d_test6) { + + int bS=2, oD=4,oH=4,oW=4, oC=5,iC=10, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int iD=3,iH=3,iW=3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = 1; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] + + NDArray input('c', {bS, iD, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {iC, oC, kD, kH, kW}, {20., 15., 10., 5., 0., -5., -10., -15., 19., 14., 9., 4., -1., -6., -11., -16., 18., 13., 8., 3., -2., -7., -12., -17., + 17., 12., 7., 2., -3., -8., -13., -18., 16., 11., 6., 1., -4., -9., -14., -19., 19.9, 14.9, 9.9, 4.9, -0.1, -5.1, -10.1, -15.1, 18.9, 13.9, 8.9, 3.9, -1.1, -6.1, + -11.1, -16.1, 17.9, 12.9, 7.9, 2.9, -2.1, -7.1, -12.1, -17.1, 16.9, 11.9, 6.9, 1.9, -3.1, -8.1, -13.1, -18.1, 15.9, 10.9, 5.9, 0.9, -4.1, -9.1, -14.1, -19.1, + 19.799999, 14.8, 9.8, 4.8, -0.2, -5.2, -10.2, -15.2, 18.799999, 13.8, 8.8, 3.8, -1.2, -6.2, -11.2, -16.200001, 17.799999, 12.8, 7.8, 2.8, -2.2, -7.2, -12.2, + -17.200001, 16.799999, 11.8, 6.8, 1.8, -3.2, -8.2, -13.2, -18.200001, 15.8, 10.8, 5.8, 0.8, -4.2, -9.2, -14.2, -19.200001, 19.700001, 14.7, 9.7, 4.7, -0.3, -5.3, -10.3, -15.3, 18.700001, 13.7, 8.7, 3.7, -1.3, -6.3, -11.3, -16.299999, 17.700001, 12.7, 7.7, 2.7, -2.3, -7.3, -12.3, -17.299999, 16.700001, 11.7, 6.7, 1.7, -3.3, -8.3, -13.3, -18.299999, 15.7, 10.7, 5.7, 0.7, -4.3, -9.3, -14.3, -19.299999, 19.6, 14.6, 9.6, 4.6, -0.4, -5.4, -10.4, -15.4, 18.6, 13.6, 8.6, 3.6, -1.4, -6.4, -11.4, -16.4, 17.6, 12.6, 7.6, 2.6, -2.4, -7.4, -12.4, -17.4, 16.6, 11.6, 6.6, 1.6, -3.4, -8.4, -13.4, -18.4, 15.6, 10.6, 5.6, 0.6, -4.4, -9.4, -14.4, -19.4, 19.5, 14.5, 9.5, 4.5, -0.5, -5.5, -10.5, -15.5, 18.5, 13.5, 8.5, 3.5, -1.5, -6.5, -11.5, -16.5, 17.5, 12.5, 7.5, 2.5, -2.5, -7.5, -12.5, -17.5, 16.5, 11.5, 6.5, 1.5, -3.5, -8.5, -13.5, -18.5, 15.5, 10.5, 5.5, 0.5, -4.5, -9.5, -14.5, -19.5, 19.4, 14.4, 9.4, 4.4, -0.6, -5.6, -10.6, -15.6, 18.4, 13.4, 8.4, 3.4, -1.6, -6.6, -11.6, -16.6, 17.4, 12.4, 7.4, 2.4, -2.6, -7.6, -12.6, -17.6, 16.4, 11.4, 6.4, 1.4, -3.6, -8.6, -13.6, -18.6, 15.4, 10.4, 5.4, 0.4, -4.6, -9.6, -14.6, -19.6, 19.299999, 14.3, 9.3, 4.3, -0.7, -5.7, -10.7, -15.7, 18.299999, 13.3, 8.3, 3.3, -1.7, -6.7, -11.7, -16.700001, 17.299999, 12.3, 7.3, 2.3, -2.7, -7.7, -12.7, -17.700001, 16.299999, 11.3, 6.3, 1.3, -3.7, -8.7, -13.7, -18.700001, 15.3, 10.3, 5.3, 0.3, -4.7, -9.7, -14.7, -19.700001, 19.200001, 14.2, 9.2, 4.2, -0.8, -5.8, -10.8, -15.8, 18.200001, 13.2, 8.2, 3.2, -1.8, -6.8, -11.8, -16.799999, 17.200001, 12.2, 7.2, 2.2, -2.8, -7.8, -12.8, -17.799999, 16.200001, 11.2, 6.2, 1.2, -3.8, -8.8, -13.8, -18.799999, 15.2, 10.2, 5.2, 0.2, -4.8, -9.8, -14.8, -19.799999, 19.1, 14.1, 9.1, 4.1, -0.9, -5.9, -10.9, -15.9, 18.1, 13.1, 8.1, 3.1, -1.9, -6.9, -11.9, -16.9, 17.1, 12.1, 7.1, 2.1, -2.9, -7.9, -12.9, -17.9, 16.1, 11.1, 6.1, 1.1, -3.9, -8.9, -13.9, -18.9, 15.1, 10.1, 5.1, 0.1, -4.9, -9.9, -14.9, -19.9}, sd::DataType::FLOAT32); + NDArray expOutput('c', {bS, oD, oH, oW, oC}, {-5191.349609, -4925.850098, -4660.350098, -4394.850098, -4129.349609, -8859.700195, -8338.700195, -7817.700195, + -7296.700195, -6775.700195, -8518.700195, -8017.700195, -7516.700195, -7015.700195, -6514.700195, -3572.850098, -3327.349854, -3081.850098, -2836.350098, + -2590.850098, -7141.200195, -6640.200195, -6139.199707, -5638.200195, -5137.200195, -11486.400391, -10504.400391, -9522.400391, -8540.400391, -7558.399902, + -11004.400391, -10062.400391, -9120.400391, -8178.399414, -7236.399414, -4254.200195, -3793.200195, -3332.200195, -2871.199951, -2410.200195, -6268.200195, + -5827.200195, -5386.200195, -4945.200195, -4504.200195, -10040.400391, -9178.400391, -8316.400391, -7454.400391, -6592.399902, -9558.400391, -8736.400391, + -7914.400391, -7092.399902, -6270.400391, -3681.199707, -3280.200195, -2879.200195, -2478.200195, -2077.200195, -1963.350098, -1757.850098, -1552.349854, -1346.849976, -1141.349976, -2803.700195, -2402.699951, -2001.699951, -1600.699951, -1199.699951, -2662.699951, -2281.699951, -1900.699951, -1519.699951, -1138.700073, -844.850037, -659.349976, -473.850006, -288.350006, -102.849998, -3313.200195, -2872.199951, -2431.200195, -1990.200195, -1549.199829, -4230.399902, -3368.400391, -2506.400391, -1644.400146, -782.400146, -3948.400146, -3126.400391, -2304.399902, -1482.400146, -660.400269, -926.200195, -525.199951, -124.199951, 276.799927, 677.799805, -1643.400269, -821.400146, 0.599609, 822.600098, 1644.599609, 1005.199951, 2609.199707, 4213.200195, 5817.200195, 7421.200684, 1169.199463, 2693.200195, 4217.199707, 5741.201172, 7265.203125, 2430.599609, 3172.600098, 3914.600098, 4656.599609, 5398.599609, -1097.400391, -395.400269, 306.599609, 1008.599854, 1710.599731, 1497.199219, 2861.199219, 4225.201172, 5589.200684, 6953.200684, 1661.199219, 2945.199463, 4229.199707, 5513.201172, 6797.200684, 2376.599609, 2998.599854, 3620.599609, 4242.600098, 4864.600098, 1042.799927, 1363.799927, 1684.800171, 2005.799805, 2326.799805, 3681.599609, 4303.599609, 4925.599609, 5547.600098, 6169.599609, 3563.599609, 4145.599609, 4727.600098, 5309.600098, 5891.599609, 2429.800293, 2710.800293, 2991.799805, 3272.799805, 3553.799805, -1594.199829, -1333.199951, -1072.200073, -811.200012, -550.200134, -1692.400024, -1190.399902, -688.400024, -186.400269, 315.600098, -1410.399902, -948.399902, -486.399902, -24.399780, 437.599731, -107.199890, 113.799988, 334.799988, 555.799988, 776.800049, -5.400024, 456.599731, 918.600281, 1380.599731, 1842.599976, 2481.199219, 3365.199219, 4249.199219, 5133.199219, 6017.199219, 2645.199219, 3449.199219, 4253.199707, 5057.199219, 5861.199707, 2268.600098, 2650.599609, 3032.600098, 3414.600098, 3796.599609, 540.599976, 882.600220, 1224.599854, 1566.599854, 1908.600220, 2973.200195, 3617.199707, 4261.199219, 4905.199219, 5549.199219, 3137.199707, 3701.199219, 4265.199707, 4829.199219, 5393.199219, 2214.599609, 2476.600098, 2738.599609, 3000.599854, 3262.599854, 961.800049, 1102.800049, 1243.799927, 1384.800171, 1525.799927, 2619.599609, 2881.599854, 3143.599854, 3405.599609, 3667.599609, 2501.599854, 2723.599609, 2945.599854, 3167.599609, 3389.600098, 1448.799927, 1549.800049, 1650.799927, 1751.800049, 1852.799927, 37.650002, 123.150009, 208.650009, 294.149994, 379.650024, 498.300018, 659.300049, 820.300049, 981.299927, 1142.299927, 439.300018, 580.299988, 721.299927, 862.300049, 1003.300049, 356.149963, 421.649994, 487.150024, 552.649963, 618.150024, 916.799988, 1057.800049, 1198.800171, 1339.800049, 1480.800171, 2429.600098, 2691.600098, 2953.599609, 3215.599609, 3477.599609, 2111.599854, 2333.599854, 2555.600098, 2777.599609, 2999.600098, 1203.800049, 1304.800049, 1405.799927, 1506.800049, 1607.800049, 589.799927, 670.800049, 751.800049, 832.800049, 913.800049, 1475.599976, 1617.600098, 1759.600098, 1901.600098, 2043.600098, 1157.600098, 1259.600098, 1361.600098, 1463.600098, 1565.599976, 576.799988, 617.800049, 658.799988, 699.799927, 740.800049, 265.649994, 291.149994, 316.650024, 342.150024, 367.649994, 554.300049, 595.299988, 636.299927, 677.299988, 718.299988, 295.300018, 316.300018, 337.299988, 358.299988, 379.300018, 84.149994, 89.650002, 95.150002, 100.650009, 106.150009, 87.150002, 82.650002, 78.150002, 73.650002, 69.150002, 347.299988, 328.300018, 309.300018, 290.299988, 271.299988, 688.300049, 649.299927, 610.299988, 571.300049, 532.300049, 355.650024, 331.149963, 306.649994, 282.149994, 257.649994, 715.800049, 676.800049, 637.799988, 598.800049, 559.800049, 1527.600098, 1429.599976, 1331.599976, 1233.600098, 1135.600098, 2009.600098, 1871.600098, 1733.599976, 1595.600098, 1457.600098, 902.799988, 823.799927, 744.800049, 665.800049, 586.800049, 1588.800049, 1489.800049, 1390.800049, 1291.800049, 1192.799927, 2973.600098, 2755.600098, 2537.600098, 2319.600098, 2101.600098, 3455.600098, 3197.600098, 2939.600098, 2681.600098, 2423.600098, 1475.800049, 1336.800049, 1197.800049, 1058.799927, 919.800049, 615.150024, 550.650024, 486.149994, 421.649994, 357.150024, 1003.300049, 864.300049, 725.299988, 586.300049, 447.300018, 1144.300049, 985.299988, 826.300049, 667.299988, 508.299988, 383.649994, 299.149994, 214.649994, 130.149994, 45.649998, 1843.799927, 1744.799927, 1645.800049, 1546.799927, 1447.800049, 3383.600098, 3165.600098, 2947.600098, 2729.599854, 2511.600098, 3665.599854, 3407.600098, 3149.599854, 2891.599854, 2633.599854, 1530.800171, 1391.800049, 1252.800049, 1113.800049, 974.800171, 3270.599609, 3012.599854, 2754.600098, 2496.599854, 2238.600098, 5433.199707, 4877.200195, 4321.200195, 3765.199707, 3209.199951, 5597.200195, 4961.199707, 4325.200195, 3689.199707, 3053.199951, 1944.600098, 1606.599854, 1268.600098, 930.599976, 592.600098, 3816.599854, 3438.600342, 3060.599854, 2682.600098, 2304.600098, 5925.200195, 5129.200684, 4333.200195, 3537.199951, 2741.199707, 6089.200684, 5213.200195, 4337.200195, 3461.199707, 2585.200195, 1890.599609, 1432.600220, 974.599976, 516.599976, 58.599976, 799.799927, 580.800171, 361.800110, 142.800110, -76.200073, 495.599976, 37.599976, -420.399902, -878.399902, -1336.400024, 377.599854, -120.399902, -618.399902, -1116.400391, -1614.399902, -513.199951, -772.200012, -1031.199951, -1290.199829, -1549.200073, 3562.800049, 3283.799805, 3004.799805, 2725.800293, 2446.800293, 5921.599609, 5343.599609, 4765.600098, 4187.599609, 3609.599854, 6203.599609, 5585.600098, 4967.600098, 4349.599609, 3731.600098, 2349.799805, 2030.800171, 1711.800293, 1392.800171, 1073.799927, 4908.600098, 4290.599609, 3672.600098, 3054.600098, 2436.600098, 6909.199219, 5633.200684, 4357.200195, 3081.199219, 1805.199463, 7073.200684, 5717.199707, 4361.199219, 3005.199463, 1649.199951, 1782.600464, 1084.599609, 386.599609, -311.400146, -1009.400635, 5454.600098, 4716.599609, 3978.599854, 3240.600098, 2502.600098, 7401.199219, 5885.199219, 4369.200195, 2853.200195, 1337.199219, 7565.199219, 5969.200195, 4373.200195, 2777.199219, 1181.199219, 1728.599854, 910.600098, 92.600098, -725.400391, -1543.400391, 718.799927, 319.800049, -79.200073, -478.200073, -877.200073, -566.400391, -1384.400391, -2202.400391, -3020.400391, -3838.400391, -684.400146, -1542.400391, -2400.400391, -3258.400391, -4116.400391, -1494.200073, -1933.200073, -2372.199707, -2811.200195, -3250.199951, -83.850006, -268.350006, -452.849945, -637.350037, -821.849976, -1094.699951, -1473.699951, -1852.700073, -2231.699707, -2610.699951, -1153.700073, -1552.699829, -1951.699829, -2350.700195, -2749.700195, -1115.350098, -1319.849854, -1524.350098, -1728.849976, -1933.350098, -2026.200073, -2425.200195, -2824.200195, -3223.199707, -3622.200195, -6156.400391, -6974.400391, -7792.400391, -8610.400391, -9428.399414, -6474.400391, -7332.400391, -8190.400391, -9048.399414, -9906.399414, -4439.200195, -4878.199707, -5317.200195, -5756.200195, -6195.200195, -2353.199951, -2812.200195, -3271.200195, -3730.200195, -4189.200195, -7110.400391, -8048.400391, -8986.399414, -9924.400391, -10862.400391, -7428.400391, -8406.399414, -9384.399414, -10362.400391, -11340.400391, -5066.200195, -5565.200195, -6064.200195, -6563.200195, -7062.200195, -2555.849854, -2800.349854, -3044.849854, -3289.350098, -3533.850098, -6438.700195, -6937.700195, -7436.700195, -7935.700195, -8434.699219, -6697.700195, -7216.700195, -7735.700195, -8254.699219, -8773.700195, -4087.349854, -4351.850098, -4616.349609, -4880.850098, -5145.350098}, sd::DataType::FLOAT32); + + input.linspace(-27, 0.1); + + sd::ops::deconv3d op; + auto results = op.evaluate({&input, &weights}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat}); + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output, 1e-3)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, deconv3d_test7) { + + int bS=2, oD=4,oH=4,oW=4, iC=5,oC=10, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=1,pH=0,pW=0, dD=1,dH=1,dW=1; + int iD=4,iH=4,iW=4; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = 2; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] + + NDArray input('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); + NDArray weights('c', {iC, kD, kH, kW, oC}, {20., 19.5, 19., 18.5, 18., 17.5, 17., 16.5, 16., 15.5, 15., 14.5, 14., 13.5, 13., 12.5, 12., 11.5, 11., 10.5, 10., + 9.5, 9., 8.5, 8., 7.5, 7., 6.5, 6., 5.5, 5., 4.5, 4., 3.5, 3., 2.5, 2., 1.5, 1., 0.5, 0., -0.5, -1., -1.5, -2., -2.5, -3., -3.5, -4., -4.5, -5., -5.5, -6., + -6.5, -7., -7.5, -8., -8.5, -9., -9.5, -10., -10.5, -11., -11.5, -12., -12.5, -13., -13.5, -14., -14.5, -15., -15.5, -16., -16.5, -17., -17.5, -18., -18.5, + -19., -19.5, 19.9, 19.4, 18.9, 18.4, 17.9, 17.4, 16.9, 16.4, 15.9, 15.4, 14.9, 14.4, 13.9, 13.4, 12.9, 12.4, 11.9, 11.4, 10.9, 10.4, 9.9, 9.4, 8.9, 8.4, 7.9, + 7.4, 6.9, 6.4, 5.9, 5.4, 4.9, 4.4, 3.9, 3.4, 2.9, 2.4, 1.9, 1.4, 0.9, 0.4, -0.1, -0.6, -1.1, -1.6, -2.1, -2.6, -3.1, -3.6, -4.1, -4.6, -5.1, -5.6, -6.1, -6.6, -7.1, -7.6, -8.1, -8.6, -9.1, -9.6, -10.1, -10.6, -11.1, -11.6, -12.1, -12.6, -13.1, -13.6, -14.1, -14.6, -15.1, -15.6, -16.1, -16.6, -17.1, -17.6, -18.1, -18.6, -19.1, -19.6, 19.799999, 19.299999, 18.799999, 18.299999, 17.799999, 17.299999, 16.799999, 16.299999, 15.8, 15.3, 14.8, 14.3, 13.8, 13.3, 12.8, 12.3, 11.8, 11.3, 10.8, 10.3, 9.8, 9.3, 8.8, 8.3, 7.8, 7.3, 6.8, 6.3, 5.8, 5.3, 4.8, 4.3, 3.8, 3.3, 2.8, 2.3, 1.8, 1.3, 0.8, 0.3, -0.2, -0.7, -1.2, -1.7, -2.2, -2.7, -3.2, -3.7, -4.2, -4.7, -5.2, -5.7, -6.2, -6.7, -7.2, -7.7, -8.2, -8.7, -9.2, -9.7, -10.2, -10.7, -11.2, -11.7, -12.2, -12.7, -13.2, -13.7, -14.2, -14.7, -15.2, -15.7, -16.200001, -16.700001, -17.200001, -17.700001, -18.200001, -18.700001, -19.200001, -19.700001, 19.700001, 19.200001, 18.700001, 18.200001, 17.700001, 17.200001, 16.700001, 16.200001, 15.7, 15.2, 14.7, 14.2, 13.7, 13.2, 12.7, 12.2, 11.7, 11.2, 10.7, 10.2, 9.7, 9.2, 8.7, 8.2, 7.7, 7.2, 6.7, 6.2, 5.7, 5.2, 4.7, 4.2, 3.7, 3.2, 2.7, 2.2, 1.7, 1.2, 0.7, 0.2, -0.3, -0.8, -1.3, -1.8, -2.3, -2.8, -3.3, -3.8, -4.3, -4.8, -5.3, -5.8, -6.3, -6.8, -7.3, -7.8, -8.3, -8.8, -9.3, -9.8, -10.3, -10.8, -11.3, -11.8, -12.3, -12.8, -13.3, -13.8, -14.3, -14.8, -15.3, -15.8, -16.299999, -16.799999, -17.299999, -17.799999, -18.299999, -18.799999, -19.299999, -19.799999, 19.6, 19.1, 18.6, 18.1, 17.6, 17.1, 16.6, 16.1, 15.6, 15.1, 14.6, 14.1, 13.6, 13.1, 12.6, 12.1, 11.6, 11.1, 10.6, 10.1, 9.6, 9.1, 8.6, 8.1, 7.6, 7.1, 6.6, 6.1, 5.6, 5.1, 4.6, 4.1, 3.6, 3.1, 2.6, 2.1, 1.6, 1.1, 0.6, 0.1, -0.4, -0.9, -1.4, -1.9, -2.4, -2.9, -3.4, -3.9, -4.4, -4.9, -5.4, -5.9, -6.4, -6.9, -7.4, -7.9, -8.4, -8.9, -9.4, -9.9, -10.4, -10.9, -11.4, -11.9, -12.4, -12.9, -13.4, -13.9, -14.4, -14.9, -15.4, -15.9, -16.4, -16.9, -17.4, -17.9, -18.4, -18.9, -19.4, -19.9}, sd::DataType::FLOAT32); + NDArray expOutput('c', {bS, oC, oD, oH, oW}, {-1907.199951, -3324.499756, -3307.199707, -3289.899902, -2814.799805, -4664.800293, -4640.199707, -4615.600098, + -2755.599854, -4566.400391, -4541.800293, -4517.199707, -2696.400146, -4468., -4443.400391, -4418.799805, -1735.999878, -2542.199951, -2527.600098, -2513., + -1592.800049, -1355.999756, -1346.799805, -1337.599854, -1554.400024, -1319.199829, -1310.000122, -1300.800049, -1516., -1282.400024, -1273.200195, -1263.999878, + -1579.200073, -2308.599854, -2294., -2279.400146, -1439.199951, -1208.799683, -1199.599976, -1190.399902, -1400.800049, -1172., -1162.800049, -1153.600098, + -1362.399902, -1135.199951, -1126., -1116.799805, -1422.400024, -2075., -2060.399902, -2045.799683, -1285.599976, -1061.599854, -1052.399902, -1043.200195, + -1247.199951, -1024.800049, -1015.599976, -1006.400146, -1208.799927, -988.000122, -978.799683, -969.599976, -1859.199951, -3228.75, -3211.949951, -3195.150146, -2719.800049, -4475.299805, -4451.699707, -4428.100098, -2662.600098, -4380.899902, -4357.300293, -4333.699707, -2605.399902, -4286.5, -4262.899902, -4239.300293, -1643.999878, -2358.700195, -2345.099854, -2331.5, -1410.800049, -992.999756, -985.799438, -978.600098, -1376.400024, -964.199707, -957., -949.800049, -1342., -935.399902, -928.199951, -921.000122, -1495.200073, -2141.099854, -2127.5, -2113.900391, -1273.199951, -877.799683, -870.599976, -863.39978, -1238.800049, -849., -841.800171, -834.599976, -1204.400024, -820.199707, -813., -805.799438, -1346.400146, -1923.500122, -1909.899902, -1896.299927, -1135.599976, -762.599976, -755.399658, -748.200195, -1101.199951, -733.800049, -726.599854, -719.400024, -1066.800049, -705., -697.800171, -690.599976, -1811.199951, -3133., -3116.699951, -3100.399902, -2624.799805, -4285.799805, -4263.199707, -4240.600098, -2569.600098, -4195.399902, -4172.800293, -4150.199707, -2514.399902, -4105., -4082.400146, -4059.800293, -1552., -2175.200195, -2162.599854, -2150., -1228.800049, -630., -624.799561, -619.599854, -1198.400024, -609.199463, -603.999756, -598.800049, -1167.999878, -588.400391, -583.199951, -578., -1411.200073, -1973.599854, -1961.000122, -1948.400146, -1107.199829, -546.800171, -541.599976, -536.400269, -1076.800049, -525.999756, -520.800049, -515.599976, -1046.400146, -505.199829, -500., -494.799683, -1270.399902, -1772., -1759.400146, -1746.799927, -985.599976, -463.600098, -458.399902, -453.199951, -955.199951, -442.799927, -437.599976, -432.400269, -924.799988, -422.000122, -416.800171, -411.599976, -1763.199951, -3037.25, -3021.449951, -3005.649902, -2529.800293, -4096.299805, -4074.699951, -4053.100098, -2476.600098, -4009.900146, -3988.300049, -3966.699951, -2423.399902, -3923.5, -3901.899902, -3880.299805, -1459.999878, -1991.699951, -1980.099854, -1968.500122, -1046.800049, -266.999878, -263.799805, -260.599854, -1020.400146, -254.199829, -251., -247.799927, -994., -241.400269, -238.200073, -234.999878, -1327.200073, -1806.099854, -1794.500122, -1782.900146, -941.199951, -215.799927, -212.600098, -209.399902, -914.799988, -203.000122, -199.799683, -196.599976, -888.400024, -190.200317, -186.999878, -183.799805, -1194.399902, -1620.500122, -1608.899902, -1597.299927, -835.599915, -164.599976, -161.400269, -158.200195, -809.200073, -151.799927, -148.599976, -145.400024, -782.799927, -139., -135.799805, -132.599976, -1715.200073, -2941.5, -2926.199951, -2910.899902, -2434.800049, -3906.799805, -3886.199951, -3865.599609, -2383.600098, -3824.400391, -3803.800049, -3783.199951, -2332.400146, -3742., -3721.400146, -3700.799805, -1367.999878, -1808.199707, -1797.599854, -1786.999878, -864.800049, 95.999878, 97.200073, 98.400024, -842.39978, 100.799927, 102.000244, 103.200439, -820., 105.599609, 106.800171, 108., -1243.199951, -1638.599854, -1628.000122, -1617.400146, -775.199829, 115.200195, 116.400146, 117.60022, -752.799805, 120., 121.200073, 122.400024, -730.399841, 124.799927, 125.999878, 127.199951, -1118.400024, -1468.999878, -1458.400146, -1447.799927, -685.599915, 134.400146, 135.60022, 136.800171, -663.199951, 139.200073, 140.399902, 141.599731, -640.799988, 144., 145.200195, 146.400146, -1667.199951, -2845.749756, -2830.949707, -2816.149902, -2339.799805, -3717.300049, -3697.699951, -3678.100098, -2290.600098, -3638.900146, -3619.300049, -3599.699951, -2241.399902, -3560.5, -3540.899902, -3521.299805, -1276., -1624.699951, -1615.100098, -1605.499878, -682.799927, 459.000122, 458.199951, 457.400146, -664.400024, 455.800049, 454.999878, 454.200439, -646.000122, 452.599976, 451.799805, 451.000122, -1159.200073, -1471.099854, -1461.5, -1451.900146, -609.199829, 446.200195, 445.400024, 444.600098, -590.799927, 443., 442.200073, 441.399658, -572.39978, 439.799927, 439.000122, 438.200073, -1042.399902, -1317.499756, -1307.900146, -1298.299683, -535.599976, 433.399963, 432.600098, 431.799744, -517.200012, 430.200195, 429.400024, 428.599976, -498.799927, 427.000061, 426.200256, 425.400024, -1619.199951, -2750., -2735.699951, -2721.399902, -2244.799805, -3527.799805, -3509.199951, -3490.600098, -2197.600098, -3453.400146, -3434.800049, -3416.199951, -2150.399902, -3379., -3360.400146, -3341.800049, -1184., -1441.199951, -1432.599854, -1424., -500.799927, 822.000122, 819.200195, 816.400146, -486.400024, 810.799927, 808.000244, 805.200073, -472., 799.60022, 796.799683, 794.000122, -1075.199951, -1303.599854, -1295.000122, -1286.400024, -443.199951, 777.200073, 774.400024, 771.599854, -428.799927, 766., 763.200317, 760.400024, -414.400146, 754.800049, 752.000244, 749.200195, -966.400146, -1166.000122, -1157.400146, -1148.799927, -385.600098, 732.400024, 729.599976, 726.799927, -371.200134, 721.200012, 718.400146, 715.599792, -356.799988, 710.000183, 707.199951, 704.400024, -1571.199951, -2654.25, -2640.449951, -2626.649902, -2149.800049, -3338.299805, -3320.699951, -3303.100098, -2104.600098, -3267.900146, -3250.299805, -3232.699951, -2059.399902, -3197.5, -3179.900146, -3162.300049, -1092., -1257.699951, -1250.099854, -1242.499878, -318.799927, 1185.000122, 1180.200439, 1175.400146, -308.399902, 1165.800293, 1161.000122, 1156.200073, -298., 1146.599731, 1141.800049, 1137.000122, -991.199951, -1136.099976, -1128.500122, -1120.899902, -277.199951, 1108.199829, 1103.400146, 1098.599976, -266.799927, 1089.000366, 1084.199951, 1079.400024, -256.399902, 1069.799927, 1065.000122, 1060.200317, -890.400024, -1014.5, -1006.900024, -999.299988, -235.599976, 1031.399902, 1026.599854, 1021.800049, -225.199951, 1012.200195, 1007.400024, 1002.599854, -214.799805, 992.999878, 988.199707, 983.400146, -1523.199951, -2558.5, -2545.199951, -2531.899902, -2054.800049, -3148.800049, -3132.199951, -3115.599854, -2011.599976, -3082.400146, -3065.800049, -3049.199951, -1968.400024, -3016., -2999.400146, -2982.799805, -1000.000061, -1074.199951, -1067.599976, -1061.000244, -136.799805, 1548.000244, 1541.200195, 1534.400269, -130.400146, 1520.800171, 1514.000122, 1507.200073, -124., 1493.600098, 1486.799805, 1480.000244, -907.200073, -968.599976, -962.000122, -955.400085, -111.199951, 1439.200073, 1432.399902, 1425.599854, -104.800049, 1412.000122, 1405.200195, 1398.400024, -98.400024, 1384.799927, 1378.000366, 1371.200195, -814.400024, -862.999939, -856.399902, -849.799927, -85.599976, 1330.400024, 1323.599854, 1316.799927, -79.200073, 1303.200073, 1296.399902, 1289.599731, -72.799927, 1276., 1269.200195, 1262.400024, -1475.200073, -2462.75, -2449.949951, -2437.149902, -1959.800049, -2959.299805, -2943.699951, -2928.099854, -1918.599976, -2896.900146, -2881.300049, -2865.699951, -1877.399902, -2834.5, -2818.900146, -2803.300049, -907.999939, -890.700012, -885.099915, -879.499878, 45.199829, 1911., 1902.200073, 1893.400024, 47.599976, 1875.800293, 1867.000244, 1858.200073, 49.999878, 1840.599976, 1831.800171, 1823.000244, -823.200073, -801.100098, -795.500061, -789.900024, 54.799927, 1770.199951, 1761.400269, 1752.599976, 57.200073, 1735., 1726.200073, 1717.400269, 59.599976, 1699.799805, 1691., 1682.200073, -738.400024, -711.499817, -705.900085, -700.299927, 64.400146, 1629.399902, 1620.599976, 1611.800171, 66.800049, 1594.200195, 1585.39978, 1576.599976, 69.200073, 1559.000122, 1550.199829, 1541.400146, 1260.800049, 2211.5, 2228.800049, 2246.100098, 1921.200073, 3207.200195, 3231.800049, 3256.399902, 1980.400024, 3305.599854, 3330.200195, 3354.800049, 2039.599854, 3404., 3428.599854, 3453.200195, 1400., 2129.800049, 2144.400146, 2159., 1479.199951, 1588.000244, 1597.200073, 1606.400024, 1517.599976, 1624.800171, 1634., 1643.199951, 1556., 1661.600098, 1670.800171, 1679.999878, 1556.799927, 2363.400146, 2378., 2392.600098, 1632.799805, 1735.199951, 1744.400146, 1753.600098, 1671.199829, 1771.999878, 1781.200073, 1790.400024, 1709.60022, 1808.800171, 1818.000244, 1827.200073, 1713.599976, 2597., 2611.599854, 2626.199951, 1786.400024, 1882.400024, 1891.600098, 1900.800171, 1824.799805, 1919.200195, 1928.400146, 1937.600098, 1863.199951, 1956., 1965.199951, 1974.400391, 1228.800049, 2147.25, 2164.049805, 2180.850098, 1856.199951, 3076.700195, 3100.300049, 3123.899902, 1913.400024, 3171.099854, 3194.700195, 3218.300049, 1970.599976, 3265.5, 3289.099854, 3312.699951, 1332., 1993.300049, 2006.900146, 2020.499878, 1341.199951, 1310.999878, 1318.199951, 1325.400146, 1375.60022, 1339.800171, 1347., 1354.199951, 1410., 1368.600098, 1375.800171, 1383., 1480.800049, 2210.900146, 2224.5, 2238.100098, 1478.799805, 1426.200073, 1433.400146, 1440.599609, 1513.199951, 1455., 1462.199951, 1469.400024, 1547.60022, 1483.799927, 1490.999878, 1498.199951, 1629.599976, 2428.500244, 2442.100098, 2455.699951, 1616.399902, 1541.400146, 1548.600098, 1555.799683, 1650.800049, 1570.200073, 1577.400024, 1584.600098, 1685.199951, 1598.99939, 1606.200317, 1613.400024, 1196.800049, 2083., 2099.300049, 2115.600098, 1791.200073, 2946.200195, 2968.800049, 2991.400146, 1846.400024, 3036.599854, 3059.200195, 3081.800049, 1901.599976, 3127., 3149.599854, 3172.200195, 1264., 1856.800049, 1869.400146, 1881.999878, 1203.200073, 1034., 1039.200073, 1044.400146, 1233.599976, 1054.799927, 1059.999878, 1065.199951, 1263.999878, 1075.599609, 1080.800171, 1086., 1404.799927, 2058.400146, 2071., 2083.599854, 1324.799927, 1117.199951, 1122.400146, 1127.599609, 1355.199951, 1138., 1143.200439, 1148.400146, 1385.599976, 1158.800171, 1164.000244, 1169.200073, 1545.599976, 2260., 2272.600098, 2285.199951, 1446.400024, 1200.400146, 1205.600098, 1210.800171, 1476.799805, 1221.199951, 1226.400024, 1231.600098, 1507.199951, 1242.000244, 1247.200073, 1252.400146, 1164.800049, 2018.75, 2034.549927, 2050.350098, 1726.200073, 2815.700195, 2837.300049, 2858.900146, 1779.400024, 2902.099854, 2923.700195, 2945.300049, 1832.599976, 2988.5, 3010.099854, 3031.700195, 1196.000122, 1720.300049, 1731.900146, 1743.499878, 1065.200073, 757.000122, 760.200073, 763.400024, 1091.599976, 769.800171, 773., 776.199951, 1118., 782.599976, 785.800049, 789., 1328.800049, 1905.900146, 1917.499878, 1929.100098, 1170.799805, 808.200073, 811.400024, 814.60022, 1197.199951, 821., 824.199951, 827.400024, 1223.599976, 833.799927, 837.000244, 840.199951, 1461.599976, 2091.5, 2103.100098, 2114.700195, 1276.400146, 859.400024, 862.600098, 865.800293, 1302.799927, 872.200073, 875.400146, 878.599854, 1329.199951, 885., 888.199951, 891.400024, 1132.800049, 1954.500122, 1969.799927, 1985.099976, 1661.199951, 2685.200195, 2705.800049, 2726.399902, 1712.399902, 2767.599854, 2788.200195, 2808.800049, 1763.599976, 2850., 2870.599854, 2891.199951, 1128., 1583.800049, 1594.400146, 1605., 927.200012, 480., 481.199951, 482.400146, 949.599976, 484.800171, 486., 487.200073, 971.999878, 489.599731, 490.800171, 492.000122, 1252.799927, 1753.400146, 1763.999878, 1774.600098, 1016.799805, 499.200195, 500.400024, 501.60022, 1039.199951, 504., 505.199951, 506.400146, 1061.599976, 508.799927, 510., 511.200195, 1377.599976, 1923.000122, 1933.600098, 1944.200073, 1106.400024, 518.400024, 519.60022, 520.800171, 1128.799927, 523.199829, 524.400024, 525.600098, 1151.199829, 528., 529.199829, 530.400146, 1100.800049, 1890.25, 1905.050049, 1919.849976, 1596.199951, 2554.700195, 2574.300049, 2593.900146, 1645.399902, 2633.099854, 2652.700195, 2672.300049, 1694.599976, 2711.5, 2731.099854, 2750.700195, 1060., 1447.299805, 1456.900146, 1466.499878, 789.200012, 203.000122, 202.200195, 201.400146, 807.600098, 199.800171, 199., 198.200195, 826., 196.599731, 195.800049, 195., 1176.799927, 1600.900146, 1610.500244, 1620.099854, 862.80011, 190.200317, 189.400146, 188.60022, 881.199951, 187., 186.199829, 185.400024, 899.60022, 183.800171, 183., 182.200073, 1293.599976, 1754.499878, 1764.099854, 1773.700073, 936.400024, 177.400146, 176.60022, 175.800049, 954.799805, 174.199951, 173.400024, 172.599854, 973.200073, 171., 170.200073, 169.400146, 1068.800049, 1826., 1840.299927, 1854.599976, 1531.199951, 2424.200195, 2442.800049, 2461.399902, 1578.399902, 2498.599854, 2517.199951, 2535.800049, 1625.599976, 2573., 2591.599854, 2610.200195, 991.999939, 1310.800049, 1319.400146, 1328., 651.199951, -74., -76.799805, -79.599854, 665.600098, -85.199829, -87.999756, -90.799805, 680., -96.400024, -99.199829, -102., 1100.800049, 1448.400146, 1456.999878, 1465.600098, 708.800049, -118.799805, -121.599976, -124.400269, 723.199829, -130., -132.800171, -135.599976, 737.599976, -141.200073, -144., -146.799805, 1209.599976, 1586., 1594.600098, 1603.200073, 766.400146, -163.599976, -166.39978, -169.200073, 780.800049, -174.799927, -177.599976, -180.400146, 795.199951, -185.999878, -188.800171, -191.599854, 1036.800049, 1761.75, 1775.550049, 1789.349976, 1466.200073, 2293.700195, 2311.300049, 2328.900146, 1511.399902, 2364.099854, 2381.700195, 2399.300049, 1556.599976, 2434.5, 2452.099854, 2469.700195, 923.999939, 1174.300049, 1181.899902, 1189.5, 513.200073, -350.999756, -355.799805, -360.599854, 523.599976, -370.199951, -374.999939, -379.799805, 534., -389.400146, -394.19989, -398.999817, 1024.800049, 1295.900146, 1303.5, 1311.10022, 554.799927, -427.800171, -432.599854, -437.400146, 565.199951, -446.999878, -451.799805, -456.599854, 575.599976, -466.200317, -470.999756, -475.799805, 1125.599976, 1417.499878, 1425.100098, 1432.700073, 596.400024, -504.599854, -509.400269, -514.199951, 606.800049, -523.800171, -528.599609, -533.400146, 617.200073, -542.999878, -547.800171, -552.599854, 1004.800049, 1697.5, 1710.799927, 1724.099976, 1401.199951, 2163.200195, 2179.800049, 2196.400146, 1444.400024, 2229.599854, 2246.200195, 2262.800049, 1487.599976, 2296., 2312.599854, 2329.200195, 855.999939, 1037.800049, 1044.400146, 1051., 375.199951, -627.999756, -634.800171, -641.599976, 381.599976, -655.199829, -661.999878, -668.80011, 388.000061, -682.400146, -689.199951, -695.999756, 948.799988, 1143.400146, 1149.999878, 1156.60022, 400.799805, -736.799927, -743.599976, -750.399902, 407.200073, -763.999878, -770.799805, -777.599731, 413.599976, -791.200073, -797.999756, -804.800171, 1041.599976, 1248.999878, 1255.60022, 1262.200073, 426.399902, -845.599854, -852.400146, -859.200073, 432.799927, -872.799805, -879.599854, -886.400024, 439.200073, -899.999878, -906.799927, -913.599976, 972.800049, 1633.25, 1646.049927, 1658.850098, 1336.200073, 2032.700195, 2048.300049, 2063.900146, 1377.400024, 2095.099854, 2110.700195, 2126.300049, 1418.599976, 2157.5, 2173.099854, 2188.700195, 787.999939, 901.299988, 906.899963, 912.500061, 237.200012, -904.999817, -913.799866, -922.599792, 239.599976, -940.199707, -948.999817, -957.800171, 242., -975.400146, -984.199829, -992.999756, 872.799988, 990.899963, 996.499878, 1002.10022, 246.800049, -1045.799927, -1054.599854, -1063.400024, 249.200073, -1080.999878, -1089.799805, -1098.599854, 251.600098, -1116.199951, -1124.999878, -1133.799683, 957.599976, 1080.499878, 1086.10022, 1091.700073, 256.400024, -1186.599854, -1195.400146, -1204.199829, 258.799927, -1221.800171, -1230.599976, -1239.400269, 261.199951, -1257., -1265.799927, -1274.600098}, sd::DataType::FLOAT32); + + input.linspace(-32, 0.1); + + sd::ops::deconv3d op; + auto results = op.evaluate({&input, &weights}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat}); + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, deconv3d_bp_test1) { + + int bS=1, iD=3,iH=3,iW=3, iC=1,oC=2, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=2,oH=2,oW=2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto bias = NDArrayFactory::create('c', {iC}); + auto gradO = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + + NDArray expGradI('c', {bS, oD, oH, oW, oC}, {62., 67.6, 68.4, 74.8, 81.2, 89.2, 87.6, 96.4, 119.6, 132.4, 126., 139.6, 138.8, 154., 145.2, 161.2}, sd::DataType::FLOAT32); + NDArray expGradW('c', {kD, kH, kW, iC, oC}, {28., 28., 32., 32., 40., 40., 44., 44., 64, 64., 68., 68., 76., 76., 80., 80.}, sd::DataType::FLOAT32); + NDArray expGradB('c', {iC}, std::vector{364.5}, sd::DataType::FLOAT32); + + input = 0.5; + weights.linspace(0.1, 0.1); + gradO.linspace(0.5); + + sd::ops::deconv3d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); + + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); + + } + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, deconv3d_bp_test2) { + + int bS=1, iD=2,iH=2,iW=2, iC=1,oC=2, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=2,oH=2,oW=2; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, oD, oH, oW, oC}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}); + auto gradO = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + + NDArray expGradI('c', {bS, oD, oH, oW, oC}, {34, 37.2, 16.6, 18.4, 15.4, 17.4, 7.1, 8.2, 10.6, 13., 4.3, 5.6, 2.9, 4.3, 0.75, 1.5}, sd::DataType::FLOAT32); + NDArray expGradW('c', {kD, kH, kW, iC, oC}, {16, 16, 9, 9, 10, 10, 5.5, 5.5, 12, 12, 6.5, 6.5, 7, 7, 3.75, 3.75}, sd::DataType::FLOAT32); + + input = 0.5; + weights.linspace(0.1, 0.1); + gradO.linspace(0.5); + + sd::ops::deconv3d_bp op; + auto results = op.evaluate({&input, &weights, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); + + auto gradI = results.at(0); + auto gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, deconv3d_bp_test3) { + + int bS=1, iD=3,iH=3,iW=3, iC=1,oC=2, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=2,oH=2,oW=2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}, {0.1f, 0.9f, 0.2f, 0.1f, 0.3f, 1.1f, 0.4f, 1.2f, 0.5f, 1.3f, 0.6f, 1.4f, 0.7f, 1.5f, 0.8f, 1.6f}); + auto gradO = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + + NDArray expGradI('c', {bS, oD, oH, oW, oC}, {33.8, 37.4, 44.6, 48.2, 66.2, 69.8, 77., 80.6, 77.25, 86.35, 104.55, 113.65, 159.15, 168.25, 186.45, 195.55}, sd::DataType::FLOAT32); + NDArray expGradW('c', {kD, kH, kW, iC, oC}, {28., 28, 32, 32, 40, 40, 44, 44, 64, 64, 68, 68, 76, 76, 80, 80.}, sd::DataType::FLOAT32); + + input = 0.5; + gradO.linspace(0.5); + + sd::ops::deconv3d_bp op; + auto results = op.evaluate({&input, &weights, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); + + auto gradI = results.at(0); + auto gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, deconv3d_bp_test4) { + + int bS=1, iD=2,iH=2,iW=2, iC=1,oC=2, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=1,pH=1,pW=1, dD=1,dH=1,dW=1; + int oD=3,oH=3,oW=3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, oC, oD, oH, oW}); + auto weights = NDArrayFactory::create('c', {kD, kH, kW, iC, oC}, {0.1f, 0.9f, 0.2f, 0.1f, 0.3f, 1.1f, 0.4f, 1.2f, 0.5f, 1.3f, 0.6f, 1.4f, 0.7f, 1.5f, 0.8f, 1.6f}); + auto gradO = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + + NDArray expGradI('c', {bS, oC, oD, oH, oW}, {0.4, 1.55, 1.05, 2.3, 5.7, 3.2, 1.5, 3.35, 1.75, 3.8, 8.3, 4.3, 9.0, 18.6, 9.2, 4.4, 8.7, 4.1, 1.8, 3.55, 1.65, 3.5, 6.5, 2.8, 1.3, 2.15, 0.75, 0.8, 3.15, 2.25, 4.7, 12.1, 7.2, 3.5, 8.15, 4.55, 7.8, 17.9, 9.9, 19.75, 42.85, 23.6, 9.35, 21.55, 12.9, 5.4, 11.55, 6.05, 8.25, 20.75, 13.2, 0.65, 6.6, 6.75}, sd::DataType::FLOAT32); + NDArray expGradW('c', {kD, kH, kW, iC, oC}, {16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.0, 16.}, sd::DataType::FLOAT32); + + input = 0.5; + gradO.linspace(0.5); + + sd::ops::deconv3d_bp op; + auto results = op.evaluate({&input, &weights, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat}, {}); + + auto gradI = results.at(0); + auto gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, deconv3d_bp_test5) { + + int bS=2, iD=4,iH=4,iW=4, iC=3,oC=2, kD=2,kH=1,kW=1, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=4,oH=4,oW=4; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = 1; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] + + NDArray input('c', {bS, iC, iD, iH, iW}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {-0.1, 0.2}, sd::DataType::FLOAT32); + NDArray weights('c',{iC, oC, kD, kH, kW}, {-0.6, 0., -0.3, 0.3, -0.5, 0.1, -0.2, 0.4, -0.4, 0.2, -0.1, 0.5}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oC, oD, oH, oW},sd::DataType::FLOAT32); + + NDArray expGradI('c', {bS, iC, iD, iH, iW}, {9.696001, 9.684001, 9.672001, 9.66, 9.648001, 9.636, 9.624001, 9.612, 9.600001, 9.587999, 9.576, 9.564001, 9.552, + 9.540001, 9.528, 9.516, 9.504001, 9.492, 9.480001, 9.468, 9.455999, 9.444, 9.432001, 9.420001, 9.408001, 9.396, 9.384001, 9.372001, 9.36, 9.348001, 9.335999, + 9.324001, 9.312, 9.300001, 9.288001, 9.276001, 9.264, 9.252001, 9.24, 9.228001, 9.216, 9.204, 9.191999, 9.18, 9.168001, 9.156, 9.144001, 9.132, 13.152, 13.134001, + 13.116, 13.098, 13.080001, 13.062, 13.044001, 13.026001, 13.008001, 12.990001, 12.972, 12.954, 12.936001, 12.918, 12.900002, 12.882, 3.616001, 3.612, 3.608, 3.604, + 3.6, 3.596, 3.592, 3.588, 3.584001, 3.579999, 3.576001, 3.571999, 3.568, 3.564, 3.56, 3.556, 3.552, 3.548, 3.544, 3.539999, 3.536001, 3.532001, 3.527999, 3.524001, 3.52, 3.516, 3.512, 3.508, 3.504, 3.5, 3.496, 3.492, 3.487999, 3.484001, 3.48, 3.476, 3.472, 3.468, 3.464, 3.46, 3.456, 3.452, 3.447999, 3.444001, 3.439999, 3.436, 3.432001, 3.428, 10.272, 10.258, 10.244, 10.23, 10.216, 10.202, 10.188, 10.174, 10.16, 10.146, 10.132, 10.118, 10.104, 10.09, 10.076, 10.062, -2.464, -2.460001, -2.455999, -2.452, -2.448, -2.444, -2.44, -2.436, -2.432, -2.428, -2.424, -2.42, -2.415999, -2.412, -2.408, -2.404, -2.4, -2.396, -2.392, -2.388, -2.384, -2.38, -2.376, -2.372, -2.368, -2.363999, -2.36, -2.356, -2.352, -2.348, -2.344, -2.34, -2.336, -2.332, -2.328001, -2.323999, -2.32, -2.316, -2.312, -2.308, -2.304, -2.3, -2.296, -2.292, -2.288, -2.283999, -2.28, -2.276, 7.392, 7.382, 7.372, 7.362, 7.352, 7.342, 7.332, 7.322, 7.312, 7.302, 7.292, 7.282, 7.272, 7.262, 7.252, 7.242, 8.16, 8.148001, 8.136001, 8.124001, 8.112, 8.1, 8.087999, 8.076, 8.063999, 8.052, 8.04, 8.028001, 8.016, 8.004001, 7.992001, 7.98, 7.968, 7.956, 7.944, 7.932001, 7.92, 7.908, 7.896, 7.884, 7.872001, 7.86, 7.848001, 7.835999, 7.824, 7.812, 7.800001, 7.788, 7.776, 7.764, 7.752, 7.740001, 7.728, 7.716001, 7.704, 7.692, 7.68, 7.668, 7.656, 7.644001, 7.632001, 7.62, 7.608001, 7.596001, 10.848, 10.830001, 10.812, 10.794001, 10.776, 10.758, 10.74, 10.722, 10.704, 10.686001, 10.668, 10.650001, 10.632, 10.614, 10.596001, 10.578001, 3.104, 3.1, 3.096, 3.092, 3.088, 3.084, 3.079999, 3.076001, 3.072, 3.068, 3.064, 3.06, 3.056, 3.052, 3.048, 3.044, 3.039999, 3.036001, 3.032, 3.028, 3.024001, 3.02, 3.016, 3.012, 3.008, 3.004, 3., 2.996, 2.992, 2.987999, 2.984001, 2.98, 2.976, 2.972, 2.968, 2.964, 2.96, 2.956, 2.952, 2.947999, 2.944001, 2.94, 2.936, 2.932001, 2.928, 2.924, 2.92, 2.916, 8.48, 8.466, 8.452, 8.438, 8.424, 8.41, 8.396, 8.382, 8.368, 8.354, 8.34, 8.326, 8.312, 8.298, 8.284, 8.27, -1.952, -1.948, -1.944, -1.94, -1.936, -1.932, -1.928, -1.924, -1.92, -1.916, -1.912, -1.908, -1.904, -1.9, -1.896, -1.892, -1.888, -1.884, -1.88, -1.876, -1.872, -1.868, -1.863999, -1.86, -1.856, -1.852, -1.848, -1.844, -1.84, -1.836, -1.832, -1.828, -1.823999, -1.82, -1.816, -1.812, -1.808, -1.804, -1.8, -1.796, -1.792, -1.788, -1.784, -1.78, -1.776, -1.771999, -1.768, -1.764, 6.112, 6.102, 6.092, 6.082, 6.072, 6.062, 6.052, 6.042, 6.032, 6.022, 6.012, 6.002, 5.992, 5.982, 5.972, 5.962}, sd::DataType::FLOAT32); + + NDArray expGradW('c', {iC, oC, kD, kH, kW}, {-73678.695312, -59907.972656, -67739.515625, -54962.082031, -15966.075195, -17115.042969, -15269.777344, -16101.275391, 41746.566406, 25677.917969, 37200.003906, 22759.517578}, sd::DataType::FLOAT32); + NDArray expGradB('c', {oC}, {-1803.520020, -1639.679932}, sd::DataType::FLOAT32); + + input.linspace(100., -0.5); + gradO.linspace(-16, 0.02); + + sd::ops::deconv3d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, deconv3d_bp_test6) { + + int bS=2, iD=4,iH=4,iW=4, iC=3,oC=2, kD=2,kH=1,kW=1, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=5,oH=4,oW=4; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = 2; // 0 - [kD, kH, kW, oC, iC], 1 - [iC, oC, kD, kH, kW], 2 - [iC, kD, kH, kW, oC] + + NDArray input('c', {bS, iD, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {-0.1, 0.2}, sd::DataType::FLOAT32); + NDArray weights('c',{iC, kD, kH, kW, oC}, {-0.6, -0.3, 0., 0.3, -0.5, -0.2, 0.1, 0.4, -0.4, -0.1, 0.2, 0.5}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oD, oH, oW, oC}, sd::DataType::FLOAT32); + + NDArray expGradI('c', {bS, iD, iH, iW, iC}, {1.056, 0.482, -0.092, 1.044, 0.478, -0.088, 1.032, 0.474, -0.084, 1.02, 0.47, -0.08, 1.008, 0.466, -0.076, 0.996, + 0.462, -0.072, 0.984, 0.458, -0.068, 0.972, 0.454, -0.064, 0.96, 0.45, -0.06, 0.948, 0.446, -0.056, 0.936, 0.442, -0.052, 0.924, 0.438, -0.048, 0.912, 0.434, + -0.044, 0.9, 0.43, -0.04, 0.888, 0.426, -0.036, 0.876, 0.422, -0.032, 0.864, 0.418, -0.028, 0.852, 0.414, -0.024, 0.84, 0.41, -0.02, 0.828, 0.406, -0.016, + 0.816, 0.402, -0.012, 0.804, 0.398, -0.008, 0.792, 0.394, -0.004, 0.78, 0.39, 0., 0.768, 0.386, 0.004, 0.756, 0.382, 0.008, 0.744, 0.378, 0.012, 0.732, 0.374, + 0.016, 0.72, 0.37, 0.02, 0.708, 0.366, 0.024, 0.696, 0.362, 0.028, 0.684, 0.358, 0.032, 0.672, 0.354, 0.036, 0.66, 0.35, 0.04, 0.648, 0.346, 0.044, 0.636, 0.342, 0.048, 0.624, 0.338, 0.052, 0.612, 0.334, 0.056, 0.6, 0.33, 0.06, 0.588, 0.326, 0.064, 0.576, 0.322, 0.068, 0.564, 0.318, 0.072, 0.552, 0.314, 0.076, 0.54, 0.31, 0.08, 0.528, 0.306, 0.084, 0.516, 0.302, 0.088, 0.504, 0.298, 0.092, 0.492, 0.294, 0.096, 0.48, 0.29, 0.1, 0.468, 0.286, 0.104, 0.456, 0.282, 0.108, 0.444, 0.278, 0.112, 0.432, 0.274, 0.116, 0.42, 0.27, 0.12, 0.408, 0.266, 0.124, 0.396, 0.262, 0.128, 0.384, 0.258, 0.132, 0.372, 0.254, 0.136, 0.36, 0.25, 0.14, 0.348, 0.246, 0.144, 0.336, 0.242, 0.148, 0.324, 0.238, 0.152, 0.312, 0.234, 0.156, 0.3, 0.23, 0.16, 0.096, 0.162, 0.228, 0.084, 0.158, 0.232, 0.072, 0.154, 0.236, 0.06, 0.15, 0.24, 0.048, 0.146, 0.244, 0.036, 0.142, 0.248, 0.024, 0.138, 0.252, 0.012, 0.134, 0.256, 0., 0.13, 0.26, -0.012, 0.126, 0.264, -0.024, 0.122, 0.268, -0.036, 0.118, 0.272, -0.048, 0.114, 0.276, -0.06, 0.11, 0.28, -0.072, 0.106, 0.284, -0.084, 0.102, 0.288, -0.096, 0.098, 0.292, -0.108, 0.094, 0.296, -0.12, 0.09, 0.3, -0.132, 0.086, 0.304, -0.144, 0.082, 0.308, -0.156, 0.078, 0.312, -0.168, 0.074, 0.316, -0.18, 0.07, 0.32, -0.192, 0.066, 0.324, -0.204, 0.062, 0.328, -0.216, 0.058, 0.332, -0.228, 0.054, 0.336, -0.24, 0.05, 0.34, -0.252, 0.046, 0.344, -0.264, 0.042, 0.348, -0.276, 0.038, 0.352, -0.288, 0.034, 0.356, -0.3, 0.03, 0.36, -0.312, 0.026, 0.364, -0.324, 0.022, 0.368, -0.336, 0.018, 0.372, -0.348, 0.014, 0.376, -0.36, 0.01, 0.38, -0.372, 0.006, 0.384, -0.384, 0.002, 0.388, -0.396, -0.002, 0.392, -0.408, -0.006, 0.396, -0.42, -0.01, 0.4, -0.432, -0.014, 0.404, -0.444, -0.018, 0.408, -0.456, -0.022, 0.412, -0.468, -0.026, 0.416, -0.48, -0.03, 0.42, -0.492, -0.034, 0.424, -0.504, -0.038, 0.428, -0.516, -0.042, 0.432, -0.528, -0.046, 0.436, -0.54, -0.05, 0.44, -0.552, -0.054, 0.444, -0.564, -0.058, 0.448, -0.576, -0.062, 0.452, -0.588, -0.066, 0.456, -0.6, -0.07, 0.46, -0.612, -0.074, 0.464, -0.624, -0.078, 0.468, -0.636, -0.082, 0.472, -0.648, -0.086, 0.476, -0.66, -0.09, 0.48}, sd::DataType::FLOAT32); + + NDArray expGradW('c', {iC, kD, kH, kW, oC}, {-6328.958984, -6322.880371, -6134.400879, -6128.319824, -6318.079590, -6312.640137, -6144.000000, -6138.560547, -6307.202637, -6302.399414, -6153.599609, -6148.799316}, sd::DataType::FLOAT32); + NDArray expGradB('c', {oC}, {-1.599994, 0.000001}, sd::DataType::FLOAT32); + + input.linspace(100., -0.5); + gradO.linspace(-1.6, 0.01); + + sd::ops::deconv3d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, dataFormat, wFormat}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto gradI = results.at(0); + auto gradW = results.at(1); + auto gradB = results.at(2); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, maxpool2d_1) { + + auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW}); + auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); + // auto z('c',{bS,iD,oH,oW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + std::vector* argI = block->getIArguments(); + *argI = {kH,kW, sH,sW, pH,pW, dH,dW, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + sd::ops::maxpool2d pooling; + Nd4jStatus status = pooling.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + // result.printShapeInfo(); + ASSERT_TRUE(exp.isSameShape(result)); + + delete variableSpace; + delete block; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, maxpool2d_2) { + + const int bS = 2; + const int iD = 1; + const int iH = 28; + const int iW = 28; + const int kH = 5; + const int kW = 5; + const int sH = 1; + const int sW = 1; + const int pH = 0; + const int pW = 0; + const int dH = 1; + const int dW = 1; + const int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; // output height + const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width + + + auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW}); + auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); + // auto z('c',{bS,iD,oH,oW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + std::vector* argI = block->getIArguments(); + *argI = {kH,kW, sH,sW, pH,pW, dH,dW, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + sd::ops::maxpool2d pooling; + Nd4jStatus status = pooling.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + // result.printShapeInfo(); + ASSERT_TRUE(exp.isSameShape(result)); + + delete variableSpace; + delete block; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, maxpool2d_3) { + + const int bS = 2; + const int iD = 1; + const int iH = 28; + const int iW = 28; + const int kH = 5; + const int kW = 5; + const int sH = 1; + const int sW = 1; + const int pH = 0; + const int pW = 0; + const int dH = 1; + const int dW = 1; + const int oH = (int) sd::math::nd4j_ceil(iH * 1.f / sH); + const int oW = (int) sd::math::nd4j_ceil(iW * 1.f / sW); + + + auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW}); + auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); + // auto z('c',{bS,iD,oH,oW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + std::vector* argI = block->getIArguments(); + *argI = {kH,kW, sH,sW, pH,pW, dH,dW, 1}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + sd::ops::maxpool2d pooling; + Nd4jStatus status = pooling.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + // result.printShapeInfo(); + ASSERT_TRUE(exp.isSameShape(result)); + + delete variableSpace; + delete block; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, maxpool2d_4) { + + const int bS = 2; + const int iD = 1; + const int iH = 24; + const int iW = 24; + const int kH = 3; + const int kW = 3; + const int sH = 1; + const int sW = 1; + const int pH = 0; + const int pW = 0; + const int dH = 1; + const int dW = 1; + const int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; // output height + const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width + + + auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW}); + auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); + // auto z('c',{bS,iD,oH,oW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + std::vector* argI = block->getIArguments(); + *argI = {kH,kW, sH,sW, pH,pW, dH,dW, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + sd::ops::maxpool2d pooling; + Nd4jStatus status = pooling.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + // result.printShapeInfo(); + ASSERT_TRUE(exp.isSameShape(result)); + + delete variableSpace; + delete block; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, maxpool2d_5) { + + const int bS = 2; + const int iD = 1; + const int iH = 24; + const int iW = 24; + const int kH = 3; + const int kW = 3; + const int sH = 1; + const int sW = 1; + const int pH = 0; + const int pW = 0; + const int dH = 1; + const int dW = 1; + const int oH = (int) sd::math::nd4j_ceil(iH * 1.f / sH); + const int oW = (int) sd::math::nd4j_ceil(iW * 1.f / sW); + + + auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW}); + auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); + // auto z('c',{bS,iD,oH,oW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + std::vector* argI = block->getIArguments(); + *argI = {kH,kW, sH,sW, pH,pW, dH,dW, 1}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + sd::ops::maxpool2d pooling; + Nd4jStatus status = pooling.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + // result.printShapeInfo(); + ASSERT_TRUE(exp.isSameShape(result)); + + delete variableSpace; + delete block; +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, maxpool2d_6) { + auto x = NDArrayFactory::create('c', {2, 4, 4, 2}); + auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, {11.f, 12.f, 15.f, 16.f, 27.f, 28.f, 31.f, 32.f, 43.f, 44.f, 47.f, 48.f, 59.f, 60.f, 63.f, 64.f}); + + x.linspace(1); + + sd::ops::maxpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 1, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); +#if 0 + exp.printIndexedBuffer("Expected"); + z->printIndexedBuffer("Z"); +#endif + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, maxpool2d_7) { + auto x = NDArrayFactory::create('c', {2, 4, 4, 2}); + auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, {11.f, 12.f, 15.f, 16.f, 27.f, 28.f, 31.f, 32.f, 43.f, 44.f, 47.f, 48.f, 59.f, 60.f, 63.f, 64.f}); + + x.linspace(1); + + sd::ops::maxpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); +#if 0 + exp.printIndexedBuffer("Expected"); + z->printIndexedBuffer("Z"); +#endif + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, maxpool2d_8) { + auto x = NDArrayFactory::create('c', {2, 2, 5, 5}); + auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, {7.f, 9.f, 17.f, 19.f, 32.f, 34.f, 42.f, 44.f, 57.f, 59.f, 67.f, 69.f, 82.f, 84.f, 92.f, 94.f}); + + x.linspace(1); + + sd::ops::maxpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); +#if 0 + exp.printIndexedBuffer("Expected"); + z->printIndexedBuffer("Z"); +#endif + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, maxpool2d_9) { + + int bS = 3; // batch size (number of samples) + int iC = 3; // input channels + int iH = 28, iW = 28; // input height/width + int kH = 2, kW = 2; // kernel (filter) height/width + int sH = 1, sW = 1; // stride height/width + int pH = 0, pW = 0; // padding height/width + int dH = 1, dW = 1; // dilation height/width + + int oH = 27, oW = 27; // output height/width + + int isSameMode = 0; // 1-SAME, 0-VALID + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + + sd::ops::maxpool2d op; + auto results = op.evaluate({&input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, isSameMode, 1, 0}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(output->isSameShape({bS, iC, oH, oW})); + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, maxpool2d_10) { + + int bS=1, iH=4,iW=4, iC=3, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=3,oW=3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.27620894f, 0.21801452f, 0.062078513f, 7.348895E-4f, 0.24149609f, 0.4948205f, 0.93483436f, 0.52035654f, 0.30292067f, 0.3289706f, 0.7977864f, + 0.03180518f, 0.1455722f, 0.90352905f, 0.9405744f, 0.0048329555f, 0.44062102f, 0.111197524f, 0.31742015f, 0.1933705f, 0.23825112f, 0.35076278f, 0.7135856f, 0.28229436f, 0.18310733f, + 0.9613717f, 0.56823575f, 0.78289545f, 0.62195826f, 0.5244586f, 0.5040889f, 0.025349546f, 0.41400263f, 0.28420195f, 0.8536445f, 0.3044107f, 0.7997134f, 0.45762005f, 0.7653578f, + 0.07198584f, 0.5304998f, 0.7334402f, 0.85019743f, 0.031957153f, 0.37088063f, 0.85722464f, 0.06376881f, 0.39791203f}); + + auto expOutput = NDArrayFactory::create('c', {bS, iC, oH, oW}, {0.4948205f, 0.93483436f, 0.93483436f, 0.4948205f, 0.93483436f, 0.93483436f, 0.90352905f, 0.9405744f, 0.9405744f, 0.44062102f, 0.7135856f, + 0.7135856f, 0.9613717f, 0.9613717f, 0.78289545f, 0.9613717f, 0.9613717f, 0.78289545f, 0.7997134f, 0.8536445f, 0.8536445f, 0.7997134f, 0.85019743f, 0.85019743f, + 0.85722464f, 0.85722464f, 0.85019743f}); + + sd::ops::maxpool2d op; + auto results = op.evaluate({&input}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode}); + auto* output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); +#if 0 + expOutput.printIndexedBuffer("expOutput"); + output->printIndexedBuffer("output"); +#endif + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, maxpool2d_11) { + + NDArray input('c', {1,1,4,5}, sd::DataType::FLOAT32); + NDArray z('c', {1,1,4,5}, sd::DataType::FLOAT32); + + input.linspace(1.); + + sd::ops::maxpool2d op; + auto results = op.evaluate({&input}, {}, {2,2, 1,1, 1,1, 2,2, 1,0,0}); + + ASSERT_EQ(Status::OK(), results.status()); + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, avgpool3d_test1) { + + int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=2,oH=2,oW=2; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto expected = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}, {10.5f, 11.5f, 13.5f, 14.5f, 22.5f, 23.5f, 25.5f, 26.5f, 46.5f, 47.5f, 49.5f, 50.5f, 58.5f, 59.5f, 61.5f, 62.5f, + 82.5f, 83.5f, 85.5f, 86.5f, 94.5f, 95.5f, 97.5f, 98.5f,118.5f,119.5f,121.5f,122.5f,130.5f,131.5f,133.5f,134.5f, + 154.5f,155.5f,157.5f,158.5f,166.5f,167.5f,169.5f,170.5f,190.5f,191.5f,193.5f,194.5f,202.5f,203.5f,205.5f,206.5f}); + input.linspace(1.); + + sd::ops::avgpool3dnew op; + auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, avgpool3d_test2) { + + int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=3,oH=4,oW=3; + int paddingMode = 1; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, { 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 29.5f, 30.5f, 31.5f, 29.5f, 30.5f, 31.5f, 32.5f, 33.5f, 34.5f, 34.f, 35.f, 36.f, 38.5f, 39.5f, 40.5f, 41.5f, 42.5f, 43.5f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 47.5f, 48.5f, 49.5f, + 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 65.5f, 66.5f, 67.5f, 65.5f, 66.5f, 67.5f, 68.5f, 69.5f, 70.5f, 70.f, 71.f, 72.f, 74.5f, 75.5f, 76.5f, 77.5f, 78.5f, 79.5f, 79.f, 80.f, 81.f, 79.f, 80.f, 81.f, 82.f, 83.f, 84.f, 83.5f, 84.5f, 85.5f, + 79.f, 80.f, 81.f, 82.f, 83.f, 84.f, 83.5f, 84.5f, 85.5f, 83.5f, 84.5f, 85.5f, 86.5f, 87.5f, 88.5f, 88.f, 89.f, 90.f, 92.5f, 93.5f, 94.5f, 95.5f, 96.5f, 97.5f, 97.f, 98.f, 99.f, 97.f, 98.f, 99.f, 100.f, 101.f, 102.f, 101.5f, 102.5f, 103.5f, + 133.f, 134.f, 135.f, 136.f, 137.f, 138.f, 137.5f, 138.5f, 139.5f, 137.5f, 138.5f, 139.5f, 140.5f, 141.5f, 142.5f, 142.f, 143.f, 144.f, 146.5f, 147.5f, 148.5f, 149.5f, 150.5f, 151.5f, 151.f, 152.f, 153.f, 151.f, 152.f, 153.f, 154.f, 155.f, 156.f, 155.5f, 156.5f, 157.5f, + 169.f, 170.f, 171.f, 172.f, 173.f, 174.f, 173.5f, 174.5f, 175.5f, 173.5f, 174.5f, 175.5f, 176.5f, 177.5f, 178.5f, 178.f, 179.f, 180.f, 182.5f, 183.5f, 184.5f, 185.5f, 186.5f, 187.5f, 187.f, 188.f, 189.f, 187.f, 188.f, 189.f, 190.f, 191.f, 192.f, 191.5f, 192.5f, 193.5f, + 187.f, 188.f, 189.f, 190.f, 191.f, 192.f, 191.5f, 192.5f, 193.5f, 191.5f, 192.5f, 193.5f, 194.5f, 195.5f, 196.5f, 196.f, 197.f, 198.f, 200.5f, 201.5f, 202.5f, 203.5f, 204.5f, 205.5f, 205.f, 206.f, 207.f, 205.f, 206.f, 207.f, 208.f, 209.f, 210.f, 209.5f, 210.5f, 211.5f}); + input.linspace(1.); + + sd::ops::avgpool3dnew op; + auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 0, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, avgpool3d_test3) { + + int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=2,oH=2,oW=2; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, { 29.5f, 30.5f, 31.5f, 32.5f, 33.5f, 34.5f, 38.5f, 39.5f, 40.5f, 41.5f, 42.5f, 43.5f, 65.5f, 66.5f, 67.5f, 68.5f, 69.5f, 70.5f, + 74.5f, 75.5f, 76.5f, 77.5f, 78.5f, 79.5f, 137.5f, 138.5f, 139.5f, 140.5f, 141.5f, 142.5f, 146.5f, 147.5f, 148.5f, 149.5f, 150.5f, 151.5f, + 173.5f, 174.5f, 175.5f, 176.5f, 177.5f, 178.5f, 182.5f, 183.5f, 184.5f, 185.5f, 186.5f, 187.5f}); + input.linspace(1.); + + sd::ops::avgpool3dnew op; + auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, avgpool3d_test4) { + + int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=1,pH=1,pW=1, dD=1,dH=1,dW=1; + int oD=4,oH=4,oW=4; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto expected = NDArrayFactory::create('c', {bS, iC, oD, oH, oW},{0.416667f, 1.00f, 1.333333f, 0.75f, 1.00f, 2.25f, 2.75f, 1.50f, 1.75f, 3.75f, 4.25f, 2.25f, 1.416667f, 3.00f, 3.333333f, 1.75f, 2.833333f, 6.00f, 6.666667f, 3.50f, 5.00f, 10.50f, 11.50f, 6.00f, 6.50f, + 13.50f, 14.50f, 7.50f, 4.833333f, 10.00f, 10.666667f, 5.50f, 6.833333f, 14.00f, 14.666667f, 7.50f, 11.00f, 22.50f, 23.50f, 12.00f, 12.50f, 25.50f, 26.50f, 13.50f, 8.833333f, 18.00f, 18.666666f, 9.50f, + 4.416667f, 9.00f, 9.333333f, 4.75f, 7.00f, 14.25f, 14.75f, 7.50f, 7.75f, 15.75f, 16.25f, 8.25f, 5.416667f, 11.00f, 11.333333f, 5.75f, 6.416667f, 13.00f, 13.333333f, 6.75f, 10.00f, 20.25f, 20.75f, + 10.50f, 10.75f, 21.75f, 22.25f, 11.25f, 7.416667f, 15.00f, 15.333333f, 7.75f, 14.833333f, 30.00f, 30.666666f, 15.50f, 23.00f, 46.50f, 47.50f, 24.00f, 24.50f, 49.50f, 50.50f, 25.50f, 16.833334f, + 34.00f, 34.666668f, 17.50f, 18.833334f, 38.00f, 38.666668f, 19.50f, 29.00f, 58.50f, 59.50f, 30.00f, 30.50f, 61.50f, 62.50f, 31.50f, 20.833334f, 42.00f, 42.666668f, 21.50f, 10.416667f, 21.00f, + 21.333334f, 10.75f, 16.00f, 32.25f, 32.75f, 16.50f, 16.75f, 33.75f, 34.25f, 17.25f, 11.416667f, 23.00f, 23.333334f, 11.75f, 12.416667f, 25.00f, 25.333334f, 12.75f, 19.00f, 38.25f, 38.75f, 19.50f, + 19.75f, 39.75f, 40.25f, 20.25f, 13.416667f, 27.00f, 27.333334f, 13.75f, 26.833334f, 54.00f, 54.666668f, 27.50f, 41.00f, 82.50f, 83.50f, 42.00f, 42.50f, 85.50f, 86.50f, 43.50f, 28.833334f, 58.00f, + 58.666668f, 29.50f, 30.833334f, 62.00f, 62.666668f, 31.50f, 47.00f, 94.50f, 95.50f, 48.00f, 48.50f, 97.50f, 98.50f, 49.50f, 32.833332f, 66.00f, 66.666664f, 33.50f, 16.416666f, 33.00f, 33.333332f, + 16.75f, 25.00f, 50.25f, 50.75f, 25.50f, 25.75f, 51.75f, 52.25f, 26.25f, 17.416666f, 35.00f, 35.333332f, 17.75f, 18.416666f, 37.00f, 37.333332f, 18.75f, 28.00f, 56.25f, 56.75f, 28.50f, 28.75f, + 57.75f, 58.25f, 29.25f, 19.416666f, 39.00f, 39.333332f, 19.75f, 38.833332f, 78.00f, 78.666664f, 39.50f, 59.00f, 118.50f, 119.50f, 60.00f, 60.50f, 121.50f, 122.50f, 61.50f, 40.833332f, 82.00f, + 82.666664f, 41.50f, 42.833332f, 86.00f, 86.666664f, 43.50f, 65.00f, 130.50f, 131.50f, 66.00f, 66.50f, 133.50f, 134.50f, 67.50f, 44.833332f, 90.00f, 90.666664f, 45.50f, 22.416666f, 45.00f, + 45.333332f, 22.75f, 34.00f, 68.25f, 68.75f, 34.50f, 34.75f, 69.75f, 70.25f, 35.25f, 23.416666f, 47.00f, 47.333332f, 23.75f, 24.416666f, 49.00f, 49.333332f, 24.75f, 37.00f, 74.25f, 74.75f, + 37.50f, 37.75f, 75.75f, 76.25f, 38.25f, 25.416666f, 51.00f, 51.333332f, 25.75f, 50.833332f, 102.00f, 102.666664f, 51.50f, 77.00f, 154.50f, 155.50f, 78.00f, 78.50f, 157.50f, 158.50f, 79.50f, + 52.833332f, 106.00f, 106.666664f, 53.50f, 54.833332f, 110.00f, 110.666664f, 55.50f, 83.00f, 166.50f, 167.50f, 84.00f, 84.50f, 169.50f, 170.50f, 85.50f, 56.833332f, 114.00f, 114.666664f, + 57.50f, 28.416666f, 57.00f, 57.333332f, 28.75f, 43.00f, 86.25f, 86.75f, 43.50f, 43.75f, 87.75f, 88.25f, 44.25f, 29.416666f, 59.00f, 59.333332f, 29.75f, 30.416666f, 61.00f, 61.333332f, 30.75f, + 46.00f, 92.25f, 92.75f, 46.50f, 46.75f, 93.75f, 94.25f, 47.25f, 31.416666f, 63.00f, 63.333332f, 31.75f, 62.833332f, 126.00f, 126.666664f, 63.50f, 95.00f, 190.50f, 191.50f, 96.00f, 96.50f, + 193.50f, 194.50f, 97.50f, 64.833336f, 130.00f, 130.666672f, 65.50f, 66.833336f, 134.00f, 134.666672f, 67.50f, 101.00f, 202.50f, 203.50f, 102.00f, 102.50f, 205.50f, 206.50f, 103.50f, + 68.833336f, 138.00f, 138.666672f, 69.50f, 34.416668f, 69.00f, 69.333336f, 34.75f, 52.00f, 104.25f, 104.75f, 52.50f, 52.75f, 105.75f, 106.25f, 53.25f, 35.416668f, 71.00f, 71.333336f, 35.75f}); + input.linspace(1.); + + sd::ops::avgpool3dnew op; + auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, maxpool3d_test1) { + + int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=2,oH=2,oW=2; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto expected = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}, {20.f, 21.f, 23.f, 24.f, 32.f, 33.f, 35.f, 36.f, 56.f, 57.f, 59.f, 60.f, 68.f, 69.f, 71.f, 72.f, 92.f, 93.f, 95.f, 96.f, 104.f, 105.f, 107.f, 108.f, + 128.f, 129.f, 131.f, 132.f, 140.f, 141.f, 143.f, 144.f, 164.f, 165.f, 167.f, 168.f, 176.f, 177.f, 179.f, 180.f, 200.f, 201.f, 203.f, 204.f, 212.f, 213.f, 215.f, 216.f}); + input.linspace(1.); + + sd::ops::maxpool3dnew op; + auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, maxpool3d_test2) { + + int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=3,oH=4,oW=3; + int paddingMode = 1; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, { 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, + 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, 88.f, 89.f, 90.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 97.f, 98.f, 99.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, + 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, 88.f, 89.f, 90.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 97.f, 98.f, 99.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, 106.f, 107.f, 108.f, + 157.f, 158.f, 159.f, 160.f, 161.f, 162.f, 160.f, 161.f, 162.f, 166.f, 167.f, 168.f, 169.f, 170.f, 171.f, 169.f, 170.f, 171.f, 175.f, 176.f, 177.f, 178.f, 179.f, 180.f, 178.f, 179.f, 180.f, 175.f, 176.f, 177.f, 178.f, 179.f, 180.f, 178.f, 179.f, 180.f, + 193.f, 194.f, 195.f, 196.f, 197.f, 198.f, 196.f, 197.f, 198.f, 202.f, 203.f, 204.f, 205.f, 206.f, 207.f, 205.f, 206.f, 207.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f, + 193.f, 194.f, 195.f, 196.f, 197.f, 198.f, 196.f, 197.f, 198.f, 202.f, 203.f, 204.f, 205.f, 206.f, 207.f, 205.f, 206.f, 207.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f, 214.f, 215.f, 216.f}); + input.linspace(1.); + + sd::ops::maxpool3dnew op; + auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, maxpool3d_test3) { + + int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=2,oH=2,oW=2; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto expected = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}, {58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 94.f, 95.f, 96.f, 97.f, 98.f, 99.f, 103.f, 104.f, 105.f, 106.f, 107.f, 108.f, + 166.f, 167.f, 168.f, 169.f, 170.f, 171.f, 175.f, 176.f, 177.f, 178.f, 179.f, 180.f, 202.f, 203.f, 204.f, 205.f, 206.f, 207.f, 211.f, 212.f, 213.f, 214.f, 215.f, 216.f}); + input.linspace(1.); + + sd::ops::maxpool3dnew op; + auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, maxpool3d_test4) { + + int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=1,pH=1,pW=1, dD=1,dH=1,dW=1; + int oD=4,oH=4,oW=4; + int paddingMode = 0; // -SAME, 0-VALID + int dataFormat = 0; // -NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto expected = NDArrayFactory::create('c', {bS, iC, oD, oH, oW},{ 4.f, 5.f, 6.f, 6.f, 7.f, 8.f, 9.f, 9.f, 10.f, 11.f, 12.f, 12.f, 10.f, 11.f, 12.f, 12.f, 16.f, 17.f, 18.f, 18.f, 19.f, 20.f, 21.f, 21.f, 22.f, 23.f, 24.f, 24.f, 22.f, 23.f, 24.f, 24.f, 28.f, 29.f, 30.f, 30.f, 31.f, 32.f, 33.f, 33.f, 34.f, 35.f, 36.f, 36.f, 34.f, 35.f, 36.f, 36.f, + 28.f, 29.f, 30.f, 30.f, 31.f, 32.f, 33.f, 33.f, 34.f, 35.f, 36.f, 36.f, 34.f, 35.f, 36.f, 36.f, 40.f, 41.f, 42.f, 42.f, 43.f, 44.f, 45.f, 45.f, 46.f, 47.f, 48.f, 48.f, 46.f, 47.f, 48.f, 48.f, 52.f, 53.f, 54.f, 54.f, 55.f, 56.f, 57.f, 57.f, 58.f, 59.f, 60.f, 60.f, 58.f, 59.f, 60.f, 60.f, + 64.f, 65.f, 66.f, 66.f, 67.f, 68.f, 69.f, 69.f, 70.f, 71.f, 72.f, 72.f, 70.f, 71.f, 72.f, 72.f, 64.f, 65.f, 66.f, 66.f, 67.f, 68.f, 69.f, 69.f, 70.f, 71.f, 72.f, 72.f, 70.f, 71.f, 72.f, 72.f, 76.f, 77.f, 78.f, 78.f, 79.f, 80.f, 81.f, 81.f, 82.f, 83.f, 84.f, 84.f, 82.f, 83.f, 84.f, 84.f, + 88.f, 89.f, 90.f, 90.f, 91.f, 92.f, 93.f, 93.f, 94.f, 95.f, 96.f, 96.f, 94.f, 95.f, 96.f, 96.f, 100.f, 101.f, 102.f, 102.f, 103.f, 104.f, 105.f, 105.f, 106.f, 107.f, 108.f, 108.f, 106.f, 107.f, 108.f, 108.f, 100.f, 101.f, 102.f, 102.f, 103.f, 104.f, 105.f, 105.f, 106.f, 107.f, 108.f, 108.f, 106.f, 107.f, 108.f, 108.f, + 112.f, 113.f, 114.f, 114.f, 115.f, 116.f, 117.f, 117.f, 118.f, 119.f, 120.f, 120.f, 118.f, 119.f, 120.f, 120.f, 124.f, 125.f, 126.f, 126.f, 127.f, 128.f, 129.f, 129.f, 130.f, 131.f, 132.f, 132.f, 130.f, 131.f, 132.f, 132.f, 136.f, 137.f, 138.f, 138.f, 139.f, 140.f, 141.f, 141.f, 142.f, 143.f, 144.f, 144.f, 142.f, 143.f, 144.f, 144.f, + 136.f, 137.f, 138.f, 138.f, 139.f, 140.f, 141.f, 141.f, 142.f, 143.f, 144.f, 144.f, 142.f, 143.f, 144.f, 144.f, 148.f, 149.f, 150.f, 150.f, 151.f, 152.f, 153.f, 153.f, 154.f, 155.f, 156.f, 156.f, 154.f, 155.f, 156.f, 156.f, 160.f, 161.f, 162.f, 162.f, 163.f, 164.f, 165.f, 165.f, 166.f, 167.f, 168.f, 168.f, 166.f, 167.f, 168.f, 168.f, + 172.f, 173.f, 174.f, 174.f, 175.f, 176.f, 177.f, 177.f, 178.f, 179.f, 180.f, 180.f, 178.f, 179.f, 180.f, 180.f, 172.f, 173.f, 174.f, 174.f, 175.f, 176.f, 177.f, 177.f, 178.f, 179.f, 180.f, 180.f, 178.f, 179.f, 180.f, 180.f, 184.f, 185.f, 186.f, 186.f, 187.f, 188.f, 189.f, 189.f, 190.f, 191.f, 192.f, 192.f, 190.f, 191.f, 192.f, 192.f, + 196.f, 197.f, 198.f, 198.f, 199.f, 200.f, 201.f, 201.f, 202.f, 203.f, 204.f, 204.f, 202.f, 203.f, 204.f, 204.f, 208.f, 209.f, 210.f, 210.f, 211.f, 212.f, 213.f, 213.f, 214.f, 215.f, 216.f, 216.f, 214.f, 215.f, 216.f, 216.f, 208.f, 209.f, 210.f, 210.f, 211.f, 212.f, 213.f, 213.f, 214.f, 215.f, 216.f, 216.f, 214.f, 215.f, 216.f, 216.f}); + input.linspace(1.); + + sd::ops::maxpool3dnew op; + auto results = op.evaluate({&input}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test1) { + + int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=2,oH=2,oW=2; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); + auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, + 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, + 0.333333f, 0.666667f, 0.333333f, 0.666667f, 1.333333f, 0.666667f, 0.666667f, 1.333333f, 0.666667f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f, 0.333333f, 0.666667f, 0.333333f, 0.333333f, 0.666667f, 0.333333f, 0.166667f, 0.333333f, 0.166667f}); + input.linspace(1.); + gradO = 2.; + + sd::ops::avgpool3dnew_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test2) { + + int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=1,pH=1,pW=1, dD=1,dH=1,dW=1; + int oD=4,oH=4,oW=4; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); + auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, + 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 1.333333f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 1.333333f, 1.333333f, 1.333333f}); + input.linspace(1.); + gradO = 2.; + + sd::ops::avgpool3dnew_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + // output->printBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test3) { + + int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=3,oH=4,oW=3; + int paddingMode = 1; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}); + auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, + 0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, + 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f, + 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f, + 0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, + 0.41667f, 0.41667f, 0.41667f, 0.83333f, 0.83333f, 0.83333f, 1.25f, 1.25f, 1.25f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, + 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 0.83333f, 0.83333f, 0.83333f, 1.66667f, 1.66667f, 1.66667f, 2.5f, 2.5f, 2.5f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f, + 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 1.25f, 1.25f, 1.25f, 2.5f, 2.5f, 2.5f, 3.75f, 3.75f, 3.75f}); + input.linspace(1.); + gradO = 2.; + + sd::ops::avgpool3dnew_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 0, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, avgpool3d_bp_test4) { + + int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=3,oH=4,oW=3; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}); + auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0.16667f, 0.16667f, 0.16667f, 0.33333f, 0.33333f, 0.33333f, 0.5f, 0.5f, 0.5f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, + 0.91667f, 0.91667f, 0.91667f, 1.83333f, 1.83333f, 1.83333f, 2.75f, 2.75f, 2.75f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.66667f, 0.66667f, 0.66667f, 1.33333f, 1.33333f, 1.33333f, 2.f, 2.f, 2.f, + 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 1.83333f, 1.83333f, 1.83333f, 3.66667f, 3.66667f, 3.66667f, 5.5f, 5.5f, 5.5f, 0.5f, 0.5f, 0.5f, 1.f, 1.f, 1.f, 1.5f, 1.5f, 1.5f, + 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 2.75f, 2.75f, 2.75f, 5.5f, 5.5f, 5.5f, 8.25f, 8.25f, 8.25f, + 0.16667f, 0.16667f, 0.16667f, 0.33333f, 0.33333f, 0.33333f, 0.5f, 0.5f, 0.5f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.58333f, 0.58333f, 0.58333f, 1.16667f, 1.16667f, 1.16667f, 1.75f, 1.75f, 1.75f, + 0.91667f, 0.91667f, 0.91667f, 1.83333f, 1.83333f, 1.83333f, 2.75f, 2.75f, 2.75f, 0.33333f, 0.33333f, 0.33333f, 0.66667f, 0.66667f, 0.66667f, 1.f, 1.f, 1.f, 0.66667f, 0.66667f, 0.66667f, 1.33333f, 1.33333f, 1.33333f, 2.f, 2.f, 2.f, + 1.16667f, 1.16667f, 1.16667f, 2.33333f, 2.33333f, 2.33333f, 3.5f, 3.5f, 3.5f, 1.83333f, 1.83333f, 1.83333f, 3.66667f, 3.66667f, 3.66667f, 5.5f, 5.5f, 5.5f, 0.5f, 0.5f, 0.5f, 1.f, 1.f, 1.f, 1.5f, 1.5f, 1.5f, + 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 1.75f, 1.75f, 1.75f, 3.5f, 3.5f, 3.5f, 5.25f, 5.25f, 5.25f, 2.75f, 2.75f, 2.75f, 5.5f, 5.5f, 5.5f, 8.25f, 8.25f, 8.25f}); + input.linspace(1.); + gradO = 2.; + + sd::ops::avgpool3dnew_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 0, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test1) { + + int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=2,oH=2,oW=2; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); + auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.f, 0.3f, 0.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.5f, 0.6f, 0.f, 0.7f, 0.8f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 0.f, 1.9f, 2.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.1f, 2.2f, 0.f, 2.3f, 2.4f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.5f, 2.6f, 0.f, 2.7f, 2.8f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.9f, 3.f, 0.f, 3.1f, 3.2f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.3f, 3.4f, 0.f, 3.5f, 3.6f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 0.f, 3.9f, 4.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.1f, 4.2f, 0.f, 4.3f, 4.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 4.5f, 4.6f, 0.f, 4.7f, 4.8f}); + + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::maxpool3dnew_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test2) { + + int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=1,pH=1,pW=1, dD=1,dH=1,dW=1; + int oD=4,oH=4,oW=4; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oD, oH, oW}); + auto expected = NDArrayFactory::create('c', {bS, iC, iD, iH, iW}, {0.000e+00f, 0.000e+00f, 0.000e+00f, 1.000e-01f, 2.000e-01f, 7.000e-01f, 5.000e-01f, 6.000e-01f, 1.500e+00f, 2.200e+00f, 2.400e+00f, 5.400e+00f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.700e+00f, 1.800e+00f, 3.900e+00f, 2.100e+00f, 2.200e+00f, 4.700e+00f, 5.400e+00f, 5.600e+00f, 1.180e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 8.200e+00f, 8.400e+00f, 1.740e+01f, 9.000e+00f, 9.200e+00f, 1.900e+01f, 2.040e+01f, 2.080e+01f, 4.280e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 6.500e+00f, 6.600e+00f, 1.350e+01f, 6.900e+00f, 7.000e+00f, 1.430e+01f, 1.500e+01f, 1.520e+01f, 3.100e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 8.100e+00f, 8.200e+00f, 1.670e+01f, 8.500e+00f, 8.600e+00f, 1.750e+01f, 1.820e+01f, 1.840e+01f, 3.740e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.100e+01f, 2.120e+01f, 4.300e+01f, 2.180e+01f, 2.200e+01f, 4.460e+01f, 4.600e+01f, 4.640e+01f, 9.400e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.290e+01f, 1.300e+01f, 2.630e+01f, 1.330e+01f, 1.340e+01f, 2.710e+01f, 2.780e+01f, 2.800e+01f, 5.660e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.450e+01f, 1.460e+01f, 2.950e+01f, 1.490e+01f, 1.500e+01f, 3.030e+01f, 3.100e+01f, 3.120e+01f, 6.300e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 3.380e+01f, 3.400e+01f, 6.860e+01f, 3.460e+01f, 3.480e+01f, 7.020e+01f, 7.160e+01f, 7.200e+01f, 1.452e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 1.930e+01f, 1.940e+01f, 3.910e+01f, 1.970e+01f, 1.980e+01f, 3.990e+01f, 4.060e+01f, 4.080e+01f, 8.220e+01f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.090e+01f, 2.100e+01f, 4.230e+01f, 2.130e+01f, 2.140e+01f, 4.310e+01f, 4.380e+01f, 4.400e+01f, 8.860e+01f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 4.660e+01f, 4.680e+01f, 9.420e+01f, 4.740e+01f, 4.760e+01f, 9.580e+01f, 9.720e+01f, 9.760e+01f, 1.964e+02f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.570e+01f, 2.580e+01f, 5.190e+01f, 2.610e+01f, 2.620e+01f, 5.270e+01f, 5.340e+01f, 5.360e+01f, 1.078e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 2.730e+01f, 2.740e+01f, 5.510e+01f, 2.770e+01f, 2.780e+01f, 5.590e+01f, 5.660e+01f, 5.680e+01f, 1.142e+02f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 5.940e+01f, 5.960e+01f, 1.198e+02f, 6.020e+01f, 6.040e+01f, 1.214e+02f, 1.228e+02f, 1.232e+02f, 2.476e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 3.210e+01f, 3.220e+01f, 6.470e+01f, 3.250e+01f, 3.260e+01f, 6.550e+01f, 6.620e+01f, 6.640e+01f, 1.334e+02f, + 0.000e+00f, 0.000e+00f, 0.000e+00f, 3.370e+01f, 3.380e+01f, 6.790e+01f, 3.410e+01f, 3.420e+01f, 6.870e+01f, 6.940e+01f, 6.960e+01f, 1.398e+02f, 0.000e+00f, 0.000e+00f, 0.000e+00f, 7.220e+01f, 7.240e+01f, 1.454e+02f, 7.300e+01f, 7.320e+01f, 1.470e+02f, 1.484e+02f, 1.488e+02f, 2.988e+02f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::maxpool3dnew_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test3) { + + int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=3,oH=4,oW=3; + int paddingMode = 1; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}); + auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, { 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, + 0.f, 0.f, 0.f, 1.f, 1.1f, 1.2f, 2.9f, 3.1f, 3.3f, 0.f, 0.f, 0.f, 4.7f, 4.9f, 5.1f, 11.2f, 11.6f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 11.f, 11.2f, 11.4f, 23.8f, 24.2f, 24.6f, 0.f, 0.f, 0.f, 12.8f, 13.f, 13.2f, 27.4f, 27.8f, 28.2f, 0.f, 0.f, 0.f, 31.f, 31.4f, 31.8f, 65.6f, 66.39999f, 67.2f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.9f, 11.f, 11.1f, 22.7f, 22.9f, 23.1f, + 0.f, 0.f, 0.f, 11.8f, 11.9f, 12.f, 24.5f, 24.7f, 24.9f, 0.f, 0.f, 0.f, 26.3f, 26.5f, 26.7f, 54.4f, 54.8f, 55.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 32.6f, 32.8f, 33.f, 67.f, 67.4f, 67.8f, 0.f, 0.f, 0.f, 34.4f, 34.6f, 34.8f, 70.6f, 71.f, 71.4f, 0.f, 0.f, 0.f, 74.2f, 74.6f, 75.f, 152.f, 152.8f, 153.6f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::maxpool3dnew_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, maxpool3d_bp_test4) { + + int bS=2, iD=3,iH=4,iW=3, iC=3, kD=2,kH=3,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=1,dH=1,dW=1; + int oD=3,oH=4,oW=3; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}); + auto gradO = NDArrayFactory::create('c', {bS, oD, oH, oW, iC}); + auto expected = NDArrayFactory::create('c', {bS, iD, iH, iW, iC}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, 0.f, 0.f, 0.f, 5.7f, 6.f, 6.3f, + 14.1f, 14.7f, 15.3f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 11.f, 11.2f, 11.4f, 23.8f, 24.2f, + 24.6f, 0.f, 0.f, 0.f, 43.8f, 44.4f, 45.f, 93.f, 94.2f, 95.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 10.9f, 11.f, 11.1f, 22.7f, 22.9f, 23.1f, 0.f, 0.f, 0.f, 38.1f, 38.4f, 38.7f, 78.9f, 79.5f, 80.1f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 32.6f, 32.8f, 33.f, 67.f, 67.4f, 67.8f, 0.f, 0.f, 0.f, 108.6f, 109.2f, 109.8f, 222.6f, 223.8f, 225.f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::maxpool3dnew_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {kD,kH,kW, sD,sH,sW, pD,pH,pW, dD,dH,dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, maxpool2d_bp_1) { + + auto input = NDArrayFactory::create_('c', {bS,iD,iH,iW}); + auto epsilon = NDArrayFactory::create_('c', {bS,iD,oH,oW}); + auto exp = NDArrayFactory::create('c', {bS,iD,iH,iW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, input); + variableSpace->putVariable(-2, epsilon); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + block->fillInputs({-2}); + std::vector* argI = block->getIArguments(); + *argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + sd::ops::maxpool2d_bp bp; + Nd4jStatus status = bp.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + ASSERT_TRUE(exp.isSameShape(result)); + + delete variableSpace; + delete block; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, maxpool2d_bp_2) { + + int bS=2, iD=1, iH=4,iW=4, oD=3, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; + int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; + + // TypeParam epsilonBuff[] = {6., 7., 8., 10., 11., 12., 14., 15., 16., 22., 23., 24., 26., 27., 28., 30., 31., 32.}; + // TypeParam expectedBuff[] = {0., 0., 0., 0.,0., 6., 7., 8.,0.,10.,11.,12.,0.,14.,15.,16.,0., 0., 0., 0.,0.,22.,23.,24.,0.,26.,27.,28.,0.,30.,31.,32.}; + + NDArray input('c', {bS,iD,iH,iW}); + NDArray epsilon('c', {bS,iD,oH,oW}, {6., 7., 8., 10., 11., 12., 14., 15., 16., 22., 23., 24., 26., 27., 28., 30., 31., 32.}); + NDArray expected('c', {bS,iD,iH,iW}, {0., 0., 0., 0.,0., 6., 7., 8.,0.,10.,11.,12.,0.,14.,15.,16.,0., 0., 0., 0.,0.,22.,23.,24.,0.,26.,27.,28.,0.,30.,31.,32.}); + + + input.linspace(1.); + + std::initializer_list argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + sd::ops::maxpool2d_bp op; + auto results = op.evaluate({&input, &epsilon}, {}, argI); + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_3) { + + int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=2,oW=2; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); + auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.f, 0.3f, 0.4f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.5f, 0.6f, 0.f, 0.7f, 0.8f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.9f, 1.f, 0.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.3f, 1.4f, 0.f, 1.5f, 1.6f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 0.f, 1.9f, 2.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.1f, 2.2f, 0.f, 2.3f, 2.4f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::maxpool2d_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_4) { + + int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=1,pW=1, dH=1,dW=1; + int oH=4,oW=4; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); + auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.f, 0.f, 0.f, 0.1f, 0.2f, 0.7f, 0.5f, 0.6f, 1.5f, 2.2f, 2.4f, 5.4f, 0.f, 0.f, 0.f, 1.7f, 1.8f, 3.9f, 2.1f, 2.2f, 4.7f, 5.4f, 5.6f, 11.8f, + 0.f, 0.f, 0.f, 3.3f, 3.4f, 7.1f, 3.7f, 3.8f, 7.9f, 8.6f, 8.8f, 18.2f, 0.f, 0.f, 0.f, 4.9f, 5.f, 10.3f, 5.3f, 5.4f, 11.1f, 11.8f, 12.f, 24.6f, + 0.f, 0.f, 0.f, 6.5f, 6.6f, 13.5f, 6.9f, 7.f, 14.3f, 15.f, 15.2f, 31.f, 0.f, 0.f, 0.f, 8.1f, 8.2f, 16.7f, 8.5f, 8.6f, 17.5f, 18.2f, 18.4f, 37.4f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::maxpool2d_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_5) { + + int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=4,oW=3; + int paddingMode = 1; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto gradO = NDArrayFactory::create('c', {bS, oH, oW, iC}); + auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 1.1f, 1.3f, 1.5f, 0.f, 0.f, 0.f, 1.f, 1.1f, 1.2f, 2.9f, 3.1f, 3.3f, + 0.f, 0.f, 0.f, 4.7f, 4.9f, 5.1f, 11.2f, 11.6f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 3.7f, 3.8f, 3.9f, 8.3f, 8.5f, 8.7f, + 0.f, 0.f, 0.f, 4.6f, 4.7f, 4.8f, 10.1f, 10.3f, 10.5f, 0.f, 0.f, 0.f, 11.9f, 12.1f, 12.3f, 25.6f, 26.f, 26.4f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::maxpool2d_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, maxpool2d_bp_6) { + + int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=2,oW=2; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto gradO = NDArrayFactory::create('c', {bS, oH, oW, iC}); + auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, + 0.f, 0.f, 0.f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f, 0.f, 0.f, 0.f, 1.9f, 2.f, 2.1f, 2.2f, 2.3f, 2.4f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::maxpool2d_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, maxpool2d_bp_7) { + + int bS=2, iH=56,iW=56, iC=3, kH=2,kW=2, sH=2,sW=2, pH=0,pW=0, dH=1,dW=1; + int oH=28,oW=28; + int paddingMode = 1; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); + + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::maxpool2d_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); + // auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + // ASSERT_TRUE(expected.isSameShape(output)); + // ASSERT_TRUE(expected.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, avgpool2d_bp_1) { + + auto input = NDArrayFactory::create_('c', {bS,iD,iH,iW}); + auto epsilon = NDArrayFactory::create_('c', {bS,iD,oH,oW}); + auto exp = NDArrayFactory::create('c', {bS,iD,iH,iW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, input); + variableSpace->putVariable(-2, epsilon); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + block->fillInputs({-2}); + std::vector* argI = block->getIArguments(); + *argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 1, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode, 9 - extraParam0 (unnecessary for avg mode), 10 - data format + + sd::ops::avgpool2d_bp bp; + Nd4jStatus status = bp.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + ASSERT_TRUE(exp.isSameShape(result)); + + delete variableSpace; + delete block; +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_2) { + + int bS=2, iD=1, iH=4,iW=4, oD=3, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; + int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; + + // TypeParam epsilonBuff[] = {3.5 , 4.5 , 5.5, 7.5 , 8.5 , 9.5, 11.5, 12.5, 13.5, 19.5, 20.5, 21.5, 23.5, 24.5, 25.5, 27.5, 28.5, 29.5}; + // TypeParam expectedBuff[] = {0.875, 2., 2.5,1.375, 2.75 , 6., 7., 3.75, 4.75 ,10., 11., 5.75, 2.875, 6., 6.5, 3.375, 4.875, 10.,10.5, 5.375, 10.75, 22.,23., 11.75, 12.75, 26.,27., 13.75, 6.875, 14.,14.5, 7.375}; + + auto input = NDArrayFactory::create('c', {bS,iD,iH,iW}); + auto epsilon = NDArrayFactory::create('c', {bS,iD,oH,oW}, {3.5f, 4.5f, 5.5f, 7.5f, 8.5f, 9.5f, 11.5f, 12.5f, 13.5f, 19.5f, 20.5f, 21.5f, 23.5f, 24.5f, 25.5f, 27.5f, 28.5f, 29.5f}); + auto expected = NDArrayFactory::create('c', {bS,iD,iH,iW}, {0.875f, 2.f, 2.5f, 1.375f, 2.75f, 6.f, 7.f, 3.75f, 4.75f, 10.f, 11.f, 5.75f, 2.875f, 6.f, 6.5f, 3.375f, 4.875f, 10.f, 10.5f, 5.375f, 10.75f, 22.f, 23.f, 11.75f, 12.75f, 26.f, 27.f, 13.75f, 6.875f, 14.f, 14.5f, 7.375f}); + + input.linspace(1.); + + std::initializer_list argI = {kH,kW, sH,sW, pH,pW, dW,dH, 1, 1, 0}; + + sd::ops::avgpool2d_bp op; + auto results = op.evaluate({&input, &epsilon}, {}, argI); + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_3) { + + int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=2,oW=2; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); + auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.016667f, 0.05f, 0.033333f, 0.066667f, 0.166667f, 0.1f, 0.066667f, 0.166667f, 0.1f, 0.05f, 0.116667f, 0.066667f, + 0.083333f, 0.183333f, 0.1f, 0.2f, 0.433333f, 0.233333f, 0.2f, 0.433333f, 0.233333f, 0.116667f, 0.25f, 0.133333f, + 0.15f, 0.316667f, 0.166667f, 0.333333f, 0.7f, 0.366667f, 0.333333f, 0.7f, 0.366667f, 0.183333f, 0.383333f, 0.2f, + 0.216667f, 0.45f, 0.233333f, 0.466667f, 0.966667f, 0.5f, 0.466667f, 0.966667f, 0.5f, 0.25f, 0.516667f, 0.266667f, + 0.283333f, 0.583333f, 0.3f, 0.6f, 1.233333f, 0.633333f, 0.6f, 1.233333f, 0.633333f, 0.316667f, 0.65f, 0.333333f, + 0.35f, 0.716667f, 0.366667f, 0.733333f, 1.5f, 0.766667f, 0.733333f, 1.5f, 0.766667f, 0.383333f, 0.783333f, 0.4f }); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::avgpool2d_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_4) { + + int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=1,pW=1, dH=1,dW=1; + int oH=4,oW=4; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); + auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.233333f, 0.3f, 0.366667f, 0.55f, 0.65f, 0.75f, 0.95f, 1.05f, 1.15f, 0.766667f, 0.833333f, 0.9f, + 1.3f, 1.366667f, 1.433333f, 2.15f, 2.25f, 2.35f, 2.55f, 2.65f, 2.75f, 1.833333f, 1.9f, 1.966667f, + 2.366667f, 2.433333f, 2.5f, 3.75f, 3.85f, 3.95f, 4.15f, 4.25f, 4.35f, 2.9f, 2.966667f, 3.033333f, + 3.433333f, 3.5f, 3.566667f, 5.35f, 5.45f, 5.55f, 5.75f, 5.85f, 5.95f, 3.966667f, 4.033333f, 4.1f, + 4.5f, 4.566667f, 4.633333f, 6.95f, 7.05f, 7.15f, 7.35f, 7.45f, 7.55f, 5.033333f, 5.1f, 5.166667f, + 5.566667f, 5.633333f, 5.7f, 8.549999f, 8.65f, 8.75f, 8.95f, 9.05f, 9.150001f, 6.1f, 6.166667f, 6.233334f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::avgpool2d_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + + +//////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_5) { + + int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=4,oW=3; + int paddingMode = 1; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto gradO = NDArrayFactory::create('c', {bS, oH, oW, iC}); + auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.19167f, 0.23333f, 0.275f, 0.50833f, 0.59167f, 0.675f, 1.2f, 1.325f, 1.45f, 0.50833f, 0.56667f, 0.625f, 1.19167f, 1.30833f, 1.425f, 2.4f, 2.575f, 2.75f, + 1.18333f, 1.24167f, 1.3f, 2.54167f, 2.65833f, 2.775f, 4.425f, 4.6f, 4.775f, 1.01667f, 1.05833f, 1.1f, 2.15833f, 2.24167f, 2.325f, 3.675f, 3.8f, 3.925f, + 1.69167f, 1.73333f, 1.775f, 3.50833f, 3.59167f, 3.675f, 5.7f, 5.825f, 5.95f, 2.60833f, 2.66667f, 2.725f, 5.39167f, 5.50833f, 5.625f, 8.7f, 8.875f, 9.05f, + 3.28333f, 3.34167f, 3.4f, 6.74167f, 6.85833f, 6.975f, 10.725f, 10.9f, 11.075f, 2.51667f, 2.55833f, 2.6f, 5.15833f, 5.24167f, 5.325f, 8.175f, 8.3f, 8.425f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::avgpool2d_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 0, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, avgpool2d_bp_6) { + + int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=2,oW=2; + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto gradO = NDArrayFactory::create('c', {bS, oH, oW, iC}); + auto expected = NDArrayFactory::create('c', {bS, iH, iW, iC}, {0.01667f, 0.03333f, 0.05f, 0.08333f, 0.11667f, 0.15f, 0.06667f, 0.08333f, 0.1f, 0.13333f, 0.16667f, 0.2f, 0.36667f, 0.43333f, 0.5f, 0.23333f, 0.26667f, 0.3f, + 0.13333f, 0.16667f, 0.2f, 0.36667f, 0.43333f, 0.5f, 0.23333f, 0.26667f, 0.3f, 0.11667f, 0.13333f, 0.15f, 0.28333f, 0.31667f, 0.35f, 0.16667f, 0.18333f, 0.2f, + 0.21667f, 0.23333f, 0.25f, 0.48333f, 0.51667f, 0.55f, 0.26667f, 0.28333f, 0.3f, 0.53333f, 0.56667f, 0.6f, 1.16667f, 1.23333f, 1.3f, 0.63333f, 0.66667f, 0.7f, + 0.53333f, 0.56667f, 0.6f, 1.16667f, 1.23333f, 1.3f, 0.63333f, 0.66667f, 0.7f, 0.31667f, 0.33333f, 0.35f, 0.68333f, 0.71667f, 0.75f, 0.36667f, 0.38333f, 0.4f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::avgpool2d_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 1, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, pnormpool2d_bp_1) { + + auto input = NDArrayFactory::create_('c', {bS,iD,iH,iW}); + auto epsilon = NDArrayFactory::create_('c', {bS,iD,oH,oW}); + auto exp = NDArrayFactory::create('c', {bS,iD,iH,iW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, input); + variableSpace->putVariable(-2, epsilon); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + block->fillInputs({-2}); + auto argI = block->getIArguments(); + *argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 3}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; 9 - divisor + std::vector* argT = block->getTArguments(); + *argT = {0.000001}; + + sd::ops::pnormpool2d_bp bp; + Nd4jStatus status = bp.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + ASSERT_TRUE(exp.isSameShape(result)); + + delete variableSpace; + delete block; +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, pnormpool2d_bp_2) { + + int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=2,oW=2; + int pnorm = 3; + double eps = 0.; + + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); + auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {9.661570e-04f, 9.671602e-03f, 1.306569e-02f, 3.679184e-02f, 1.297220e-01f, 1.040181e-01f, 1.126750e-01f, 3.320884e-01f, 2.340406e-01f, 1.333333e-01f, 3.352886e-01f, 2.070211e-01f, + 8.991618e-02f, 2.160601e-01f, 1.283173e-01f, 2.744226e-01f, 6.364498e-01f, 3.662123e-01f, 3.869788e-01f, 8.808994e-01f, 4.984556e-01f, 2.613189e-01f, 5.818475e-01f, 3.225517e-01f, + 2.065654e-01f, 4.553546e-01f, 2.501175e-01f, 5.190718e-01f, 1.131343e+00f, 6.148388e-01f, 6.362602e-01f, 1.377521e+00f, 7.439550e-01f, 3.833026e-01f, 8.227519e-01f, 4.407146e-01f, + 3.261206e-01f, 6.969233e-01f, 3.717564e-01f, 7.627507e-01f, 1.620991e+00f, 8.600952e-01f, 8.814538e-01f, 1.866888e+00f, 9.873542e-01f, 5.046682e-01f, 1.064004e+00f, 5.602558e-01f, + 4.464697e-01f, 9.389536e-01f, 4.932274e-01f, 1.005908e+00f, 2.108550e+00f, 1.104095e+00f, 1.125322e+00f, 2.354009e+00f, 1.230180e+00f, 6.258913e-01f, 1.305581e+00f, 6.804127e-01f, + 5.671396e-01f, 1.181128e+00f, 6.145977e-01f, 1.248783e+00f, 2.595083e+00f, 1.347494e+00f, 1.368600e+00f, 2.840157e+00f, 1.472778e+00f, 7.470673e-01f, 1.547362e+00f, 8.008900e-01f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::pnormpool2d_bp op; + auto results = op.evaluate({&input, &gradO}, {eps}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, pnorm, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, pnormpool2d_bp_3) { + + int bS=2, iH=4,iW=3, iC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=2,oW=2; + int pnorm = 2; + double eps = 0.; + + int paddingMode = 0; // 1-SAME, 0-VALID + int dataFormat = 0; // 1-NDHWC, 0-NCDHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, oH, oW}); + auto expected = NDArrayFactory::create('c', {bS, iC, iH, iW}, {0.007931f, 0.042891f, 0.040544f, 0.09369f, 0.276841f, 0.191675f, 0.163957f, 0.442946f, 0.287512f, 0.154919f, 0.373153f, 0.221172f, + 0.15901f, 0.365232f, 0.207846f, 0.428282f, 0.959455f, 0.534076f, 0.508585f, 1.128771f, 0.623089f, 0.319794f, 0.698063f, 0.379547f, + 0.321068f, 0.692438f, 0.372316f, 0.757521f, 1.620323f, 0.864566f, 0.838684f, 1.787943f, 0.951023f, 0.483194f, 1.023434f, 0.541058f, + 0.483937f, 1.019414f, 0.536145f, 1.085348f, 2.276996f, 1.192917f, 1.166749f, 2.443606f, 1.278126f, 0.646499f, 1.349361f, 0.703463f, + 0.647021f, 1.346249f, 0.699745f, 1.412654f, 2.932174f, 1.520512f, 1.494153f, 3.098146f, 1.604985f, 0.809791f, 1.675544f, 0.866229f, + 0.810192f, 1.673009f, 0.863237f, 1.739711f, 3.58665f, 1.847753f, 1.82126f, 3.752188f, 1.931741f, 0.973081f, 2.001861f, 1.029173f}); + input.linspace(1.); + gradO.linspace(0.1, 0.1); + + sd::ops::pnormpool2d_bp op; + auto results = op.evaluate({&input, &gradO}, {eps}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, pnorm, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, upsampling2d_bp_1) { + + const int bS=1, iH=2,iW=2, iC=1; + const int factorH=2, factorW=2; + const int isNCHW = 1; // data format, default is NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto gradO = NDArrayFactory::create('c', {bS, iC, iH*factorH, iW*factorW}); + gradO = 1.; + + auto expGradI = NDArrayFactory::create('c', {bS, iC, iH, iW}); + expGradI = 4.; + + sd::ops::upsampling2d_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {isNCHW}); + auto* gradI = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, upsampling2d_bp_2) { + + const int bS=1, iH=2,iW=2, iC=1; + const int factorH=2, factorW=2; + const int isNCHW = 0; // data format, default is NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto gradO = NDArrayFactory::create('c', {bS, iH*factorH, iW*factorW, iC}); + gradO = 1.; + + auto expGradI = NDArrayFactory::create('c', {bS, iH, iW, iC}); + expGradI = 4.; + + sd::ops::upsampling2d_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {isNCHW}); + auto* gradI = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, upsampling2d_bp_3) { + + const int bS=1, iH=3,iW=3, iC=2; + const int factorH=2, factorW=2; + const int isNCHW = 1; // data format, default is NCHW + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + + NDArray gradO('c', {bS, iC, iH*factorH, iW*factorW}, {0.6793504, 0.35508695, 0.84278935, 0.20031333, 0.7014987, 0.31069338, 0.44793984, + 0.93800974, 0.32667395, 0.15187258, 0.38331753, 0.78212297, 0.1988072, 0.7985636, 0.1632634, 0.14696825, 0.26089668, 0.13505761, + 0.7562093, 0.27545404, 0.36908787, 0.09282647, 0.83649176, 0.26841334, 0.09506222, 0.31279507, 0.13591796, 0.5175439, 0.32870287, + 0.061735712, 0.39643127, 0.248016, 0.5489592, 0.115046196, 0.8143622, 0.7215636, 0.40449402, 0.29908907, 0.4038839, 0.9883108, + 0.022296403, 0.927782, 0.3184157, 0.0685462, 0.28453344, 0.23272, 0.35214192, 0.058909304, 0.7112212, 0.6744568, 0.19694561, 0.6994972, + 0.0743224, 0.42042503, 0.5842631, 0.14957358, 0.44640633, 0.72307247, 0.06448108, 0.48307765, 0.8759956, 0.5698191, 0.4458631, 0.5277549, + 0.016646361, 0.753678, 0.14063567, 0.7541292, 0.16193217, 0.7750374, 0.3326449, 0.11739397}, sd::DataType::FLOAT32); + + NDArray expGradI('c', {bS, iC, iH, iW}, {2.4203868, 1.5216494, 2.1776323, 2.0290341, 0.772146, 1.5008594, 1.0523045, 1.3174672, 1.9263644, + 1.090545, 1.9094483, 1.3611296, 2.1195147, 2.0659215, 1.0423062, 2.3405795, 1.9105877, 1.2203633}, sd::DataType::FLOAT32); + + sd::ops::upsampling2d_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {isNCHW}); + auto* gradI = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedConvolutionTests2, depthwise_conv2d_1) { + + int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oC=iC*mC; + int oH=4,oW=3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); + + + auto expOutput = NDArrayFactory::create('c', {bS, oH, oW, oC},{12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, + 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f, + 12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, + 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f}); + input = 2.; + weights.linspace(0.1, 0.1); + + sd::ops::depthwise_conv2d op; + auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto* output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_2) { + + int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oC=iC*mC; + int oH=2,oW=2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); + + + auto expOutput = NDArrayFactory::create('c', {bS, oH, oW, oC},{13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, + 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f}); + input = 2.; + weights.linspace(0.1, 0.1); + + sd::ops::depthwise_conv2d op; + auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto* output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_3) { + + int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oC=iC*mC; + int oH=2,oW=2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto weights = NDArrayFactory::create('c', {mC, iC, kH, kW}); + auto biases = NDArrayFactory::create('c', {iC*mC}, {1.f,2.f,3.f,4.f}); + + NDArray expOutput('c', {bS, oC, oH, oW},{5.2, 5.2, 5.2, 5.2,20.6,20.6,20.6,20.6,14.4,14.4,14.4,14.4,29.8,29.8,29.8,29.8, 5.2, 5.2, 5.2, 5.2,20.6,20.6,20.6,20.6,14.4,14.4,14.4,14.4,29.8,29.8,29.8,29.8}, sd::DataType::FLOAT32); + + input = 2.; + weights.linspace(0.1, 0.1); + weights.permutei({2,3,1,0}); + + sd::ops::depthwise_conv2d op; + auto results = op.evaluate({&input, &weights, &biases}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto* output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_4) { + + int bS=1, iH=111,iW=111, iC=32,mC=1, kH=7,kW=7, sH=2,sW=2, pH=0,pW=0, dH=1,dW=1; + int oC=iC*mC; + int oH=56,oW=56; + + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + const float unique = -1000000; + + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); + NDArray output('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32); + input.linspace(0.1, 0.0001); + weights = 0.5; + output = unique; + + sd::ops::depthwise_conv2d op; + Nd4jStatus status = op.execute({&input, &weights}, {&output} , {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}, {}); + + ASSERT_EQ(Status::OK(), status); + + for(Nd4jLong i=output.lengthOf()/1.5; i < output.lengthOf(); ++i) + ASSERT_EQ(output.e(i) != unique, true); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_5) { + + int bS=1, iH=3,iW=3, iC=2,mC=1, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oC=iC*mC; + int oH=3,oW=3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); + + NDArray expOutput('c', {bS, oH, oW, oC}, {10., 12., 14., 16., 8., 9., 22., 24., 26., 28., 14., 15., 14., 15., 16., 17., 8.5, 9.}, sd::DataType::FLOAT32); + + input.linspace(1.); + weights = 0.5; + + sd::ops::depthwise_conv2d op; + auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_6) { + + int bS=1, iH=3,iW=3, iC=2,mC=1, kH=2,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oC=iC*mC; + int oH=3,oW=3; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); + + NDArray expOutput('c', {bS, oH, oW, oC}, {20., 24.,28., 32.,16., 18.,44., 48.,52., 56.,28., 30.,28., 30.,32., 34.,17., 18.}, sd::DataType::FLOAT32); + input.linspace(1.); + weights = 1.; + + sd::ops::depthwise_conv2d op; + auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + NDArray* output = results.at(0); + // output.printIndexedBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_7) { + + int bS=1, iH=3,iW=3, iC=2,mC=2, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oC=iC*mC; + int oH=3,oW=3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iH, iW}, {0.6793503761291504, 0.35508695244789124, 0.842789351940155, 0.20031332969665527, 0.7014986872673035, 0.3106933832168579, + 0.44793984293937683, 0.9380097389221191, 0.3266739547252655, 0.15187257528305054, 0.3833175301551819, 0.7821229696273804, + 0.19880719482898712, 0.7985635995864868, 0.16326339542865753, 0.14696824550628662, 0.2608966827392578, 0.13505761325359344}, sd::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, mC}, {0.1308445781469345, 0.6442840099334717, 0.5698848366737366, 0.19896849989891052}, sd::DataType::FLOAT32); + NDArray biases('c', {1,iC*mC}, {0.6123566627502441, 0.37637925148010254, 0.17464971542358398, 0.4270855486392975}, sd::DataType::FLOAT32); + + NDArray expOutput('c', {bS, oC, oH, oW}, {0.7012459761288241, 0.6588178652487691, 0.722631079971582, 0.6385665758716108, 0.7041439625563628, 0.6530092074102978, + 0.670967162534851, 0.735090151337225, 0.6551001785478623, 0.8140738359624038, 0.6051560970782859, 0.9193749546773375, 0.5054379267801892, 0.8283436386757472, + 0.5765540302788565, 0.6649797296980537, 0.9807239274294943, 0.586850056971322, 0.261199593183985, 0.3930965634902499, 0.6203697362284615, 0.28794692117826504, + 0.6297390019475202, 0.26769104886224415, 0.25840469001015975, 0.3233307788551656, 0.25161700129415276, 0.4573034071191504, 0.5033536625992294, 0.5827033826425385, + 0.4666419179635315, 0.585974550122895, 0.4595698215161401, 0.45632759998045813, 0.4789957702325296, 0.4539577593482922}, sd::DataType::FLOAT32); + + + sd::ops::depthwise_conv2d op; + auto results = op.evaluate({&input, &weights, &biases}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto* output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_8) { + + int bS=1, iH=10,iW=10, iC=8,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oC=iC*mC; + int oH=10,oW=10; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); + + NDArray expOutput('c', {bS, oH, oW, oC}, {-42.879997, -43.959999, -44.959999, -45.879997, -46.720005, -47.480003, -48.160000, -48.760002, -43.519997, -45.139999, -46.639996, -48.020000, -49.280003, -50.419998, -51.440006, -52.340000, -31.999998, -33.139999, -34.160000, -35.060001, -35.840004, -36.500004, -37.039997, -37.459999, -20.480000, + -21.139997, -21.680000, -22.100000, -22.399998, -22.579998, -22.639996, -22.580002, -8.960000, -9.139998, -9.200002, -9.140001, -8.960001, -8.660000, -8.240002, -7.700001, 2.560000, 2.860002, 3.279998, 3.820000, 4.480001, 5.260000, 6.160001, 7.180000, 14.080000, 14.860000, 15.759998, 16.779999, 17.920002, 19.180000, 20.560001, 22.059998, + 25.600000, 26.860001, 28.239998, 29.739998, 31.360001, 33.099998, 34.959999, 36.939999, 37.119999, 38.860001, 40.720001, 42.699997, 44.800003, 47.020000, 49.360001, 51.820000, 26.239998, 27.400002, 28.639999, 29.959999, 31.360001, 32.840000, 34.400002, 36.040001, 62.400002, 62.459999, 62.639999, 62.940002, 63.360001, 63.900002, 64.559998, + 65.340004, 106.080002, 106.169998, 106.440002, 106.889999, 107.519997, 108.330002, 109.320000, 110.490005, 114.720001, 115.529999, 116.520004, 117.690002, 119.040009, 120.570000, 122.279999, 124.169998, 123.359985, 124.889999, 126.599998, 128.490005, 130.559998, 132.809998, 135.240005, 137.850006, 132.000000, 134.250000, 136.679993, + 139.290009, 142.080002, 145.049988, 148.199997, 151.529999, 140.639999, 143.610001, 146.760010, 150.089996, 153.600006, 157.290009, 161.160004, 165.209991, 149.279999, 152.970001, 156.839996, 160.889999, 165.120010, 169.529999, 174.119995, 178.889999, 157.919998, 162.330002, 166.919983, 171.690002, 176.639999, 181.769989, 187.079987, + 192.570007, 166.559998, 171.690002, 177.000000, 182.489990, 188.160004, 194.010010, 200.040009, 206.250000, 100.799995, 104.220001, 107.760002, 111.419998, 115.200005, 119.099998, 123.120003, 127.260010, 139.200012, 144.059998, 149.040009, 154.139999, 159.360001, 164.699997, 170.160004, 175.739990, 192.479996, 199.770020, 207.239990, + 214.889999, 222.720001, 230.730011, 238.919998, 247.290009, 201.119995, 209.129990, 217.319992, 225.690002, 234.240005, 242.970001, 251.880005, 260.970001, 209.760010, 218.489990, 227.399994, 236.490005, 245.760010, 255.209991, 264.839996, 274.649994, 218.399994, 227.850006, 237.479996, 247.289993, 257.279999, 267.449982, 277.799988, + 288.330017, 227.040009, 237.209991, 247.559998, 258.089996, 268.800018, 279.690002, 290.760010, 302.010010, 235.679993, 246.570007, 257.639984, 268.889984, 280.320007, 291.929993, 303.720001, 315.690002, 244.320007, 255.929993, 267.720001, 279.690002, 291.839996, 304.169983, 316.679993, 329.369995, 252.959991, 265.290009, 277.799988, 290.489990, 303.359985, 316.410004, 329.640015, 343.050018, 139.199997, 147.419998, 155.760010, 164.220001, 172.799988, 181.500000, 190.319992, 199.260010, 216.000000, 225.660004, 235.440002, 245.339996, 255.360016, 265.500000, 275.760010, 286.140015, 278.880005, 293.369995, 308.040009, 322.889984, 337.920013, 353.129974, 368.519989, + 384.090027, 287.520020, 302.730011, 318.119995, 333.690002, 349.440002, 365.369995, 381.479980, 397.770020, 296.160004, 312.089996, 328.199982, 344.489990, 360.960022, 377.609985, 394.440002, 411.449982, 304.799988, 321.450012, 338.280029, 355.289978, 372.480011, 389.850006, 407.399994, 425.130005, 313.440002, 330.809998, 348.359985, 366.089996, 384.000000, 402.090027, 420.359985, 438.809998, 322.079987, 340.169983, 358.440002, 376.889984, 395.520020, 414.329987, 433.320007, 452.489990, 330.720001, 349.530029, 368.520020, 387.690002, 407.039978, 426.570007, 446.279999, 466.170013, 339.360016, 358.890015, 378.599976, 398.490021, 418.559998, 438.809998, 459.239990, 479.849976, 177.600006, 190.619995, 203.759995, 217.020004, 230.399994, 243.899994, 257.519989, 271.260010, 292.799988, 307.260010, 321.839996, 336.539978, 351.360016, 366.299988, 381.359985, 396.540009, 365.279999, 386.970001, 408.839996, 430.889984, 453.120026, 475.529968, 498.119995, 520.890015, 373.920013, 396.329987, 418.919983, 441.690002, 464.640015, 487.769958, 511.079987, 534.570007, 382.559998, 405.690002, 429.000000, 452.489990, 476.160004, 500.010010, 524.039978, 548.250000, 391.200012, 415.049988, 439.080017, 463.290009, 487.679993, 512.250000, 537.000000, 561.930054, 399.839996, 424.409973, 449.160034, 474.089966, 499.200012, 524.489990, 549.959961, 575.609985, 408.479980, 433.770020, 459.239990, 484.889954, 510.720032, 536.729980, 562.919983, 589.290039, 417.119995, 443.130005, 469.319977, 495.690002, 522.239990, 548.969971, 575.880005, 602.969971, 425.760010, 452.489990, 479.399994, 506.489990, 533.760010, 561.209961, 588.839966, 616.650024, 216.000000, 233.819992, 251.760010, 269.820007, 288.000000, 306.299988, 324.719971, 343.260010, 369.600006, 388.859985, 408.239990, 427.739990, 447.360016, 467.100006, 486.959961, 506.940002, 451.679993, 480.570007, 509.639984, 538.890015, 568.320007, 597.929993, 627.719971, 657.690002, 460.320007, 489.929993, 519.719971, 549.690002, 579.840027, 610.170044, 640.680054, 671.369995, 468.960022, 499.289978, 529.799988, 560.489990, 591.359985, 622.409973, 653.640015, 685.049988, 477.599976, 508.650024, 539.880005, 571.289978, 602.880005, 634.650024, 666.599976, 698.729980, 486.239990, 518.010010, 549.960022, 582.089966, 614.400024, 646.890015, 679.559937, 712.410034, 494.879974, 527.369995, 560.039978, 592.890015, 625.920044, 659.130005, 692.520020, 726.089966, 503.519989, 536.729980, 570.119995, 603.689941, 637.440063, 671.369995, 705.480042, 739.770020, 512.160034, 546.089966, 580.199951, 614.489990, 648.960022, 683.609985, 718.440002, 753.449951, 254.400009, 277.020020, 299.760010, 322.619995, 345.600006, 368.700012, 391.919983, 415.260010, 446.399994, 470.459961, 494.640015, 518.940002, 543.360046, 567.900024, 592.559998, 617.340027, 538.080017, 574.170044, 610.440002, 646.890015, 683.520020, 720.329956, 757.320007, 794.489990, 546.719971, 583.530029, 620.520020, 657.690002, 695.040039, 732.570007, 770.279968, 808.169983, 555.359985, 592.889954, 630.599976, 668.489990, 706.559998, 744.809998, 783.239990, 821.849976, 564.000000, 602.250000, 640.679993, 679.289978, 718.080017, 757.050049, 796.199951, 835.530029, 572.640015, 611.609985, 650.760010, 690.089966, 729.600037, 769.289978, 809.160034, 849.210083, 581.279968, 620.970032, 660.839966, 700.889954, 741.119995, 781.529968, 822.119995, 862.890015, 589.919983, 630.330017, 670.919983, 711.690002, 752.640015, 793.770020, 835.079956, 876.570007, 598.559998, 639.690002, 681.000000, 722.490051, 764.160034, 806.010010, 848.039978, 890.250061, 292.799988, 320.220001, 347.760010, 375.419983, 403.200012, 431.100006, 459.119995, 487.260010, 523.199951, 552.059998, 581.040039, 610.139954, 639.360046, 668.699951, 698.159973, 727.739990, 624.479980, 667.770020, 711.239990, 754.890015, 798.719971, 842.729980, 886.919983, 931.290039, 633.119995, 677.130005, 721.319946, 765.690002, 810.239990, 854.969971, 899.880005, 944.969971, 641.760010, 686.489990, 731.400024, 776.489990, 821.760010, 867.209961, 912.839966, 958.650024, 650.400024, 695.849976, 741.479980, 787.290039, 833.279968, 879.449951, 925.799927, 972.330017, 659.040039, 705.210022, 751.559998, 798.089966, 844.800049, 891.690002, 938.760010, 986.010010, 667.679993, 714.569946, 761.640015, 808.890015, 856.320007, 903.929993, 951.719971, 999.690063, 676.320007, 723.929993, 771.719971, 819.690002, 867.839966, 916.169922, 964.679932, 1013.369995, 684.959961, 733.290039, 781.800049, 830.489990, 879.359985, 928.410034, 977.640015, 1027.050049, 331.199982, 363.419983, 395.760010, 428.220001, 460.799988, 493.500000, 526.320007, 559.260010, 600.000000, 633.660034, 667.440002, 701.339966, 735.359985, 769.500000, 803.759949, 838.140015, 710.880005, 761.369995, 812.039978, 862.889893, 913.919983, 965.130005, 1016.520020, 1068.090088, 719.520020, 770.729980, 822.119934, 873.689941, 925.440063, 977.369995, 1029.479980, 1081.770020, 728.160034, 780.090088, 832.199951, 884.489990, 936.960022, 989.610046, 1042.439941, 1095.449951, 736.799927, 789.449951, 842.280029, 895.290039, 948.480042, 1001.849976, 1055.399902, 1109.129883, 745.439941, 798.810059, 852.359985, 906.089966, 960.000000, 1014.089966, 1068.359985, 1122.810059, 754.080017, 808.170044, 862.440002, 916.890015, 971.520020, 1026.330078, 1081.319946, 1136.489990, 762.720032, 817.530029, 872.520020, 927.689941, 983.040039, 1038.569946, 1094.280029, 1150.169922, 771.359985, 826.890015, 882.599976, 938.489990, 994.559998, 1050.810059, 1107.239990, 1163.849976, 369.599976, 406.619995, 443.760010, 481.020020, 518.400024, 555.900024, 593.520020, 631.260010, 113.279999, 136.839996, 160.480011, 184.199982, 208.000015, 231.880005, 255.839996, 279.880005, 31.359985, 66.699989, 102.160004, 137.740005, 173.440002, 209.260010, 245.199982, 281.260010, 31.359993, 67.179993, 103.120003, 139.179993, 175.360016, 211.660004, 248.079987, 284.619995, 31.359993, 67.659996, 104.080009, 140.619995, 177.280014, 214.060013, 250.959991, 287.980011, 31.359993, 68.139999, 105.039993, 142.059982, 179.200027, 216.459991, 253.839996, 291.339996, 31.360008, 68.619995, 106.000000, 143.499985, 181.119995, 218.860001, 256.719971, 294.700012, 31.360001, 69.099991, 106.959984, 144.939987, 183.040009, 221.260010, 259.600006, 298.059998, 31.360008, 69.579971, 107.920006, 146.379990, 184.960007, 223.660004, 262.479980, 301.419983, 31.360001, 70.059975, 108.880020, 147.819977, 186.880020, 226.059998, 265.359985, 304.779999, -83.840004, -58.040001, -32.159988, -6.200012, 19.840012, 45.959984, 72.159996, 98.440010}, sd::DataType::FLOAT32); + + input.linspace(-10, 0.1); + weights.linspace(-2, 0.1); + + sd::ops::depthwise_conv2d op; + auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto output = results.at(0); + // output->printBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_9) { + + int bS=1, iH=10,iW=10, iC=8,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oC=iC*mC; + int oH=10,oW=10; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); + + NDArray expOutput('c', {bS, oC, oH, oW}, {-103.360001, -131.440002, -130.000000, -128.559998, -127.120003, -125.680000, -124.240005, -122.799995, -121.360001, -66.720001,-76.199997, -81.239998, -80.160004, -79.080002, -78.000000, -76.919998, -75.840004, -74.760002, -73.680000, -29.400002, -66.599998, -70.440002, -69.360001, -68.279999, + -67.199997, -66.120003, -65.040001, -63.959999, -62.879997, -24.599997, -57.000000, -59.639999, -58.560005, -57.479996, -56.399998, -55.320000, -54.240002, -53.159996, -52.080002, -19.799997, -47.400002, -48.840000, -47.760002, -46.680000, -45.599998, -44.520000, -43.440002, -42.360001, -41.279999, -15.000000, -37.799999, -38.040001, + -36.959999, -35.879997, -34.799999, -33.720001, -32.639999, -31.560001, -30.479996, -10.199999, -28.200001, -27.240002, -26.160000, -25.080002, -24.000000, -22.919998,-21.840000, -20.759998, -19.679998, -5.400000, -18.599998, -16.439999, -15.360001, -14.280001, -13.200001, -12.120001, -11.040000, -9.960001, -8.880000, -0.600000, + -9.000000, -5.639999, -4.560000, -3.480000, -2.400000, -1.320001, -0.240000, 0.840001, 1.920000, 4.200000, 0.160000, 3.920000, 3.920000, 3.920000, 3.920000, 3.920000,3.920001, 3.920000, 3.920000, 3.520000, 8.860001, 12.920000, 14.420000, 15.920000, 17.420000, 18.920000, 20.420000, 21.920000, 23.420000, 13.820000, 20.430000, 27.750000, + 28.919998, 30.090000, 31.260000, 32.430000, 33.600002, 34.770000, 35.939999, 19.709999, 30.630001, 39.450001, 40.619999, 41.790001, 42.960003, 44.129997, 45.299999, 46.470001, 47.639999, 25.110001, 40.829998, 51.150002, 52.320000, 53.489998, 54.660004, 55.829994, 57.000000, 58.169998, 59.340004, 30.510002, 51.029999, 62.849998, + 64.019997, 65.190002, 66.360001, 67.529999, 68.699997, 69.870003, 71.040001, 35.910000, 61.229996, 74.550003, 75.720001, 76.889999, 78.059998, 79.229996, 80.400002, 81.570000, 82.740005, 41.310001, 71.430000, 86.250000, 87.419998, 88.589996, 89.760002, 90.929993, 92.099991, 93.270004, 94.440002, 46.709999, 81.630005, 97.949997, + 99.120003, 100.290009, 101.459999, 102.630005, 103.800003, 104.970001, 106.139999, 52.110001, 91.830002, 109.649994, 110.820007, 111.990005, 113.159996, 114.330002, 115.500000, 116.669998, 117.839996, 57.509995, 19.580000, 9.079998, 9.139999, 9.199999, 9.259996, 9.320001, 9.379998, 9.440000, 9.500000, -8.740000, 129.080002, 169.279999, + 170.839996, 172.399994, 173.960007, 175.520004, 177.080002, 178.639999, 180.199982, 102.360001, 129.059998, 154.739990, 156.000000, 157.259995, 158.520004, 159.779999, 161.039993, 162.300003, 163.559998, 80.820000, 139.860001, 167.340012, 168.600006, 169.860001, 171.119995, 172.380005, 173.639999, 174.899994, 176.160004, 86.820000, + 150.660004, 179.940002, 181.200012, 182.459991, 183.720001, 184.980011, 186.239990, 187.500000, 188.759995, 92.820007, 161.459991, 192.540009, 193.799988, 195.059998, 196.319992, 197.579987, 198.839996, 200.100006, 201.360001, 98.820000, 172.259995, 205.139999, 206.399994, 207.660004, 208.919983, 210.179993, 211.440002, 212.700012, + 213.959991, 104.819992, 183.059998, 217.739990, 219.000000, 220.259995, 221.519989, 222.779999, 224.039993, 225.300018, 226.559998, 110.819992, 193.860016, 230.339996, 231.600006, 232.860001, 234.119995, 235.380005, 236.639999, 237.900009, 239.160004, 116.820000, 204.660004, 242.940002, 244.199982, 245.459991, 246.720001, 247.980011, + 249.239990, 250.500000, 251.759995, 122.819992, 47.000000, 26.240004, 26.360004, 26.479998, 26.600002, 26.720001, 26.840002, 26.959997, 27.080000, -12.999998, 257.299988, 337.640015, 339.260010, 340.879974, 342.499969, 344.119995, 345.740021, 347.359985, 348.979980, 198.899994, 249.690002, 299.729980, 301.079987, 302.429993, 303.779999, 305.130005, 306.480011, 307.829987, 309.179993, 153.929993, 261.089996, 313.230011, 314.580017, 315.929993, 317.279968, 318.630005, 319.979980, 321.329987, 322.679993, 160.529999, 272.489990, 326.729980, 328.079987, 329.429993, 330.779968, 332.130005, 333.479980, 334.829987, 336.179993, 167.130005, 283.889984, 340.230011, 341.580017, 342.929993, 344.279999, 345.630005, 346.980011, 348.330017, 349.679993, 173.729996, 295.289978, 353.729980, 355.079987, 356.429993, 357.779968, 359.130005, 360.480011, 361.829987, 363.179993, 180.329987, 306.690002, 367.230011, 368.580017, 369.929993, 371.279999, 372.630005, 373.980011, 375.330017, 376.679993, 186.929993, 318.089996, 380.729980, 382.080017, 383.429993, 384.779968, 386.130005, 387.479980, 388.829987, 390.179993, 193.529984, 329.489990, 394.229980, 395.579987, 396.929993, 398.279999, 399.630005, 400.980011, 402.330017, 403.679993, 200.130005, 82.419998, 55.400005, 55.580002, 55.759995, 55.939999, 56.120003, 56.299995, 56.479996, 56.659996, -9.260002, 393.520020, 518.000000, 519.679993, 521.359985, 523.040039, 524.720032, 526.400024, 528.080017, 529.760010, 303.440002, 382.320007, 462.720032, 464.160004, 465.600037, 467.040009, 468.479980, 469.919983, 471.359985, 472.800018, 239.040009, 394.320007, 477.119995, 478.559998, 480.000000, 481.440002, 482.880005, 484.320007, 485.760010, 487.200012, 246.240005, 406.320007, 491.520020, 492.960022, 494.400024, 495.839996, 497.280029, 498.720032, 500.160004, 501.600037, 253.440002, 418.320007, 505.919983, 507.359985, 508.800018, 510.240051, 511.680023, 513.119995, 514.559998, 516.000000, 260.640015, 430.319977, 520.320007, 521.760010, 523.200012, 524.640015, 526.079956, 527.520020, 528.960022, 530.400024, 267.839996, 442.320007, 534.720032, 536.160034, 537.600037, 539.040039, 540.479980, 541.919983, 543.359985, 544.800049, 275.040009, 454.320007, 549.119995, 550.559998, 552.000000, 553.440002, 554.880005, 556.320007, 557.760010, 559.200012, 282.239990, 466.320007, 563.520020, 564.960022, 566.400024, 567.839966, 569.280029, 570.720032, 572.160034, 573.600037, 289.440002, 125.839996, 96.559998, 96.799995, 97.040009, 97.280014, 97.520004, 97.759995, 98.000000, 98.240013, 2.480007, 537.739990, 710.359985, 712.099976, 713.840027, 715.579956, 717.319946, 719.059998, 720.799988, 722.539978, 415.980011, 526.950012, 643.710022, 645.240051, 646.770020, 648.300049, 649.829956, 651.359985, 652.890015, 654.419983, 336.149994, 539.549988, 659.010010, 660.539978, 662.070007, 663.600037, 665.130005, 666.660034, 668.190002, 669.720032, 343.950012, 552.150024, 674.309998, 675.839966, 677.369995, 678.900024, 680.429993, 681.960022, 683.490051, 685.020020, 351.750000, 564.750000, 689.609985, 691.140015, 692.669983, 694.200012, 695.729980, 697.260010, 698.789978, 700.320007, 359.549988, 577.349976, 704.910034, 706.440002, 707.970032, 709.500000, 711.029968, 712.559998, 714.089966, 715.619995, 367.350037, 589.950012, 720.210022, 721.740051, 723.270020, 724.800049, 726.329956, 727.859985, 729.390015, 730.919983, 375.149994, 602.549988, 735.510010, 737.039978, 738.570007, 740.100037, 741.630005, 743.160034, 744.690002, 746.220032, 382.950012, 615.150024, 750.809998, 752.339966, 753.869995, 755.399963, 756.929993, 758.460022, 759.990051, 761.520020, 390.750000, 177.260010, 149.720001, 150.020004, 150.319992, 150.619995, 150.919998, 151.220001, 151.520004, 151.819992, 22.220009, 689.959961, 914.720032, 916.519958, 918.319946, 920.119995, 921.919983, 923.719971, 925.520020, 927.320007, 536.519958, 683.579956, 842.699951, 844.319946, 845.940002, 847.559998, 849.179993, 850.799988, 852.419983, 854.039978, 445.260010, 696.779968, 858.900024, 860.520020, 862.140015, 863.760010, 865.380005, 867.000000, 868.619995, 870.239990, 453.659973, 709.979980, 875.099976, 876.719971, 878.339966, 879.959961, 881.579956, 883.199951, 884.819946, 886.440002, 462.059998, 723.179993, 891.299988, 892.919983, 894.539978, 896.159973, 897.779968, 899.400024, 901.020020, 902.640015, 470.459991, 736.380005, 907.500000, 909.119995, 910.739990, 912.359985, 913.979980, 915.599976, 917.219971, 918.839966, 478.859985, 749.579956, 923.699951, 925.319946, 926.940002, 928.559998, 930.179993, 931.799988, 933.419983, 935.039978, 487.260010, 762.779968, 939.900024, 941.520020, 943.140015, 944.760010, 946.380005, 948.000000, 949.619995, 951.239990, 495.659973, 775.979980, 956.099976, 957.719971, 959.339966, 960.959961, 962.579956, 964.199951, 965.819946, 967.440002, 504.059998, 236.679977, 214.880005, 215.239990, 215.599991, 215.959991, 216.319992, 216.679993, 217.040009, 217.399994, 49.959995, 850.180054, 1131.079956, 1132.939941, 1134.800049, 1136.660034, 1138.520020, 1140.380005, 1142.239990, 1144.100098, 665.060059, 852.209961, 1059.689941, 1061.399902, 1063.110107, 1064.820068, 1066.530029, 1068.239990, 1069.950073, 1071.660034, 566.370056, 866.010010, 1076.790039, 1078.500000, 1080.209961, 1081.920044, 1083.630005, 1085.339966, 1087.050049, 1088.760010, 575.369995, 879.809998, 1093.890015, 1095.599976, 1097.310059, 1099.020020, 1100.729980, 1102.439941, 1104.149902, 1105.859985, 584.369995, 893.609985, 1110.989990, 1112.699951, 1114.410034, 1116.120117, 1117.830078, 1119.540039, 1121.250000, 1122.959961, 593.370056, 907.410034, 1128.089966, 1129.800049, 1131.510010, 1133.220093, 1134.929932, 1136.639893, 1138.349976, 1140.060059, 602.369995, 921.209961, 1145.189941, 1146.900024, 1148.609985, 1150.320068, 1152.030029, 1153.739990, 1155.449951, 1157.160034, 611.370056, 935.010010, 1162.290039, 1164.000000, 1165.709961, 1167.420044, 1169.130005, 1170.839966, 1172.550049, 1174.260010, 620.369995, 948.809998, 1179.390015, 1181.099976, 1182.810059, 1184.520020, 1186.229980, 1187.939941, 1189.650024, 1191.359985, 629.370056, 304.099976, 292.039978, 292.460022, 292.880005, 293.300018, 293.720001, 294.140015, 294.559998, 294.980042, 85.700005}, sd::DataType::FLOAT32); + + input.linspace(-10, 0.1); + weights.linspace(-2, 0.1); + + sd::ops::depthwise_conv2d op; + auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output, 1e-4)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_10) { + + int bS=1, iH=3,iW=3, iC=2,mC=2, kH=1,kW=1, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oC=iC*mC; + int oH=3,oW=3; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = 1; // 0-[kH, kW, iC, mC], 1-[mC, iC, kH, kW], 2-[mC, kH, kW, iC] + + NDArray input('c', {bS, iC, iH, iW}, {0.6793503761291504, 0.35508695244789124, 0.842789351940155, 0.20031332969665527, 0.7014986872673035, 0.3106933832168579, + 0.44793984293937683, 0.9380097389221191, 0.3266739547252655, 0.15187257528305054, 0.3833175301551819, 0.7821229696273804, + 0.19880719482898712, 0.7985635995864868, 0.16326339542865753, 0.14696824550628662, 0.2608966827392578, 0.13505761325359344}, sd::DataType::FLOAT32); + NDArray weights('c', {mC, iC, kH, kW}, {0.130845, 0.569885, 0.644284, 0.198968}, sd::DataType::FLOAT32); + NDArray biases('c', {iC*mC}, {0.6123566627502441, 0.37637925148010254, 0.17464971542358398, 0.4270855486392975}, sd::DataType::FLOAT32); + + NDArray expOutput('c', {bS, oC, oH, oW}, {0.7012459761288241, 0.6588178652487691, 0.722631079971582, 0.6385665758716108, 0.7041439625563628, 0.6530092074102978, + 0.670967162534851, 0.735090151337225, 0.6551001785478623, 0.8140738359624038, 0.6051560970782859, 0.9193749546773375, 0.5054379267801892, 0.8283436386757472, + 0.5765540302788565, 0.6649797296980537, 0.9807239274294943, 0.586850056971322, 0.261199593183985, 0.3930965634902499, 0.6203697362284615, 0.28794692117826504, + 0.6297390019475202, 0.26769104886224415, 0.25840469001015975, 0.3233307788551656, 0.25161700129415276, 0.4573034071191504, 0.5033536625992294, 0.5827033826425385, + 0.4666419179635315, 0.585974550122895, 0.4595698215161401, 0.45632759998045813, 0.4789957702325296, 0.4539577593482922}, sd::DataType::FLOAT32); + + sd::ops::depthwise_conv2d op; + auto results = op.evaluate({&input, &weights, &biases}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); + auto* output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_11) { + + int bS=1, iH=10,iW=10, iC=8,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oC=iC*mC; + int oH=10,oW=10; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + int wFormat = 2; // 0-[kH, kW, iC, mC], 1-[mC, iC, kH, kW], 2-[mC, kH, kW, iC] + + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {mC, kH, kW, iC}, {-2., -1.9, -1.8, -1.7, -1.6, -1.5, -1.4, -1.3, -1.2, -1.1, -1., -0.9, -0.8, -0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1, + 0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2., 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., + 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6, 4.7, 4.8, 4.9, 5., 5.1}, sd::DataType::FLOAT32); + + NDArray expOutput('c', {bS, oH, oW, oC}, {-42.879997, -43.959999, -44.959999, -45.879997, -46.720005, -47.480003, -48.160000, -48.760002, -43.519997, -45.139999, -46.639996, -48.020000, -49.280003, -50.419998, -51.440006, -52.340000, -31.999998, -33.139999, -34.160000, -35.060001, -35.840004, -36.500004, -37.039997, -37.459999, -20.480000, + -21.139997, -21.680000, -22.100000, -22.399998, -22.579998, -22.639996, -22.580002, -8.960000, -9.139998, -9.200002, -9.140001, -8.960001, -8.660000, -8.240002, -7.700001, 2.560000, 2.860002, 3.279998, 3.820000, 4.480001, 5.260000, 6.160001, 7.180000, 14.080000, 14.860000, 15.759998, 16.779999, 17.920002, 19.180000, 20.560001, 22.059998, + 25.600000, 26.860001, 28.239998, 29.739998, 31.360001, 33.099998, 34.959999, 36.939999, 37.119999, 38.860001, 40.720001, 42.699997, 44.800003, 47.020000, 49.360001, 51.820000, 26.239998, 27.400002, 28.639999, 29.959999, 31.360001, 32.840000, 34.400002, 36.040001, 62.400002, 62.459999, 62.639999, 62.940002, 63.360001, 63.900002, 64.559998, + 65.340004, 106.080002, 106.169998, 106.440002, 106.889999, 107.519997, 108.330002, 109.320000, 110.490005, 114.720001, 115.529999, 116.520004, 117.690002, 119.040009, 120.570000, 122.279999, 124.169998, 123.359985, 124.889999, 126.599998, 128.490005, 130.559998, 132.809998, 135.240005, 137.850006, 132.000000, 134.250000, 136.679993, + 139.290009, 142.080002, 145.049988, 148.199997, 151.529999, 140.639999, 143.610001, 146.760010, 150.089996, 153.600006, 157.290009, 161.160004, 165.209991, 149.279999, 152.970001, 156.839996, 160.889999, 165.120010, 169.529999, 174.119995, 178.889999, 157.919998, 162.330002, 166.919983, 171.690002, 176.639999, 181.769989, 187.079987, + 192.570007, 166.559998, 171.690002, 177.000000, 182.489990, 188.160004, 194.010010, 200.040009, 206.250000, 100.799995, 104.220001, 107.760002, 111.419998, 115.200005, 119.099998, 123.120003, 127.260010, 139.200012, 144.059998, 149.040009, 154.139999, 159.360001, 164.699997, 170.160004, 175.739990, 192.479996, 199.770020, 207.239990, + 214.889999, 222.720001, 230.730011, 238.919998, 247.290009, 201.119995, 209.129990, 217.319992, 225.690002, 234.240005, 242.970001, 251.880005, 260.970001, 209.760010, 218.489990, 227.399994, 236.490005, 245.760010, 255.209991, 264.839996, 274.649994, 218.399994, 227.850006, 237.479996, 247.289993, 257.279999, 267.449982, 277.799988, + 288.330017, 227.040009, 237.209991, 247.559998, 258.089996, 268.800018, 279.690002, 290.760010, 302.010010, 235.679993, 246.570007, 257.639984, 268.889984, 280.320007, 291.929993, 303.720001, 315.690002, 244.320007, 255.929993, 267.720001, 279.690002, 291.839996, 304.169983, 316.679993, 329.369995, 252.959991, 265.290009, 277.799988, + 290.489990, 303.359985, 316.410004, 329.640015, 343.050018, 139.199997, 147.419998, 155.760010, 164.220001, 172.799988, 181.500000, 190.319992, 199.260010, 216.000000, 225.660004, 235.440002, 245.339996, 255.360016, 265.500000, 275.760010, 286.140015, 278.880005, 293.369995, 308.040009, 322.889984, 337.920013, 353.129974, 368.519989, + 384.090027, 287.520020, 302.730011, 318.119995, 333.690002, 349.440002, 365.369995, 381.479980, 397.770020, 296.160004, 312.089996, 328.199982, 344.489990, 360.960022, 377.609985, 394.440002, 411.449982, 304.799988, 321.450012, 338.280029, 355.289978, 372.480011, 389.850006, 407.399994, 425.130005, 313.440002, 330.809998, 348.359985, 366.089996, 384.000000, 402.090027, 420.359985, 438.809998, 322.079987, 340.169983, 358.440002, 376.889984, 395.520020, 414.329987, 433.320007, 452.489990, 330.720001, 349.530029, 368.520020, 387.690002, 407.039978, 426.570007, 446.279999, 466.170013, 339.360016, 358.890015, 378.599976, 398.490021, 418.559998, 438.809998, 459.239990, 479.849976, 177.600006, 190.619995, 203.759995, 217.020004, 230.399994, 243.899994, 257.519989, 271.260010, 292.799988, 307.260010, 321.839996, 336.539978, 351.360016, 366.299988, 381.359985, 396.540009, 365.279999, 386.970001, 408.839996, 430.889984, 453.120026, 475.529968, 498.119995, 520.890015, 373.920013, 396.329987, 418.919983, 441.690002, 464.640015, 487.769958, 511.079987, 534.570007, 382.559998, 405.690002, 429.000000, 452.489990, 476.160004, 500.010010, 524.039978, 548.250000, 391.200012, 415.049988, 439.080017, 463.290009, 487.679993, 512.250000, 537.000000, 561.930054, 399.839996, 424.409973, 449.160034, 474.089966, 499.200012, 524.489990, 549.959961, 575.609985, 408.479980, 433.770020, 459.239990, 484.889954, 510.720032, 536.729980, 562.919983, 589.290039, 417.119995, 443.130005, 469.319977, 495.690002, 522.239990, 548.969971, 575.880005, 602.969971, 425.760010, 452.489990, 479.399994, 506.489990, 533.760010, 561.209961, 588.839966, 616.650024, 216.000000, 233.819992, 251.760010, 269.820007, 288.000000, 306.299988, 324.719971, 343.260010, 369.600006, 388.859985, 408.239990, 427.739990, 447.360016, 467.100006, 486.959961, 506.940002, 451.679993, 480.570007, 509.639984, 538.890015, 568.320007, 597.929993, 627.719971, 657.690002, 460.320007, 489.929993, 519.719971, 549.690002, 579.840027, 610.170044, 640.680054, 671.369995, 468.960022, 499.289978, 529.799988, 560.489990, 591.359985, 622.409973, 653.640015, 685.049988, 477.599976, 508.650024, 539.880005, 571.289978, 602.880005, 634.650024, 666.599976, 698.729980, 486.239990, 518.010010, 549.960022, 582.089966, 614.400024, 646.890015, 679.559937, 712.410034, 494.879974, 527.369995, 560.039978, 592.890015, 625.920044, 659.130005, 692.520020, 726.089966, 503.519989, 536.729980, 570.119995, 603.689941, 637.440063, 671.369995, 705.480042, 739.770020, 512.160034, 546.089966, 580.199951, 614.489990, 648.960022, 683.609985, 718.440002, 753.449951, 254.400009, 277.020020, 299.760010, 322.619995, 345.600006, 368.700012, 391.919983, 415.260010, 446.399994, 470.459961, 494.640015, 518.940002, 543.360046, 567.900024, 592.559998, 617.340027, 538.080017, 574.170044, 610.440002, 646.890015, 683.520020, 720.329956, 757.320007, 794.489990, 546.719971, 583.530029, 620.520020, 657.690002, 695.040039, 732.570007, 770.279968, 808.169983, 555.359985, 592.889954, 630.599976, 668.489990, 706.559998, 744.809998, 783.239990, 821.849976, 564.000000, 602.250000, 640.679993, 679.289978, 718.080017, 757.050049, 796.199951, 835.530029, 572.640015, 611.609985, 650.760010, 690.089966, 729.600037, 769.289978, 809.160034, 849.210083, 581.279968, 620.970032, 660.839966, 700.889954, 741.119995, 781.529968, 822.119995, 862.890015, 589.919983, 630.330017, 670.919983, 711.690002, 752.640015, 793.770020, 835.079956, 876.570007, 598.559998, 639.690002, 681.000000, 722.490051, 764.160034, 806.010010, 848.039978, 890.250061, 292.799988, 320.220001, 347.760010, 375.419983, 403.200012, 431.100006, 459.119995, 487.260010, 523.199951, 552.059998, 581.040039, 610.139954, 639.360046, 668.699951, 698.159973, 727.739990, 624.479980, 667.770020, 711.239990, 754.890015, 798.719971, 842.729980, 886.919983, 931.290039, 633.119995, 677.130005, 721.319946, 765.690002, 810.239990, 854.969971, 899.880005, 944.969971, 641.760010, 686.489990, 731.400024, 776.489990, 821.760010, 867.209961, 912.839966, 958.650024, 650.400024, 695.849976, 741.479980, 787.290039, 833.279968, 879.449951, 925.799927, 972.330017, 659.040039, 705.210022, 751.559998, 798.089966, 844.800049, 891.690002, 938.760010, 986.010010, 667.679993, 714.569946, 761.640015, 808.890015, 856.320007, 903.929993, 951.719971, 999.690063, 676.320007, 723.929993, 771.719971, 819.690002, 867.839966, 916.169922, 964.679932, 1013.369995, 684.959961, 733.290039, 781.800049, 830.489990, 879.359985, 928.410034, 977.640015, 1027.050049, 331.199982, 363.419983, 395.760010, 428.220001, 460.799988, 493.500000, 526.320007, 559.260010, 600.000000, 633.660034, 667.440002, 701.339966, 735.359985, 769.500000, 803.759949, 838.140015, 710.880005, 761.369995, 812.039978, 862.889893, 913.919983, 965.130005, 1016.520020, 1068.090088, 719.520020, 770.729980, 822.119934, 873.689941, 925.440063, 977.369995, 1029.479980, 1081.770020, 728.160034, 780.090088, 832.199951, 884.489990, 936.960022, 989.610046, 1042.439941, 1095.449951, 736.799927, 789.449951, 842.280029, 895.290039, 948.480042, 1001.849976, 1055.399902, 1109.129883, 745.439941, 798.810059, 852.359985, 906.089966, 960.000000, 1014.089966, 1068.359985, 1122.810059, 754.080017, 808.170044, 862.440002, 916.890015, 971.520020, 1026.330078, 1081.319946, 1136.489990, 762.720032, 817.530029, 872.520020, 927.689941, 983.040039, 1038.569946, 1094.280029, 1150.169922, 771.359985, 826.890015, 882.599976, 938.489990, 994.559998, 1050.810059, 1107.239990, 1163.849976, 369.599976, 406.619995, 443.760010, 481.020020, 518.400024, 555.900024, 593.520020, 631.260010, 113.279999, 136.839996, 160.480011, 184.199982, 208.000015, 231.880005, 255.839996, 279.880005, 31.359985, 66.699989, 102.160004, 137.740005, 173.440002, 209.260010, 245.199982, 281.260010, 31.359993, 67.179993, 103.120003, 139.179993, 175.360016, 211.660004, 248.079987, 284.619995, 31.359993, 67.659996, 104.080009, 140.619995, 177.280014, 214.060013, 250.959991, 287.980011, 31.359993, 68.139999, 105.039993, 142.059982, 179.200027, 216.459991, 253.839996, 291.339996, 31.360008, 68.619995, 106.000000, 143.499985, 181.119995, 218.860001, 256.719971, 294.700012, 31.360001, 69.099991, 106.959984, 144.939987, 183.040009, 221.260010, 259.600006, 298.059998, 31.360008, 69.579971, 107.920006, 146.379990, 184.960007, 223.660004, 262.479980, 301.419983, 31.360001, 70.059975, 108.880020, 147.819977, 186.880020, 226.059998, 265.359985, 304.779999, -83.840004, -58.040001, -32.159988, -6.200012, 19.840012, 45.959984, 72.159996, 98.440010}, sd::DataType::FLOAT32); + + input.linspace(-10, 0.1); + weights.linspace(-2, 0.1); + + sd::ops::depthwise_conv2d op; + auto results = op.evaluate({&input, &weights}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test1) { + + int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=4,oW=3; + int oC=iC*mC; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); + auto bias = NDArrayFactory::create('c', {oC}, {1,2,3,4}); + auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); + + NDArray expGradI('c', {bS, iH, iW, iC},{0.07 , 0.19 , 0.348, 0.652, 0.588, 0.956, 0.387, 0.687, 1.326, 2.022, 1.878, 2.67 , 1.071, 1.515, 2.982, 3.966, 3.534, 4.614, 1.606, 1.982, 3.932, 4.748, 4.428, 5.308, + 1.126, 1.63 , 3.228, 4.3 , 3.468, 4.604, 3.123, 3.999, 7.95 , 9.798, 8.502, 10.446, 3.807, 4.827, 9.606, 11.742,10.158, 12.39 , 4.198, 4.958, 9.884, 11.468,10.38 , 12.028}, sd::DataType::FLOAT32); + + NDArray expGradW('c', {kH, kW, iC, mC},{19.08, 19.44,19.8 , 20.16,12.24, 12.48,12.72, 12.96,22.56, 23.04,23.52, 24. ,14.4 , 14.72,15.04, 15.36,14.76, 15.12,15.48, 15.84, 9.36, 9.6 , 9.84, 10.08}, sd::DataType::FLOAT32); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::depthwise_conv2d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto* gradI = results.at(0); + auto* gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test2) { + + int bS=2, iH=4,iW=3, iC=2,mC=2, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=2,oW=2; + int oC=iC*mC; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); + auto bias = NDArrayFactory::create('c', {oC}, {1,2,3,4}); + auto gradO = NDArrayFactory::create('c', {bS, oH, oW, oC}); + + NDArray expGradI('c', {bS, iH, iW, iC},{0.005, 0.025,0.034, 0.106,0.061, 0.113,0.058, 0.162,0.292, 0.564,0.298, 0.466,0.234, 0.402,0.772, 1.172,0.602, 0.834,0.333, 0.449,0.882, 1.146,0.581, 0.729, + 0.053, 0.137,0.258, 0.458,0.237, 0.353,0.41 , 0.642,1.252, 1.78 ,0.906, 1.202,1.098, 1.394,2.756, 3.412,1.722, 2.082,0.893, 1.073,2.13 , 2.522,1.269, 1.481}, sd::DataType::FLOAT32); + NDArray expGradW('c', {kH, kW, iC, mC},{2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88,2.4 , 2.56,2.72, 2.88}, sd::DataType::FLOAT32); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::depthwise_conv2d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto* gradI = results.at(0); + auto* gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test3) { + + auto in = NDArrayFactory::create('c', {4, 8, 64, 64}); + auto w = NDArrayFactory::create('c', {2, 2, 8, 2}); + auto b = NDArrayFactory::create('c', {1, 16}); + auto grad = NDArrayFactory::create('c', {4, 16, 64, 64}); + + auto gradI = in.like(); + auto gradW = w.like(); + auto gradB = b.like(); + + nd4j:ops::depthwise_conv2d_bp op; + auto status = op.execute({&in, &w, &b, &grad}, {&gradI, &gradW, &gradB}, {2, 2, 1, 1, 0, 0, 1, 1, 1, 0}); + ASSERT_EQ(Status::OK(), status); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test4) { + + int bS=1, iH=10,iW=10, iC=8,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oC=iC*mC; + int oH=10,oW=10; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, sd::DataType::FLOAT32); + + input.linspace(-10, 0.1); + weights.linspace(-2, 0.1); + gradO.linspace(10, -0.1); + + + NDArray expGradI('c', {bS, iH, iW, iC},{10.880001, 13.239998, 15.520001, 17.719997, 19.840000, 21.880001, 23.839998, 25.720001, 31.360004, 34.420002, 37.360001, 40.180004, 42.880005, 45.460003, 47.919994, 50.260002, 31.360001, 33.939999, 36.400002, 38.739998, 40.959999, 43.059998, 45.040001, 46.900005, 31.359997, 33.459999, 35.439999, 37.300003, 39.040001, 40.660000, 42.160000, 43.539997, 31.360001, 32.980000, 34.480000, 35.860001, 37.119999, 38.259998, 39.279999, 40.180000, 31.360001, 32.499996, 33.520000, 34.419998, 35.200001, 35.860001, 36.400002, 36.820000, 31.360001, 32.019997, 32.560001, 32.979996, 33.280003, 33.459999, 33.520000, 33.459999, 31.360001, 31.540001, 31.599998, 31.539999, 31.360001, 31.059999, 30.639999, 30.100000, 31.360001, 31.060001, 30.639999, 30.099998, 29.440002, 28.660000, 27.759998, 26.740000, 18.559999, 18.040001, 17.440001, 16.760000, 16.000000, 15.160000, 14.240001, 13.240000, 85.439995, 85.860001, 86.159996, 86.339996, 86.400002, 86.340012, 86.159996, 85.860008, 132.000000, 131.910004, 131.639999, 131.190002, 130.559998, 129.750000, 128.760010, 127.589996, 123.360001, 122.550003, 121.559998, 120.389999, 119.040009, 117.510002, 115.799988, 113.910004, 114.720001, 113.189995, 111.480003, 109.590004, 107.520004, 105.270004, 102.839996, 100.230011, 106.079994, 103.830002, 101.400009, 98.790009, 96.000008, + 93.030006, 89.879990, 86.549988, 97.439995, 94.469994, 91.319992, 87.990005, 84.479996, 80.789993, 76.919998, 72.870003, 88.800003, 85.110001, 81.239998, 77.190002, 72.960007, 68.550003, 63.959999, 59.190002, 80.160004, 75.750000, 71.160004, 66.389999, 61.440002, 56.309994, 51.000000, 45.510002, 71.519997, 66.389999, 61.079998, 55.590000, 49.919998, 44.070000, 38.040001, 31.830002, 31.680000, 27.780003, 23.760000, 19.619999, 15.360001, 10.980000, 6.480000, 1.859999, 47.040001, 42.660004, 38.160000, 33.540001, 28.799999, 23.939999, 18.960001, 13.860001, 45.599998, 38.310001, 30.840000, 23.190002, 15.360001, 7.349998, -0.840002, -9.210003, 36.959999, 28.950003, 20.759998, 12.390001, 3.839998, -4.889999, -13.799999, -22.890003, 28.320002, 19.589998, 10.680000, 1.590002, -7.680002, -17.129999, -26.759998, -36.570007, 19.680002, 10.230003, 0.599998, -9.210001, -19.199999, -29.370003, -39.720001, -50.250008, 11.039999, 0.869999, -9.480000, -20.010002, -30.719994, -41.610001, -52.679996, -63.930008, 2.400005, -8.489998, -19.560005, -30.809998, -42.239998, -53.849991, -65.639992, -77.610001, -6.239998, -17.849998, -29.639988, -41.609985, -53.760002, -66.090004, -78.599991, -91.290009, -14.879990, -27.209995, -39.720009, -52.410007, -65.279999, -78.330002, -91.559998, -104.969986, -45.119995, -53.820000, -62.639999, -71.580002, -80.640007, -89.819992, -99.119995, -108.540009, 8.639999, -0.540001, -9.839996, -19.259998, -28.799995, -38.459999, -48.240002, -58.140003, -40.799999, -55.289997, -69.960007, -84.810013, -99.840004, -115.050011, -130.440018, -146.010010, -49.439991, -64.650009, -80.040009, -95.610016, -111.360008, -127.290001, -143.399994, -159.690018, -58.080009, -74.009987, -90.119995, -106.409988, -122.880005, -139.530014, -156.360001, -173.369995, -66.720001, -83.369995, -100.199997, + -117.209999, -134.399994, -151.769989, -169.319992, -187.049988, -75.360008, -92.729996, -110.279991, -128.009979, -145.920013, -164.009995, -182.279984, -200.729996, -84.000000, -102.089996, -120.360016, -138.809967, -157.440002, -176.249969, -195.240005, -214.410019, -92.639999, -111.449997, -130.440018, -149.610016, -168.960007, -188.489990, -208.200012, -228.090012, -101.279976, -120.809982, -140.519989, -160.410004, -180.480011, -200.730011, -221.160034, -241.770020, -121.920006, -135.420013, -149.040009, -162.779999, -176.640015, -190.619995, -204.719986, -218.940002, -29.760002, -43.739998, -57.840000, -72.059998, -86.400009, -100.860001, -115.439995, -130.140015, -127.199997, -148.890015, -170.760010, -192.809998, -215.040024, -237.450012, -260.039978, -282.809998, -135.839996, -158.250000, -180.840012, -203.610046, -226.559982, -249.690002, -272.999969, -296.489990, -144.479980, -167.609985, -190.920013, -214.410019, -238.080032, -261.929993, -285.959991, -310.169983, -153.119995, -176.969986, -201.000031, -225.210022, -249.599976, -274.170013, -298.920013, -323.849976, -161.760040, -186.330017, -211.079987, -236.009995, -261.120026, -286.410034, -311.879974, -337.530029, -170.400009, -195.689987, -221.159973, -246.809998, -272.639954, -298.650024, -324.840057, -351.209991, -179.039963, -205.050018, -231.240021, -257.609985, -284.160004, -310.890015, -337.799988, -364.890015, -187.680023, -214.410004, -241.319977, -268.410004, -295.679993, -323.130005, -350.760010, -378.570038, -198.720016, -217.019989, -235.440002, -253.979980, -272.640045, -291.419983, -310.319977, -329.339996, -68.159981, -86.939987, -105.840012, -124.860001, -144.000000, -163.260010, -182.639984, -202.140015, -213.600021, -242.489990, -271.559937, -300.809998, -330.239990, -359.849976, -389.639984, + -419.610016, -222.240036, -251.849960, -281.640015, -311.609985, -341.760040, -372.089996, -402.600037, -433.290009, -230.880005, -261.210022, -291.719971, -322.410034, -353.280029, -384.329956, -415.559998, -446.970001, -239.519989, -270.570007, -301.800018, -333.209991, -364.800018, -396.570007, -428.520020, -460.650024, -248.160034, -279.929962, -311.880005, -344.010010, -376.320038, -408.809998, -441.479980, -474.330017, -256.799988, -289.289978, -321.960022, -354.809967, -387.839996, -421.050018, -454.440002, -488.009979, -265.440002, -298.650024, -332.040009, -365.609985, -399.360016, -433.290009, -467.399963, -501.689941, -274.080017, -308.009949, -342.119995, -376.409973, -410.880005, -445.530029, -480.359985, -515.369995, -275.520020, -298.619995, -321.839966, -345.179993, -368.640015, -392.220001, -415.919952, -439.740021, -106.560005, -130.140030, -153.840027, -177.659973, -201.599991, -225.660019, -249.840012, -274.140015, -300.000000, -336.090057, -372.360046, -408.809937, -445.440002, -482.250031, -519.240051, -556.410034, -308.640015, -345.450012, -382.440002, -419.609955, -456.959961, -494.489960, -532.200012, -570.089966, -317.280029, -354.809998, -392.520020, -430.410004, -468.480042, -506.729980, -545.159912, -583.770020, -325.920013, -364.169952, -402.600037, -441.210022, -480.000000, -518.970032, -558.119873, -597.449951, -334.559967, -373.529999, -412.679993, -452.009949, -491.519989, -531.209961, -571.080017, -611.129944, -343.200012, -382.889984, -422.760071, -462.809906, -503.039978, -543.449951, -584.039978, -624.809998, -351.839966, -392.250000, -432.839966, -473.609955, -514.560120, -555.689941, -596.999939, -638.489990, -360.480011, -401.610016, -442.920044, -484.409912, -526.080017, -567.929993, -609.959961, -652.169983, -352.320007, -380.220001, + -408.239990, -436.380005, -464.639984, -493.019989, -521.519958, -550.139954, -144.960022, -173.339996, -201.839996, -230.459976, -259.200043, -288.059998, -317.039978, -346.140015, -386.399963, -429.690002, -473.159912, -516.809937, -560.640076, -604.650024, -648.839966, -693.210022, -395.039978, -439.050018, -483.239929, -527.609985, -572.159973, -616.890015, -661.799988, -706.890015, -403.680023, -448.409973, -493.320007, -538.410034, -583.680054, -629.129944, -674.760010, -720.570068, -412.320007, -457.769897, -503.399963, -549.210083, -595.199951, -641.369995, -687.720093, -734.250000, -420.960052, -467.130035, -513.479980, -560.010010, -606.720093, -653.610046, -700.680054, -747.930115, -429.599976, -476.489990, -523.559998, -570.809937, -618.239990, -665.849976, -713.640015, -761.609985, -438.239990, -485.850037, -533.640015, -581.610046, -629.760010, -678.089966, -726.600037, -775.289917, -446.880035,-495.210052, -543.719971, -592.410034, -641.279968, -690.330017, -739.559937, -788.970093, -429.120026, -461.819946, -494.639984, -527.580017, -560.640015, -593.820007, -627.119995, -660.540039, -183.360016, -216.540009, -249.839996, -283.260040, -316.800018, -350.459961, -384.239990, -418.139984, -472.800049, -523.289917, -573.959961, -624.809998, -675.839966, -727.050049, -778.440063, -830.010010, -481.440002, -532.649963, -584.040100, -635.609985, -687.359924, -739.290039, -791.399963, -843.689941, -490.079987, -542.010010, -594.119995, -646.410034, -698.880005, -751.529968, -804.359985, -857.369995, -498.720032, -551.369995, -604.200012, -657.210022, -710.400024, -763.770081, -817.319946, -871.050049, -507.359955, -560.729919, -614.280029, -668.010010, -721.919983, -776.010010, -830.280029, -884.730042, -515.999939, -570.089966, -624.360046, -678.809937, -733.440002, + -788.250000, -843.239990, -898.410034, -524.639954, -579.449951, -634.440002, -689.609985, -744.960022, -800.489990, -856.200012, -912.090027, -533.280029, -588.810059, -644.520081, -700.409973, -756.480042, -812.730103, -869.159912, -925.769958, -505.920013, -543.420044, -581.040039, -618.780029, -656.640015, -694.620056, -732.719971, -770.940002, -447.359985, -471.559998, -495.840027, -520.200012, -544.640015, -569.159973, -593.760010, -618.440002, -815.359985, -852.140015, -889.040039, -926.059937, -963.200073, -1000.460022, -1037.839966, -1075.339966, -826.879944, -864.139954, -901.519958, -939.019958, -976.640076, -1014.379944, -1052.239990, -1090.219971, -838.400024, -876.140015, -913.999939, -951.979919, -990.080017, -1028.299927, -1066.640015, -1105.099976, -849.919983, -888.140015, -926.479980, -964.939941, -1003.520081, -1042.219971, -1081.040039, -1119.979980, -861.440063, -900.140015, -938.960022,-977.899963, -1016.960022, -1056.140015, -1095.440063, -1134.859985, -872.960022, -912.140015, -951.439941, -990.859985, -1030.400024, -1070.060059, -1109.839844, -1149.739990, -884.479980, -924.140015, -963.919922, -1003.819946, -1043.839966, -1083.979980, -1124.239990, -1164.619995, -896.000000, -936.140015, -976.399963, -1016.780029, -1057.280029, -1097.899902, -1138.640015, -1179.500122, -705.919983, -733.000000, -760.159912, -787.400024, -814.719971, -842.119995, -869.599976, -897.160034}, sd::DataType::FLOAT32); + + NDArray expGradW('c', {kH, kW, iC, mC},{-104306.421875, -104786.734375, -105268.687500, -105752.250000, -106237.421875, -106724.242188, -107212.671875, + -107702.734375, -116289.593750, -116823.296875, -117358.781250, -117896.109375, -118435.210938, -118976.109375, -119518.796875, -120063.296875, -104824.789062, + -105305.117188, -105787.070312, -106270.640625, -106755.843750, -107242.640625, -107731.078125, -108221.117188, -126744.000000, -127277.710938, -127813.187500, + -128350.484375, -128889.601562, -129430.515625, -129973.210938, -130517.703125, -140944.000000, -141536.984375, -142131.984375, -142729.000000, -143328.000000, + -143929.015625, -144532.000000, -145137.000000, -126744.000000, -127277.710938, -127813.187500, -128350.484375, -128889.601562, -129430.515625, -129973.210938, -130517.703125, -104824.789062, -105305.117188, -105787.070312, -106270.640625, -106755.843750, -107242.640625, -107731.078125, -108221.117188, -116289.593750, -116823.296875, -117358.781250, -117896.109375, -118435.210938, -118976.109375, -119518.796875, -120063.296875, -104306.421875, -104786.734375, -105268.687500, -105752.250000, -106237.421875, -106724.242188, -107212.671875, -107702.734375}, sd::DataType::FLOAT32); + + NDArray expGradB('c', {oC}, {-2960., -2970., -2980., -2990., -3000., -3010., -3020., -3030.}, sd::DataType::FLOAT32); + + sd::ops::depthwise_conv2d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + NDArray* gradI = results.at(0); + NDArray* gradW = results.at(1); + NDArray* gradB = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test5) { + + int bS=1, iH=10,iW=10, iC=8,mC=1, kH=3,kW=3, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oC=iC*mC; + int oH=10,oW=10; + int paddingMode = 1; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray weights('c', {kH, kW, iC, mC}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, sd::DataType::FLOAT32); + + input.linspace(-10, 0.1); + weights.linspace(-2, 0.1); + gradO.linspace(10, -0.1); + + + NDArray expGradI('c', {bS, iC, iH, iW}, {-12.639999, 3.920004, 3.920000, 3.920000, 3.920002, 3.920000, 3.920000, 3.919998, 3.919998, 16.319998, 52.680004, 111.000015, 109.919991, 108.840004, 107.760002, 106.680008, 105.600006, 104.519997, 103.440018, 87.960007, 47.880001, 100.200005, 99.119995, 98.040001, 96.959999, 95.879990, 94.799995, 93.720001, 92.639999, 78.360001, 43.079998, 89.399994, 88.320007, 87.240005, 86.159996, 85.079994, 84.000000, 82.919998, 81.840004, 68.759995, 38.279999, 78.600006, 77.519997, 76.440010, 75.360001, 74.279999, 73.200005, 72.120003, 71.040001, 59.160004, 33.480000, 67.799995, 66.720009, 65.639999, 64.559998, 63.480000, 62.399994, 61.320007, 60.240002, 49.559998, 28.680004, 57.000004, 55.919998, 54.839993, 53.759998, 52.680000, 51.600002, 50.519997, 49.440002, 39.959999, 23.880001, 46.200001, 45.120003, 44.039997, 42.959999, 41.880001, 40.799999, 39.719994, 38.639999, 30.360001, 19.079998, 35.400002, 34.320000, 33.239998, 32.159996, 31.080000, 29.999998, 28.919998, 27.840000, 20.759998, 14.079999, 24.080000, 22.639997, 21.200001, 19.759998, 18.320002, 16.880001, 15.440001, 14.000000, 9.759999, 3.140000, 3.560000, 3.500000, 3.440000, 3.380000, 3.320000, 3.260000, 3.200000, 3.140000, -0.220000, 4.050000, 2.010000, 0.840000, -0.330000, -1.499999, -2.670000, -3.840000, -5.010000, -6.179998, -9.150000, -1.350000, -9.690001, -10.859999, -12.029998, -13.200001, -14.370001, -15.539999, -16.710001, -17.879999, -19.349998, -6.750000, -21.389997, -22.560003, -23.730003, -24.900002, -26.069998, -27.239998, -28.410007, -29.580002, -29.550003, -12.150001, -33.089996, -34.260002, -35.430000, -36.600002, -37.770000, -38.939995, -40.110001, -41.280003, -39.749996, -17.550003, -44.790005, -45.959991, -47.129993, -48.300003, -49.470001, -50.640003, -51.809990, -52.979996, -49.950001, -22.949999, -56.490005, -57.660000, -58.829998, -60.000000, -61.170002, -62.340004, -63.510002, -64.680000, + -60.149994, -28.349998, -68.189987, -69.360001, -70.529999, -71.700005, -72.870010, -74.039993, -75.209999, -76.379990, -70.349998, -33.749996, -79.889999, -81.059990, -82.229988, -83.399994, -84.570007, -85.740005, -86.910004, -88.079994, -80.549995, -69.340004, -125.080002, -126.580002, -128.080002, -129.580002, -131.080002, -132.580002, -134.080002, -135.580002, -105.979996, 10.919998, -8.799997, -8.919998, -9.040003, -9.160004, -9.279999, -9.400002, -9.520002, -9.640003, -24.760000, -56.580009, -124.980003, -126.240005, -127.499992, -128.759995, -130.020020, -131.279999, -132.540009, -133.800003, -118.260002, -62.580009, -137.580002, -138.840012, -140.099991, -141.360001, -142.620010, -143.879974, -145.139999, -146.399994, -129.060013, -68.580002, -150.179993, -151.439987, -152.699997, -153.959991, -155.219986, -156.480011, -157.740005, -159.000000, -139.860001, -74.579994, -162.779999, -164.040024, -165.300003, -166.560028, -167.819977, -169.080002, -170.339996, -171.599991, -150.660004, -80.580002, -175.379990, -176.639999, -177.899994, -179.160019, -180.419998, -181.679993, -182.940002, -184.199997, -161.459991, -86.580002, -187.979996, -189.240005, -190.499985, -191.759995, -193.020020, -194.279999, -195.540024, -196.800018, -172.260010, -92.580002, -200.579987, -201.839981, -203.100006, -204.359970, -205.620010, -206.880005, -208.139999, -209.399994, -183.060013, -98.580002, -213.180023, -214.440002, -215.700012, -216.959991, -218.220001, -219.480011, -220.739975, -222.000000, -193.860001, -160.760010, -286.239990, -287.799988, -289.360016, -290.920013, -292.480011, -294.040009, -295.599976, -297.160004, -229.719986, 10.700003, -33.160004, -33.339996, -33.519993, -33.700001, + -33.879997, -34.059994, -34.239994, -34.419994, -57.299995, -129.209991, -269.969971, -271.319977, -272.670044, -274.019989, -275.369995, -276.720001, -278.070007, -279.420013, -239.369980, -135.809998, -283.470001, -284.820007, -286.169983, -287.520020, -288.869995, -290.220001, -291.570038, -292.919983, -250.770004, -142.410004, -296.969971, -298.320007, -299.669983, -301.020020, -302.369995, -303.719971, -305.070007, -306.419983, -262.169983, -149.009995, -310.470001, -311.820007, -313.170013, -314.519989, -315.869995, -317.220001, -318.570007, -319.919983, -273.570007, -155.610016, -323.969971, -325.320038, -326.669983, -328.020020, -329.369965, -330.719971, -332.070007, -333.419983, -284.970001, -162.209991, -337.469971, -338.820007, -340.169983, -341.519958, -342.869995, -344.220001, -345.570007, -346.920013, -296.369995, -168.809998, -350.970001, -352.320007, -353.669983, -355.019989, -356.369995, -357.719971, -359.070038, -360.419983, -307.769989, -175.410004, -364.469971, -365.820007, -367.169983, -368.520020, -369.869995, -371.219971, -372.570007, -373.919983, -319.169983, -260.179993, -459.399994, -461.019958, -462.639984, -464.260010, -465.880005, -467.500000, -469.119995, -470.739990, -361.459991, 2.480003, -69.520004, -69.760025, -70.000000, -70.239990, -70.479996, -70.720001, -70.960007, -71.200005, -97.839996, -213.840012, -432.960022, -434.400055, -435.840027, -437.279999, -438.720001, -440.160065, -441.599976, -443.040039, -372.480011, -221.040009, -447.360016, -448.800018, -450.239990, -451.679993, -453.119995, -454.559967, -456.000061, -457.440033, -384.480011, -228.239990, -461.759979, -463.200012, -464.639984, -466.079956, -467.520081, -468.960052, -470.399963, -471.839996, -396.479980, -235.440002, -476.159912, + -477.600006, -479.040039, -480.479980, -481.919952, -483.360046, -484.800079, -486.239990, -408.480042, -242.639999, -490.559967, -491.999969, -493.440063, -494.880035, -496.319946, -497.759979, -499.200012, -500.639984, -420.480011, -249.840012, -504.960052, -506.399963, -507.839996, -509.280029, -510.720001, -512.159973, -513.599976, -515.040039, -432.480011, -257.040009, -519.360046, -520.800049, -522.239990, -523.680054, -525.120056, -526.559998, -527.999939, -529.440002, -444.480011, -264.239990, -533.760010, -535.200012, -536.640015, -538.079956, -539.520020, -540.960022, -542.399963, -543.839966, -456.479980, -367.599976, -644.559998, -646.239929, -647.920044, -649.599976, -651.280029, -652.960022, -654.640076, -656.320007, -501.200043, -13.740002, -117.880005, -118.179993, -118.479996, -118.780014, -119.080002, -119.379990, -119.680008, -119.979996, -146.379990, -310.470001, -613.950012, -615.479980, -617.010071, -618.539978, -620.069946, -621.599976, -623.130005, -624.660034, -517.589966, -318.269958, -629.250000, -630.779968, -632.309937, -633.840027, -635.369995, -636.899902, -638.429993, -639.959961, -530.190063, -326.070038, -644.550049, -646.079956, -647.609985, -649.140015, -650.669922, -652.200012, -653.729980, -655.260010, -542.789978, -333.870026, -659.849976, -661.380005, -662.910034, -664.439941, -665.970093, -667.500000, -669.029968, -670.559937, -555.390015, -341.669983, -675.149902, -676.679993, -678.209961, -679.740051, -681.270020, -682.800049, -684.329956, -685.859985, -567.989990, -349.470001, -690.450012, -691.979980, -693.510010, -695.039978, -696.569946, -698.099976, -699.630005, -701.160034, -580.589966, -357.269958, -705.750000, -707.279968, -708.809937, -710.340027, -711.869995, -713.399902, -714.929993, -716.459961, -593.190002, -365.070038, -721.050049, -722.579956, -724.109985, -725.640015, -727.169922, -728.700012, + -730.229980, -731.760010, -605.789978, -483.019958, -841.719971, -843.460022, -845.200073, -846.939941, -848.680054, -850.419983, -852.159973, -853.899963, -648.940002, -37.960014, -178.240021, -178.599976, -178.959991, -179.320007, -179.679993, -180.039978, -180.399994, -180.759964, -202.919983, -419.099915, -812.939941, -814.559937, -816.179993, -817.800049, -819.419922, -821.040039, -822.660034, -824.279968, -674.699951, -427.500031, -829.140015, -830.759949, -832.380005, -833.999939, -835.619995, -837.240051, -838.859924, -840.479980, -687.899963, -435.899994, -845.339966, -846.959961, -848.579956, -850.200012, -851.819885, -853.439941, -855.059937, -856.679993, -701.100037, -444.299927, -861.540039, -863.160034, -864.779968, -866.399963, -868.020020, -869.640015, -871.259949, -872.880005, -714.299988, -452.700012, -877.740051, -879.359924, -880.979980, -882.599915, -884.219971, -885.839966, -887.459961, -889.079956, -727.500000, -461.099915, -893.939941, -895.559937, -897.179993, -898.800049, -900.419922, -902.040039, -903.660034, -905.279968, -740.700012, -469.499969, -910.140015, -911.759949, -913.380005, -914.999939, -916.620056, -918.239990, -919.860046, -921.479919, -753.899963, -477.899902, -926.339905, -927.959961, -929.579956, -931.200012, -932.819946, -934.439880, -936.059937, -937.679932, -767.100037, -606.439941, -1050.880005, -1052.680054, -1054.479980, -1056.280029, -1058.079956, -1059.880005, -1061.679932, -1063.479980, -804.679993, -70.180008, -250.600006, -251.019958, -251.440033, -251.860001, -252.280029, -252.700043, -253.120026, -253.540039, -267.459991, -539.730042, -1029.929932, -1031.640137, -1033.350098, -1035.060059, -1036.770020, -1038.479980, -1040.190063, -1041.900024, -843.809998, -548.729980, -1047.030029, -1048.740112, -1050.449829, -1052.160034, -1053.870117, -1055.580078, -1057.289917, -1059.000122, -857.609985, -557.729980, + -1064.130005, -1065.840088, -1067.550049, -1069.260010, -1070.969849, -1072.679932, -1074.390137, -1076.100098, -871.410034, -566.729980, -1081.229980, -1082.940063, -1084.650024, -1086.359985, -1088.069946, -1089.780029, -1091.489990, -1093.199951, -885.210022, -575.729980, -1098.329956, -1100.040039, -1101.750122, -1103.460205, -1105.170166, -1106.879883, -1108.589966, -1110.300049, -899.010071, -584.730042, -1115.429932, -1117.140137, -1118.850098, -1120.560059, -1122.270020, -1123.979980, -1125.689941, -1127.400024, -912.810059, -593.730042, -1132.530029, -1134.240234, -1135.949951, -1137.659912, -1139.370117, -1141.079956, -1142.790039, -1144.500122, -926.610046, -602.730042, -1149.629883, -1151.339966, -1153.050049, -1154.760132, -1156.469971, -1158.179810, -1159.890137, -1161.600098, -940.410034, -737.859985, -1272.040039, -1273.899902, -1275.760010, -1277.619995, -1279.479980, -1281.340088, -1283.200195, -1285.060059, -968.420044}, sd::DataType::FLOAT32); + + NDArray expGradW('c', {kH, kW, iC, mC}, {-2586.600586, -2505.600098, -18624.595703, -50943.605469, -99462.601562, -164181.609375, -245100.609375, -342219.625000, + -2880.149902, -2790.150146, -20700.152344, -56610.148438, -110520.156250, -182430.156250, -272340.156250, -380250.125000, -2594.701416, -2513.699951, + -18632.699219, -50951.695312, -99470.695312, -164189.703125, -245108.687500, -342227.750000, -3043.501465, -2953.500244, -20863.500000, -56773.492188, + -110683.515625, -182593.515625, -272503.531250, -380413.562500, -3383.499756, -3283.500000, -23183.501953, -63083.500000, -122983.500000, -202883.515625, + -302783.531250, -422683.468750, -3043.501465, -2953.500244, -20863.500000, -56773.492188, -110683.515625, -182593.515625, -272503.531250, -380413.562500, + -2594.701416, -2513.699951, -18632.699219, -50951.695312, -99470.695312, -164189.703125, -245108.687500, -342227.750000, -2880.149902, -2790.150146, -20700.152344, -56610.148438, -110520.156250, -182430.156250, -272340.156250, -380250.125000, -2586.600586, -2505.600098, -18624.595703, -50943.605469, -99462.601562, -164181.609375, -245100.609375, -342219.625000}, sd::DataType::FLOAT32); + + NDArray expGradB('c', {oC}, {505., -495., -1495., -2495., -3495., -4494.999512, -5495., -6495.}, sd::DataType::FLOAT32); + + sd::ops::depthwise_conv2d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + NDArray* gradI = results.at(0); + NDArray* gradW = results.at(1); + NDArray* gradB = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + + ASSERT_TRUE(expGradB.isSameShape(gradB)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test6) { + + int bS=2, iH=4,iW=3, iC=2,mC=1, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=2,oW=2; + int oC=iC*mC; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto weights = NDArrayFactory::create('c', {kH, kW, iC, mC}); + auto bias = NDArrayFactory::create('c', {oC}, {3,4}); + auto gradO = NDArrayFactory::create('c', {bS, oC, oH, oW}); + + auto expGradI = NDArrayFactory::create('c', {bS, iC, iH, iW},{0.001, 0.005, 0.006, 0.008, 0.03, 0.026, 0.024, 0.07, 0.05, 0.027, 0.069, 0.044, 0.01, + 0.032, 0.024, 0.044, 0.12, 0.08, 0.092, 0.224, 0.136, 0.07, 0.164, 0.096, 0.009, 0.037, 0.03, 0.056, 0.158, 0.106, 0.136, + 0.326, 0.194, 0.099, 0.229, 0.132, 0.026, 0.08, 0.056, 0.108, 0.28, 0.176, 0.22, 0.512, 0.296, 0.15, 0.34, 0.192}); + + auto expGradW = NDArrayFactory::create('c', {kH, kW, iC, mC}, {1.04, 1.68, 1.04, 1.68, 1.04, 1.68, 1.04, 1.68, 1.04, 1.68, 1.04, 1.68}); + + input = 2.; + weights.linspace(0.1, 0.1); + gradO.linspace(0.01, 0.01); + + sd::ops::depthwise_conv2d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto* gradI = results.at(0); + auto* gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests2, depthwise_conv2d_bp_test7) { + + int bS=2, iH=4,iW=3, iC=2,mC=1, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=2,oW=2; + int oC=iC*mC; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + int wFormat = 1; // 0-[kH, kW, iC, mC], 1-[mC, iC, kH, kW], 2-[mC, kH, kW, iC] + + NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32); + NDArray weights('c', {mC, iC, kH, kW}, {0.10, 0.30, 0.50, 0.70, 0.90, 1.10, 0.20, 0.40, 0.60, 0.80, 1., 1.2}, sd::DataType::FLOAT32); + NDArray bias('c', {oC}, {3,4}, sd::DataType::FLOAT32); + NDArray gradO('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32); + + + NDArray expGradI('c', {bS, iC, iH, iW},{0.001, 0.005, 0.006, 0.008, 0.03, 0.026, 0.024, 0.07, 0.05, 0.027, 0.069, 0.044, 0.01, + 0.032, 0.024, 0.044, 0.12, 0.08, 0.092, 0.224, 0.136, 0.07, 0.164, 0.096, 0.009, 0.037, 0.03, 0.056, 0.158, 0.106, 0.136, + 0.326, 0.194, 0.099, 0.229, 0.132, 0.026, 0.08, 0.056, 0.108, 0.28, 0.176, 0.22, 0.512, 0.296, 0.15, 0.34, 0.192}, sd::DataType::FLOAT32); + + NDArray expGradW('c', {mC, iC, kH, kW}, {1.04, 1.04, 1.04, 1.04, 1.04, 1.04, 1.68, 1.68, 1.68, 1.68, 1.68, 1.68}, sd::DataType::FLOAT32); + + input = 2.; + gradO.linspace(0.01, 0.01); + + sd::ops::depthwise_conv2d_bp op; + auto results = op.evaluate({&input, &weights, &bias, &gradO}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat, wFormat}); + auto* gradI = results.at(0); + auto* gradW = results.at(1); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expGradI.isSameShape(gradI)); + ASSERT_TRUE(expGradI.equalsTo(gradI)); + + ASSERT_TRUE(expGradW.isSameShape(gradW)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); +} + +#endif //LIBND4J_CONVOLUTIONTESTS2_H \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/CuDnnTests.cu b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/CuDnnTests.cu new file mode 100644 index 000000000..a7d2f7838 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/CuDnnTests.cu @@ -0,0 +1,150 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include + +#ifdef HAVE_CUDNN + +#include + +#endif + +using namespace sd; + +class CuDnnTests : public testing::Test { +public: + +}; + +static void printer(std::initializer_list helpers) { + + for (auto v:helpers) { + nd4j_printf("Initialized [%s]\n", v->name().c_str()); + } +} + + +TEST_F(CuDnnTests, helpers_includer) { + // we need this block, to make sure all helpers are still available within binary, and not optimized out by linker +#ifdef HAVE_CUDNN + sd::ops::platforms::PLATFORM_conv2d_ENGINE_CUDA conv2d; + sd::ops::platforms::PLATFORM_conv2d_bp_ENGINE_CUDA conv2d_bp; + sd::ops::platforms::PLATFORM_conv3dnew_ENGINE_CUDA conv3dnew; + sd::ops::platforms::PLATFORM_conv3dnew_bp_ENGINE_CUDA conv3dnew_bp; + sd::ops::platforms::PLATFORM_depthwise_conv2d_ENGINE_CUDA depthwise_conv2d; + sd::ops::platforms::PLATFORM_depthwise_conv2d_bp_ENGINE_CUDA depthwise_conv2d_bp; + sd::ops::platforms::PLATFORM_batchnorm_ENGINE_CUDA batchnorm; + sd::ops::platforms::PLATFORM_batchnorm_bp_ENGINE_CUDA batchnorm_bp; + sd::ops::platforms::PLATFORM_avgpool2d_ENGINE_CUDA avgpool2d; + sd::ops::platforms::PLATFORM_avgpool2d_bp_ENGINE_CUDA avgpool2d_bp; + sd::ops::platforms::PLATFORM_maxpool2d_ENGINE_CUDA maxpool2d; + sd::ops::platforms::PLATFORM_maxpool2d_bp_ENGINE_CUDA maxpool2d_bp; + sd::ops::platforms::PLATFORM_avgpool3dnew_ENGINE_CUDA avgpool3dnew; + sd::ops::platforms::PLATFORM_avgpool3dnew_bp_ENGINE_CUDA avgpool3dnew_bp; + sd::ops::platforms::PLATFORM_maxpool3dnew_ENGINE_CUDA maxpool3dnew; + sd::ops::platforms::PLATFORM_maxpool3dnew_bp_ENGINE_CUDA maxpool3dnew_bp; + + + + printer({&conv2d}); + printer({&conv2d_bp}); + printer({&conv3dnew}); + printer({&conv3dnew_bp}); + printer({&depthwise_conv2d}); + printer({&depthwise_conv2d_bp}); + printer({&batchnorm}); + printer({&batchnorm_bp}); + printer({&avgpool2d}); + printer({&avgpool2d_bp}); + printer({&maxpool2d}); + printer({&maxpool2d_bp}); + printer({&avgpool3dnew}); + printer({&avgpool3dnew_bp}); + printer({&maxpool3dnew}); + printer({&maxpool3dnew_bp}); +#endif +} + + +TEST_F(CuDnnTests, mixed_helpers_test_1) { +#if defined(HAVE_CUDNN) && defined (HAVE_MKLDNN) + nd4j_printf("Mixed platforms test\n", ""); + + + int bS=2, iH=4,iW=3, iC=4,oC=3, kH=3,kW=2, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1; + int oH=2,oW=2; + int paddingMode = 0; // 1-SAME, 0-VALID; + int dataFormat = 0; // 1-NHWC, 0-NCHW + + auto input = NDArrayFactory::create('c', {bS, iC, iH, iW}); + auto weights = NDArrayFactory::create('c', {oC, iC, kH, kW}); + auto bias = NDArrayFactory::create('c', {oC}, {1,2,3}); + + auto expOutput = NDArrayFactory::create('c', {bS, oC, oH, oW}, {61.f, 61.f, 61.f, 61.f, 177.2f, 177.2f, 177.2f, 177.2f, 293.4f, 293.4f, 293.4f, 293.4f, 61.f, 61.f, 61.f, 61.f, 177.2f, 177.2f, 177.2f, 177.2f, 293.4f, 293.4f, 293.4f, 293.4f}); + auto zCUDA = expOutput.like(); + auto zMKL = expOutput.like(); + + input = 2.; + weights.linspace(0.1, 0.1); + weights.permutei({2,3,1,0}); + + input.syncToHost(); + weights.syncToHost(); + bias.syncToHost(); + + sd::ops::conv2d op; + + // cuDNN part + Context cuda(1); + cuda.setTargetEngine(samediff::Engine::ENGINE_CUDA); + cuda.setInputArray(0, &input); + cuda.setInputArray(1, &weights); + cuda.setInputArray(2, &bias); + cuda.setOutputArray(0, &zCUDA); + cuda.setIArguments({kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto statusCUDA = op.execute(&cuda); + + ASSERT_EQ(Status::OK(), statusCUDA); + ASSERT_EQ(expOutput, zCUDA); + + // MKL-DNN part + Context mkl(1); + mkl.setTargetEngine(samediff::Engine::ENGINE_CPU); + mkl.setInputArray(0, &input); + mkl.setInputArray(1, &weights); + mkl.setInputArray(2, &bias); + mkl.setOutputArray(0, &zMKL); + mkl.setIArguments({kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat}); + auto statusMKL = op.execute(&mkl); + + zMKL.tickWriteHost(); + + ASSERT_EQ(Status::OK(), statusMKL); + ASSERT_EQ(expOutput, zMKL); +#endif +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/CudaBasicsTests1.cu b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/CudaBasicsTests1.cu new file mode 100644 index 000000000..3a85f6eef --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/CudaBasicsTests1.cu @@ -0,0 +1,2926 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // + // @author raver119@gmail.com + // + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace sd; +using namespace sd::graph; + +class CudaBasicsTests1 : public testing::Test { +public: + +}; + + +////////////////////////////////////////////////////////////////////////// +static cudaError_t allocateDeviceMem(LaunchContext& lc, std::vector& devicePtrs, const std::vector>& hostData) { + + if(devicePtrs.size() != hostData.size()) + throw std::invalid_argument("prepareDataForCuda: two input sts::vectors should same sizes !"); + + cudaError_t cudaResult; + + void* reductionPointer; + cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); if(cudaResult != 0) return cudaResult; + int* allocationPointer; + cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); if(cudaResult != 0) return cudaResult; + + lc.setReductionPointer(reductionPointer); + lc.setAllocationPointer(allocationPointer); + cudaStream_t stream = *lc.getCudaStream(); + + for(int i = 0; i < devicePtrs.size(); ++i) { + + cudaResult = cudaMalloc(reinterpret_cast(&devicePtrs[i]), hostData[i].second); if(cudaResult != 0) return cudaResult; + cudaMemcpyAsync(devicePtrs[i], hostData[i].first, hostData[i].second, cudaMemcpyHostToDevice, stream); + } + return cudaResult; +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, TestPairwise_1) { + // allocating host-side arrays + auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); + auto z = NDArrayFactory::create('c', { 5 }, {0,0,0,0,0}); + + auto exp = NDArrayFactory::create('c', { 5 }, { 2, 4, 6, 8, 10 }); + + // making raw buffers + Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); + ASSERT_EQ(0, res); + res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); + ASSERT_EQ(0, res); + res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + ASSERT_EQ(0, res); + + Nd4jPointer nativeStream = (Nd4jPointer)malloc(sizeof(cudaStream_t)); + CHECK_ALLOC(nativeStream, "Failed to allocate memory for new CUDA stream", sizeof(cudaStream_t)); + cudaError_t dZ = cudaStreamCreate(reinterpret_cast(&nativeStream)); + auto stream = reinterpret_cast(&nativeStream); + + x.dataBuffer()->allocatePrimary(); + x.syncToHost(); + + cudaMemcpyAsync(devBufferPtrX, x.buffer(), x.lengthOf() * x.sizeOfT(), cudaMemcpyHostToDevice, *stream); + cudaMemcpyAsync(devShapePtrX, x.shapeInfo(), shape::shapeInfoByteLength(x.shapeInfo()), cudaMemcpyHostToDevice, *stream); + res = cudaStreamSynchronize(*stream); + ASSERT_EQ(0, res); + + LaunchContext lc(stream, nullptr, nullptr); + NativeOpExecutioner::execPairwiseTransform(&lc, pairwise::Add, nullptr, x.shapeInfo(), devBufferPtrX, reinterpret_cast(devShapePtrX), nullptr, x.shapeInfo(), devBufferPtrX, reinterpret_cast(devShapePtrX), nullptr, z.shapeInfo(), devBufferPtrZ, reinterpret_cast(devShapePtrX), nullptr); + res = cudaStreamSynchronize(*stream); + ASSERT_EQ(0, res); + + z.dataBuffer()->allocatePrimary(); + + cudaMemcpyAsync(z.buffer(), devBufferPtrZ, z.lengthOf() * x.sizeOfT(), cudaMemcpyDeviceToHost, *stream); + res = cudaStreamSynchronize(*stream); + ASSERT_EQ(0, res); + + cudaFree(devBufferPtrX); + cudaFree(devBufferPtrZ); + cudaFree(devShapePtrX); + + // needed due to memcpy + z.tickWriteHost(); + + for (int e = 0; e < z.lengthOf(); e++) { + //nd4j_printf("step %i\n", e); + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + } +} + + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execIndexReduceScalar_1) { + + NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT32); + NDArray x2('c', {2,2}, {0.5, 1.5, -4.5, 3.5}, sd::DataType::BFLOAT16); + NDArray x3('c', {2,2}, {0, -1, 0, 1}, sd::DataType::BOOL); + + NDArray scalar('c', {}, std::vector{0}, sd::DataType::INT64); + + NDArray exp1('c', {}, std::vector{3}, sd::DataType::INT64); + NDArray exp2('c', {}, std::vector{2}, sd::DataType::INT64); + NDArray exp3('c', {}, std::vector{1}, sd::DataType::INT64); + + void *dX1, *dX2, *dX3, *dZ; + Nd4jLong *dX1ShapeInfo, *dX2ShapeInfo, *dX3ShapeInfo, *dZShapeInfo; + + cudaError_t cudaResult; + + cudaResult = cudaMalloc(reinterpret_cast(&dX1), x1.lengthOf() * x1.sizeOfT()); ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dX2), x2.lengthOf() * x2.sizeOfT()); ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dX3), x3.lengthOf() * x3.sizeOfT()); ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dZ), scalar.lengthOf() * scalar.sizeOfT()); ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dX1ShapeInfo), shape::shapeInfoByteLength(x1.shapeInfo())); ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dX2ShapeInfo), shape::shapeInfoByteLength(x2.shapeInfo())); ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dX3ShapeInfo), shape::shapeInfoByteLength(x3.shapeInfo())); ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dZShapeInfo), shape::shapeInfoByteLength(scalar.shapeInfo())); ASSERT_EQ(0, cudaResult); + + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + + x1.syncToHost(); + x2.syncToHost(); + x3.syncToHost(); + scalar.syncToHost(); + + cudaMemcpyAsync(dX1, x1.buffer(), x1.lengthOf() * x1.sizeOfT(), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dX2, x2.buffer(), x2.lengthOf() * x2.sizeOfT(), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dX3, x3.buffer(), x3.lengthOf() * x3.sizeOfT(), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dX1ShapeInfo, x1.shapeInfo(), shape::shapeInfoByteLength(x1.shapeInfo()), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dX2ShapeInfo, x2.shapeInfo(), shape::shapeInfoByteLength(x2.shapeInfo()), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dX3ShapeInfo, x3.shapeInfo(), shape::shapeInfoByteLength(x3.shapeInfo()), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dZShapeInfo, scalar.shapeInfo(), shape::shapeInfoByteLength(scalar.shapeInfo()), cudaMemcpyHostToDevice, stream); + + void* reductionPointer = nullptr; + cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); + ASSERT_EQ(0, cudaResult); + cudaResult = cudaMemset(reductionPointer, 0, 1024 * 1024); + ASSERT_EQ(0, cudaResult); + + LaunchContext lc(&stream, LaunchContext::defaultContext()->getReductionPointer(), LaunchContext::defaultContext()->getScalarPointer(), LaunchContext::defaultContext()->getAllocationPointer()); + + /***************************************/ + + NativeOpExecutioner::execIndexReduceScalar(&lc, + sd::indexreduce::IndexAbsoluteMax, + x1.buffer(), x1.shapeInfo(), + dX1, dX1ShapeInfo, + nullptr, + scalar.buffer(), scalar.shapeInfo(), + dZ, dZShapeInfo); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + + cudaMemcpyAsync(scalar.buffer(), dZ, scalar.lengthOf() * scalar.sizeOfT(), cudaMemcpyDeviceToHost, stream); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + + scalar.tickWriteHost(); + + ASSERT_NEAR(exp1.e(0), scalar.e(0), 1e-5); + + /***************************************/ + + NativeOpExecutioner::execIndexReduceScalar(&lc, + sd::indexreduce::IndexAbsoluteMax, + nullptr, x2.shapeInfo(), + dX2, dX2ShapeInfo, + nullptr, + nullptr, scalar.shapeInfo(), + dZ, dZShapeInfo); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + + cudaMemcpyAsync(scalar.buffer(), dZ, scalar.lengthOf() * scalar.sizeOfT(), cudaMemcpyDeviceToHost, stream); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + + ASSERT_NEAR(exp2.e(0), scalar.e(0), 1e-5); + + // ************************************* + + NativeOpExecutioner::execIndexReduceScalar(&lc, + sd::indexreduce::IndexAbsoluteMax, + nullptr, x3.shapeInfo(), + dX3, dX3ShapeInfo, + nullptr, + nullptr, scalar.shapeInfo(), + dZ, dZShapeInfo); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + + cudaMemcpyAsync(scalar.buffer(), dZ, scalar.lengthOf() * scalar.sizeOfT(), cudaMemcpyDeviceToHost, stream); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + + ASSERT_NEAR(exp3.e(0), scalar.e(0), 1e-5); + + /***************************************/ + + cudaFree(dX1); cudaFree(dX2); cudaFree(dX3); cudaFree(dZ); + cudaFree(dX1ShapeInfo); cudaFree(dX2ShapeInfo); cudaFree(dX3ShapeInfo); cudaFree(dZShapeInfo); + + /***************************************/ + + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); + +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execReduce3Scalar_1) { + + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x1('c', {2,2}, {1,2,3,4}, sd::DataType::INT32); + NDArray x2('c', {2,2}, {-1,-2,-3,-4}, sd::DataType::INT32); + NDArray x3('c', {2,2}, {1.5,1.5,1.5,1.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2,2}, {1,2,3,4}, sd::DataType::DOUBLE); + + NDArray exp1('c', {}, std::vector{-30.f}, sd::DataType::FLOAT32); + NDArray exp2('c', {}, std::vector{15.}, sd::DataType::DOUBLE); + + NDArray scalar1('c', {}, std::vector{100.f}, sd::DataType::FLOAT32); + NDArray scalar2('c', {}, std::vector{100.}, sd::DataType::DOUBLE); + + void *dX1, *dX2, *dX3, *dX4, *dZ1, *dZ2; + Nd4jLong *dX1ShapeInfo, *dX3ShapeInfo, *dZ1ShapeInfo, *dZ2ShapeInfo; + + cudaError_t cudaResult; + + cudaResult = cudaMalloc(reinterpret_cast(&dX1), x1.lengthOf() * x1.sizeOfT()); ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dX2), x2.lengthOf() * x2.sizeOfT()); ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dX3), x3.lengthOf() * x3.sizeOfT()); ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dX4), x4.lengthOf() * x4.sizeOfT()); ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dZ1), scalar1.lengthOf() * scalar1.sizeOfT()); ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dZ2), scalar2.lengthOf() * scalar2.sizeOfT()); ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dX1ShapeInfo), shape::shapeInfoByteLength(x1.shapeInfo())); ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dX3ShapeInfo), shape::shapeInfoByteLength(x3.shapeInfo())); ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dZ1ShapeInfo), shape::shapeInfoByteLength(scalar1.shapeInfo())); ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&dZ2ShapeInfo), shape::shapeInfoByteLength(scalar2.shapeInfo())); ASSERT_EQ(0, cudaResult); + + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); + ASSERT_EQ(0, cudaResult); + + x1.syncToHost(); + x2.syncToHost(); + x3.syncToHost(); + x4.syncToHost(); + scalar1.syncToHost(); + scalar2.syncToHost(); + + cudaMemcpyAsync(dX1, x1.buffer(), x1.lengthOf() * x1.sizeOfT(), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dX2, x2.buffer(), x2.lengthOf() * x2.sizeOfT(), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dX3, x3.buffer(), x3.lengthOf() * x3.sizeOfT(), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dX4, x4.buffer(), x4.lengthOf() * x4.sizeOfT(), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dX1ShapeInfo, x1.shapeInfo(), shape::shapeInfoByteLength(x1.shapeInfo()), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dX3ShapeInfo, x3.shapeInfo(), shape::shapeInfoByteLength(x3.shapeInfo()), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dZ1ShapeInfo, scalar1.shapeInfo(), shape::shapeInfoByteLength(scalar1.shapeInfo()), cudaMemcpyHostToDevice, stream); + cudaMemcpyAsync(dZ2ShapeInfo, scalar2.shapeInfo(), shape::shapeInfoByteLength(scalar2.shapeInfo()), cudaMemcpyHostToDevice, stream); + + /***************************************/ + + void* reductionPointer = nullptr; + int* allocationPointer = nullptr; + + cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); + cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); ASSERT_EQ(0, cudaResult); + + LaunchContext lc(&stream, reductionPointer, nullptr, allocationPointer); + + /***************************************/ + + NativeOpExecutioner::execReduce3Scalar(&lc, sd::reduce3::Dot,nullptr, x1.shapeInfo(),dX1, dX1ShapeInfo, nullptr, nullptr, x2.shapeInfo(),dX2, dX1ShapeInfo,nullptr, scalar1.shapeInfo(),dZ1, dZ1ShapeInfo); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + + scalar1.tickWriteHost(); + scalar2.tickWriteHost(); + + cudaMemcpyAsync(scalar1.buffer(), dZ1, scalar1.lengthOf() * scalar1.sizeOfT(), cudaMemcpyDeviceToHost, stream); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + + ASSERT_NEAR(exp1.e(0), scalar1.e(0), 1e-5); + + /***************************************/ + + NativeOpExecutioner::execReduce3Scalar(&lc, sd::reduce3::Dot,nullptr, x3.shapeInfo(),dX3, dX3ShapeInfo, nullptr, nullptr, x4.shapeInfo(),dX4, dX3ShapeInfo,nullptr, scalar2.shapeInfo(),dZ2, dZ2ShapeInfo); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + + cudaMemcpyAsync(scalar2.buffer(), dZ2, scalar2.lengthOf() * scalar2.sizeOfT(), cudaMemcpyDeviceToHost, stream); + + cudaResult = cudaStreamSynchronize(stream); + ASSERT_EQ(0, cudaResult); + + ASSERT_NEAR(exp2.e(0), scalar2.e(0), 1e-5); + + /***************************************/ + + cudaFree(dX1); cudaFree(dX2); cudaFree(dX3); cudaFree(dX4); cudaFree(dZ1); cudaFree(dZ2); + cudaFree(dX1ShapeInfo); cudaFree(dX3ShapeInfo); cudaFree(dZ1ShapeInfo); cudaFree(dZ2ShapeInfo); + + /***************************************/ + + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); +} + + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execReduce3_1) { + + NDArray x('c', {2,2}, {1,2,3,4}, sd::DataType::INT32); + NDArray y('c', {2,2}, {-1,-2,-3,-4}, sd::DataType::INT32); + + NDArray exp('c', {}, std::vector{-30.f}, sd::DataType::FLOAT32); + NDArray z('c', {}, std::vector{100.f}, sd::DataType::FLOAT32); + + std::vector dimensions = {0, 1}; + + x.syncToHost(); + y.syncToHost(); + z.syncToHost(); + + + std::vector> hostData; + hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions + std::vector devicePtrs(hostData.size(), nullptr); + + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3(&lc, sd::reduce3::Dot, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, + nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + (int*)devicePtrs[0], dimensions.size(), + nullptr, nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execReduce3_2) { + + NDArray x('c', {2,2}, {1.5,1.5,1.5,1.5}, sd::DataType::DOUBLE); + NDArray y('c', {2,2}, {1,2,3,4}, sd::DataType::DOUBLE); + + NDArray exp('c', {}, std::vector{15.}, sd::DataType::DOUBLE); + NDArray z('c', {}, std::vector{100.}, sd::DataType::DOUBLE); + + std::vector dimensions = {0, 1}; + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3(&lc, sd::reduce3::Dot, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, + nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + (int*)devicePtrs[0], dimensions.size(), + nullptr, nullptr, nullptr, nullptr); + + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execReduce3_3) { + + NDArray x('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::INT32); + NDArray y('c', {2,3}, {-6,-5,-4,-3,-2,-1}, sd::DataType::INT32); + + NDArray exp('c', {3}, {-18,-20,-18}, sd::DataType::FLOAT32); + NDArray z('c', {3}, {100,100,100}, sd::DataType::FLOAT32); + + std::vector dimensions = {0}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // evaluate yTad data + shape::TAD yTad; + yTad.init(y.shapeInfo(), dimensions.data(), dimensions.size()); + yTad.createTadOnlyShapeInfo(); + yTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + hostData.emplace_back(yTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(yTad.tadOnlyShapeInfo));// 3 -- yTadShapeInfo + hostData.emplace_back(yTad.tadOffsets, yTad.numTads * sizeof(Nd4jLong)); // 4-- yTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3(&lc, sd::reduce3::Dot, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, + nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + (int*)devicePtrs[0], dimensions.size(), + (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], + (Nd4jLong*)devicePtrs[3], (Nd4jLong*)devicePtrs[4]); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execReduce3_4) { + + NDArray x('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::DOUBLE); + NDArray y('c', {2,3}, {1.5,1.5,1.5,1.5,1.5,1.5}, sd::DataType::DOUBLE); + + NDArray exp('c', {2}, {9,22.5}, sd::DataType::DOUBLE); + NDArray z('c', {2}, {100,100}, sd::DataType::DOUBLE); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // evaluate yTad data + shape::TAD yTad; + yTad.init(y.shapeInfo(), dimensions.data(), dimensions.size()); + yTad.createTadOnlyShapeInfo(); + yTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + hostData.emplace_back(yTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(yTad.tadOnlyShapeInfo));// 3 -- yTadShapeInfo + hostData.emplace_back(yTad.tadOffsets, yTad.numTads * sizeof(Nd4jLong)); // 4-- yTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3(&lc, sd::reduce3::Dot, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, + nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + (int*)devicePtrs[0], dimensions.size(), + (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], + (Nd4jLong*)devicePtrs[3], (Nd4jLong*)devicePtrs[4]); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execReduce3_5) { + + NDArray x('c', {2,2,3}, {1.5,1.5,1.5,1.5,1.5,1.5,1.5,1.5,1.5,1.5,1.5,1.5}, sd::DataType::FLOAT32); + NDArray y('c', {2,2,3}, {1,2,3,4,5,6,7,8,9,10,11,12}, sd::DataType::FLOAT32); + + NDArray exp('c', {2,3}, {7.5, 10.5, 13.5, 25.5, 28.5, 31.5}, sd::DataType::FLOAT32); + NDArray z('c', {2,3}, {100,100,100,100,100,100}, sd::DataType::FLOAT32); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // evaluate yTad data + shape::TAD yTad; + yTad.init(y.shapeInfo(), dimensions.data(), dimensions.size()); + yTad.createTadOnlyShapeInfo(); + yTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + hostData.emplace_back(yTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(yTad.tadOnlyShapeInfo));// 3 -- yTadShapeInfo + hostData.emplace_back(yTad.tadOffsets, yTad.numTads * sizeof(Nd4jLong)); // 4-- yTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3(&lc, sd::reduce3::Dot, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, + nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + (int*)devicePtrs[0], dimensions.size(), + (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], + (Nd4jLong*)devicePtrs[3], (Nd4jLong*)devicePtrs[4]); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execReduce3All_1) { + + NDArray x('c', {2,2}, {1,2,3,4}, sd::DataType::INT32); + NDArray y('c', {2,3}, {-1,1,-1,1,-1,1}, sd::DataType::INT32); + + NDArray exp('c', {2,3}, {2,-2,2,2,-2,2}, sd::DataType::FLOAT32); + NDArray z('c', {2,3}, {100,100,100,100,100,100}, sd::DataType::FLOAT32); + + std::vector dimensions = {0}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // evaluate yTad data + shape::TAD yTad; + yTad.init(y.shapeInfo(), dimensions.data(), dimensions.size()); + yTad.createTadOnlyShapeInfo(); + yTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + hostData.emplace_back(yTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(yTad.tadOnlyShapeInfo));// 3 -- yTadShapeInfo + hostData.emplace_back(yTad.tadOffsets, yTad.numTads * sizeof(Nd4jLong)); // 4 -- yTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3All(&lc, sd::reduce3::Dot, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, + nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + (int*)devicePtrs[0], dimensions.size(), + (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], + (Nd4jLong*)devicePtrs[3], (Nd4jLong*)devicePtrs[4]); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execReduce3All_2) { + + NDArray x('c', {2,2}, {1,2,3,4}, sd::DataType::DOUBLE); + NDArray y('c', {2,3}, {1.5,1.5,1.5,1.5,1.5,1.5}, sd::DataType::DOUBLE); + + NDArray exp('c', {2,3}, {6,6,6,9,9,9}, sd::DataType::DOUBLE); + NDArray z('c', {2,3}, {100,100,100,100,100,100,},sd::DataType::DOUBLE); + + std::vector dimensions = {0}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // evaluate yTad data + shape::TAD yTad; + yTad.init(y.shapeInfo(), dimensions.data(), dimensions.size()); + yTad.createTadOnlyShapeInfo(); + yTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + hostData.emplace_back(yTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(yTad.tadOnlyShapeInfo));// 3 -- yTadShapeInfo + hostData.emplace_back(yTad.tadOffsets, yTad.numTads * sizeof(Nd4jLong)); // 4-- yTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3All(&lc, sd::reduce3::Dot, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, + nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + (int*)devicePtrs[0], dimensions.size(), + (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], + (Nd4jLong*)devicePtrs[3], (Nd4jLong*)devicePtrs[4]); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execIndexReduce_1) { + + NDArray x('c', {2,3}, {100,100,100,100,100,100}, sd::DataType::DOUBLE); + x.linspace(-2.); x.syncToDevice(); + NDArray exp('c', {2}, {2, 2}, sd::DataType::INT64); + NDArray z('c', {2}, {100,100}, sd::DataType::INT64); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execIndexReduce(&lc, sd::indexreduce::IndexMax, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + (int*)devicePtrs[0], dimensions.size(), + (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]); + + cudaResult = cudaStreamSynchronize(stream); + if (cudaResult != 0) + throw sd::cuda_exception::build("execIndexReduce failed", cudaResult); + + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for(int i = 0; i < devicePtrs.size(); ++i) + cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execIndexReduce_2) { + + NDArray x('c', {2,3,4,5}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100, + 100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100, + 100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100, + 100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100, + 100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100, + 100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::FLOAT32); + x.linspace(-2.f); x.syncToDevice(); + NDArray exp('c', {2,5}, {11,11,11,11,11,11,11,11,11,11}, sd::DataType::INT64); + NDArray z('c', {2,5}, {100,100,100,100,100,100,100,100,100,100}, sd::DataType::INT64); + + std::vector dimensions = {1,2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + + std::vector> hostData; + hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execIndexReduce(&lc, sd::indexreduce::IndexMax, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + (int*)devicePtrs[0], dimensions.size(), + (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for(int i = 0; i < devicePtrs.size(); ++i) + cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execIndexReduce_3) { + + NDArray x('c', {2,3,4,5}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100, + 100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100, + 100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100, + 100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100, + 100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100, + 100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::DOUBLE); + x.linspace(-2.); x.syncToDevice(); + NDArray exp('c', {3}, {39, 39, 39}, sd::DataType::INT64); + NDArray z('c', {3}, {100,100,100}, sd::DataType::INT64); + + std::vector dimensions = {0,2,3}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execIndexReduce(&lc, sd::indexreduce::IndexMax, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + (int*)devicePtrs[0], dimensions.size(), + (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for(int i = 0; i < devicePtrs.size(); ++i) + cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execScalar_1) { + + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x('c', {2,3}, {0,1,2,3,4,5}, sd::DataType::INT64); + NDArray exp('c',{2,3}, {0,0,1,1,2,2}, sd::DataType::INT64); + NDArray scalar('c',{}, std::vector{2.f}, sd::DataType::FLOAT32); + NDArray z('c', {2,3}, {100,100,100,100,100,100}, sd::DataType::INT64); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execScalar(&lc, sd::scalar::Divide, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + nullptr, scalar.shapeInfo(), scalar.specialBuffer(), scalar.specialShapeInfo(), + nullptr); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execScalar_2) { + + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x('c', {2,3}, {-1,-2,-3,-4,-5,-6}, sd::DataType::INT64); + NDArray exp('c',{2,3}, {10,10,10,10,10,10}, sd::DataType::FLOAT32); + NDArray scalar('c',{}, std::vector{10.f}, sd::DataType::FLOAT32); + NDArray z('c', {2,3}, {100,100,100,100,100,100}, sd::DataType::FLOAT32); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execScalar(&lc, sd::scalar::CopyPws, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + nullptr, scalar.shapeInfo(), scalar.specialBuffer(), scalar.specialShapeInfo(), + nullptr); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execScalar_3) { + + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x('c', {2,3,2}, {0,1,2,3,4,5,6,7,8,9,10,11}, sd::DataType::INT64); + NDArray scalars('c',{2,2}, {1,2,3,4}, sd::DataType::FLOAT32); + NDArray exp('c', {2,3,2}, {0,0,2,1,4,2, 2,1,2,2,3,2}, sd::DataType::INT64); + NDArray z('c', {2,3,2}, {100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::INT64); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execScalar(&lc, sd::scalar::Divide, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + nullptr, scalars.shapeInfo(), scalars.specialBuffer(), scalars.specialShapeInfo(), + (int*)devicePtrs[0], dimensions.size(), + (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], + nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for(int i = 0; i < devicePtrs.size(); ++i) + cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execScalarBool_1) { + + NDArray x('c', {2,3}, {-1,-2,0,1,2,3}, sd::DataType::BFLOAT16); + NDArray scalar('c',{}, std::vector{0}, sd::DataType::BFLOAT16); + NDArray exp('c',{2,3}, {0,0,0,1,1,1}, sd::DataType::BOOL); + NDArray z('c', {2,3}, {100,100,100,100,100,100,}, sd::DataType::BOOL); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + // call cuda kernel which calculates result + NativeOpExecutioner::execScalarBool(&lc, sd::scalar::GreaterThan, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + nullptr, scalar.shapeInfo(), scalar.specialBuffer(), scalar.specialShapeInfo(), + nullptr); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execScalarBool_2) { + + NDArray x('c', {2,3}, {0,1,2,3,4,5}, sd::DataType::FLOAT32); + NDArray scalars('c',{2}, {-1,4}, sd::DataType::FLOAT32); + NDArray exp('c', {2,3}, {1,1,1,0,0,1}, sd::DataType::BOOL); + NDArray z('c', {2,3}, {100,100,100,100,100,100}, sd::DataType::BOOL); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execScalarBool(&lc, sd::scalar::GreaterThan, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + nullptr, scalars.shapeInfo(), scalars.specialBuffer(), scalars.specialShapeInfo(), + (int*)devicePtrs[0], dimensions.size(), + (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], + nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for(int i = 0; i < devicePtrs.size(); ++i) + cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execBroadcast_1) { + + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::INT32); + NDArray y('c', {3}, {10, 20, 30}, sd::DataType::INT64); + NDArray z('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::INT32); + NDArray exp('c', {2,3,4}, {10, 11, 12, 13,24, 25, 26, 27,38, 39, 40, 41,22, 23, 24, 25,36, 37, 38, 39,50, 51, 52, 53}, sd::DataType::INT32); + x.linspace(0); x.syncToDevice(); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execBroadcast(&lc, sd::broadcast::Add, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + (int*)devicePtrs[0], dimensions.size(), + (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], + nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for(int i = 0; i < devicePtrs.size(); ++i) + cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execBroadcast_2) { + + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::INT32); + NDArray y('c', {2,4}, {10,20,30,40,50,60,70,80}, sd::DataType::FLOAT32); + NDArray z('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::FLOAT32); + NDArray exp('c', {2,3,4}, {10., 21., 32., 43., 14., 25., 36., 47., 18., 29., 40., 51., 62., 73., 84., 95., 66., 77., 88., 99., 70., 81., 92., 103}, sd::DataType::FLOAT32); + x.linspace(0); x.syncToDevice(); + + std::vector dimensions = {0,2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execBroadcast(&lc, sd::broadcast::Add, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + (int*)devicePtrs[0], dimensions.size(), + (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], + nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for(int i = 0; i < devicePtrs.size(); ++i) + cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execBroadcastBool_1) { + + NDArray x('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::INT32); + NDArray y('c', {3}, {2, 12, 22}, sd::DataType::INT32); + NDArray z('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,}, sd::DataType::BOOL); + NDArray exp('c', {2,3,4}, {0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0}, sd::DataType::BOOL); + x.linspace(1); x.syncToDevice(); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execBroadcastBool(&lc, sd::broadcast::EqualTo, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + nullptr, + (int*)devicePtrs[0], dimensions.size(), + (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], + nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for(int i = 0; i < devicePtrs.size(); ++i) + cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execBroadcastBool_2) { + + NDArray x('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100},sd::DataType::FLOAT32); + NDArray y('c', {2,4}, {1,10,10,15,20,20,20,24}, sd::DataType::FLOAT32); + NDArray z('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::BOOL); + NDArray exp('c', {2,3,4}, {1, 0, 0, 0,0, 0, 0, 0,0, 1, 0, 0,0, 0, 0, 0,0, 0, 0, 0,0, 0, 0, 1}, sd::DataType::BOOL); + x.linspace(1); x.syncToDevice(); + + std::vector dimensions = {0,2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execBroadcastBool(&lc, sd::broadcast::EqualTo, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + nullptr, + (int*)devicePtrs[0], dimensions.size(), + (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], + nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for(int i = 0; i < devicePtrs.size(); ++i) + cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execPairwiseTransform_1) { + + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x('c', {2,2,2}, {1,5,3,7,2,6,4,8}, sd::DataType::INT32); + NDArray y('c', {4,2}, {0.1,0.2,0.3,0.4,1.5,0.6,0.7,1.8}, sd::DataType::DOUBLE); + NDArray z('c', {8}, {100,100,100,100,100,100,100,100}, sd::DataType::INT32); + NDArray exp('c', {8}, {0,1,2,3,3,5,6,6}, sd::DataType::INT32); + x.permutei({2,1,0}); // -> {1,2,3,4,5,6,7,8} + x.syncShape(); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execPairwiseTransform(&lc, sd::pairwise::Subtract, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + nullptr); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execPairwiseBoolTransform_1) { + + NDArray x('c', {2,2,2}, {1,5,3,7,2,6,4,8}, sd::DataType::INT64); + NDArray y('c', {4,2}, {0,2,0,4,0,6,0,8}, sd::DataType::INT64); + NDArray z('c', {8}, {100,100,100,100,100,100,100,100}, sd::DataType::BOOL); + NDArray exp('c', {8}, {0,1,0,1,0,1,0,1}, sd::DataType::BOOL); + x.permutei({2,1,0}); // -> {1,2,3,4,5,6,7,8} + x.syncShape(); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execPairwiseBoolTransform(&lc, sd::pairwise::EqualTo, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + nullptr); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execTransformFloat_1) { + + NDArray x('c', {2,2}, {0, 6.25, 2.25, 12.25}, sd::DataType::DOUBLE); + NDArray z('c', {4}, {100,100,100,100}, sd::DataType::FLOAT32); + NDArray exp('c', {4}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + x.permutei({1,0}); + x.syncShape(); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execTransformFloat(&lc, sd::transform::Sqrt, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execTransformFloat_2) { + + NDArray x('c', {1,4}, {0, 4, 9, 16}, sd::DataType::INT64); + NDArray z('c', {2,2}, {100,100,100,100}, sd::DataType::DOUBLE); + NDArray exp('c', {2,2}, {0, 2, 3, 4}, sd::DataType::DOUBLE); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execTransformFloat(&lc, sd::transform::Sqrt, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execTransformAny_1) { + + NDArray x('c', {2,2}, {0, 6.25, 2.25, 12.25}, sd::DataType::DOUBLE); + NDArray z('c', {4,1}, {100,100,100,100}, sd::DataType::INT32); + NDArray exp('c', {4,1}, {0, 2, 6, 12}, sd::DataType::INT32); + x.permutei({1,0}); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execTransformAny(&lc, sd::transform::Assign, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execTransformAny_2) { + + NDArray x('c', {1,4}, {0, 6.25, 2.25, 12.25}, sd::DataType::BFLOAT16); + NDArray z('c', {2,2}, {100,100,100,100}, sd::DataType::FLOAT32); + NDArray exp('c', {2,2}, {0, 6.25, 2.25, 12.25}, sd::DataType::FLOAT32); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execTransformAny(&lc, sd::transform::Assign, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execTransformStrict_1) { + + NDArray x('c', {2,3}, {0,2,4,1,3,5}, sd::DataType::DOUBLE); + NDArray z('c', {3,2}, {100,100,100,100,100,100}, sd::DataType::DOUBLE); + NDArray exp('c', {3,2}, {0, 3, 12, 27, 48, 75}, sd::DataType::DOUBLE); + x.permutei({1,0}); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execTransformStrict(&lc, sd::transform::CubeDerivative, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execTransformStrict_2) { + + NDArray x('c', {6}, {0,1,2,3,4,5}, sd::DataType::FLOAT32); + NDArray z('c', {3,2}, {100,100,100,100,100,100}, sd::DataType::FLOAT32); + NDArray exp('c', {3,2}, {0, 3, 12, 27, 48, 75}, sd::DataType::FLOAT32); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execTransformStrict(&lc, sd::transform::CubeDerivative, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execTransformSame_1) { + + NDArray x('c', {2,3}, {0,2.5,4.5,1.5,3.5,5.5}, sd::DataType::DOUBLE); + NDArray z('c', {1,6}, {100,100,100,100,100,100}, sd::DataType::DOUBLE); + NDArray exp('c', {1,6}, {0,2.25,6.25,12.25,20.25,30.25}, sd::DataType::DOUBLE); + x.permutei({1,0}); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execTransformSame(&lc, sd::transform::Square, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execTransformSame_2) { + + NDArray x('c', {6}, {0,1,2,3,4,5}, sd::DataType::INT32); + NDArray z('c', {3,2}, {100,100,100,100,100,100}, sd::DataType::INT32); + NDArray exp('c', {3,2}, {0,1,4,9,16,25}, sd::DataType::INT32); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execTransformSame(&lc, sd::transform::Square, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execTransformBool_1) { + + NDArray x('c', {2,3}, {0,2,4,-1,-3,-5}, sd::DataType::DOUBLE); + NDArray z('c', {1,6}, {100,100,100,100,100,100}, sd::DataType::BOOL); + NDArray exp('c', {1,6}, {0,0,1,0,1,0}, sd::DataType::BOOL); + x.permutei({1,0}); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execTransformBool(&lc, sd::transform::IsPositive, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execTransformBool_2) { + + NDArray x('c', {6}, {0,-1,2,-3,4,-5}, sd::DataType::INT32); + NDArray z('c', {3,2}, {100,100,100,100,100,100}, sd::DataType::BOOL); + NDArray exp('c', {3,2}, {0,0,1,0,1,0}, sd::DataType::BOOL); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + NativeOpExecutioner::execTransformBool(&lc, sd::transform::IsPositive, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + nullptr, nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execReduceFloat_1) { + + NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, sd::DataType::INT32); + NDArray z('c', {3}, {100,100,100}, sd::DataType::FLOAT32); + NDArray exp('c', {3}, {2.5, 6.5, 10.5}, sd::DataType::FLOAT32); + x.permutei({2,1,0}); + + std::vector dimensions = {0,2}; + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + std::vector dims = sd::ShapeUtils::evalDimsForReduceOp(x.rankOf(), dimensions); + NativeOpExecutioner::execReduceFloat(&lc, sd::reduce::Mean, nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), dims.data(), dims.size()); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execReduceFloat_2) { + + NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, sd::DataType::INT32); + NDArray z('c', {2,4}, {100,100,100,100,100,100,100,100}, sd::DataType::DOUBLE); + NDArray exp('c', {2,4}, {-1., 0., 1., 2.,11., 12., 13., 14.}, sd::DataType::DOUBLE); + + std::vector dimensions = {1}; + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + std::vector dims = sd::ShapeUtils::evalDimsForReduceOp(x.rankOf(), dimensions); + NativeOpExecutioner::execReduceFloat(&lc, sd::reduce::Mean, nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), dims.data(), dims.size()); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execReduceSame_1) { + + NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, sd::DataType::INT32); + NDArray z('c', {3}, {100,100,100}, sd::DataType::INT32); + NDArray exp('c', {3}, {20, 52, 84}, sd::DataType::INT32); + x.permutei({2,1,0}); + + std::vector dimensions = {0,2}; + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + std::vector dims = sd::ShapeUtils::evalDimsForReduceOp(x.rankOf(), dimensions); + NativeOpExecutioner::execReduceSame(&lc, sd::reduce::Sum, nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), dims.data(), dims.size()); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execReduceSame_2) { + + NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, sd::DataType::FLOAT32); + NDArray z('c', {2,4}, {100,100,100,100,100,100,100,100}, sd::DataType::FLOAT32); + NDArray exp('c', {2,4}, {-3., 0., 3., 6.,33., 36., 39., 42.}, sd::DataType::FLOAT32); + + std::vector dimensions = {1}; + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + std::vector dims = sd::ShapeUtils::evalDimsForReduceOp(x.rankOf(), dimensions); + NativeOpExecutioner::execReduceSame(&lc, sd::reduce::Sum, nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), dims.data(), dims.size()); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execReduceBool_1) { + + NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,-7,-8,-9,-10,-11,-12,-13,-14,-15,-16,-17,-18}, sd::DataType::INT32); + NDArray z('c', {3}, {100,100,100}, sd::DataType::BOOL); + NDArray exp('c', {3}, {0, 1, 1}, sd::DataType::BOOL); + x.permutei({2,1,0}); + + std::vector dimensions = {0,2}; + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + std::vector dims = sd::ShapeUtils::evalDimsForReduceOp(x.rankOf(), dimensions); + NativeOpExecutioner::execReduceBool(&lc, sd::reduce::IsPositive, nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), dims.data(), dims.size()); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execReduceBool_2) { + + NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,-7,-8,-9,-10,-11,-12,-13,-14,-15,-16,-17,-18}, sd::DataType::FLOAT32); + NDArray z('c', {2,4}, {100,100,100,100,100,100,100,100}, sd::DataType::BOOL); + NDArray exp('c', {2,4}, {1, 1, 1, 1, 0, 0, 0, 0}, sd::DataType::BOOL); + + std::vector dimensions = {1}; + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + std::vector dims = sd::ShapeUtils::evalDimsForReduceOp(x.rankOf(), dimensions); + NativeOpExecutioner::execReduceBool(&lc, sd::reduce::IsPositive, nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), dims.data(), dims.size()); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execReduceLong_1) { + + NDArray x('c', {2,3,4}, {-5,0,-3,0,-1,0,1,2,3,4,5,6,7,0,9,10,11,0,13,14,0,16,0,18}, sd::DataType::INT32); + NDArray z('c', {3}, {100,100,100}, sd::DataType::INT64); + NDArray exp('c', {3}, {5,6,6}, sd::DataType::INT64); + x.permutei({2,1,0}); + + std::vector dimensions = {0,2}; + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + std::vector dims = sd::ShapeUtils::evalDimsForReduceOp(x.rankOf(), dimensions); + NativeOpExecutioner::execReduceLong(&lc, sd::reduce::CountNonZero, nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), dims.data(), dims.size()); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execReduceLong_2) { + + NDArray x('c', {2,3,4}, {-5,0,-3,0,-1,0,1,2,3,4,5,6,7,0,9,10,11,0,13,14,0,16,0,18}, sd::DataType::FLOAT32); + NDArray z('c', {2,4}, {100,100,100,100,100,100,100,100}, sd::DataType::INT64); + NDArray exp('c', {2,4}, {3, 1, 3, 2, 2, 1, 2, 3}, sd::DataType::INT64); + + std::vector dimensions = {1}; + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // call cuda kernel which calculates result + std::vector dims = sd::ShapeUtils::evalDimsForReduceOp(x.rankOf(), dimensions); + NativeOpExecutioner::execReduceLong(&lc, sd::reduce::CountNonZero, nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), dims.data(), dims.size()); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); + ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execReduceFloatScalar_1) { + + NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, sd::DataType::INT32); + NDArray z('c', {}, std::vector{100}, sd::DataType::FLOAT32); + NDArray exp('c', {}, std::vector{6.5}, sd::DataType::FLOAT32); + x.permutei({2,1,0}); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + void* reductionPointer; + cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); + int* allocationPointer; + cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); ASSERT_EQ(0, cudaResult); + lc.setReductionPointer(reductionPointer); + lc.setAllocationPointer(allocationPointer); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceFloatScalar(&lc, sd::reduce::Mean, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execReduceFloatScalar_2) { + + NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, sd::DataType::INT32); + NDArray z('c', {}, std::vector{100}, sd::DataType::DOUBLE); + NDArray exp('c', {}, std::vector{6.5}, sd::DataType::DOUBLE); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + void* reductionPointer; + cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); + int* allocationPointer; + cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); ASSERT_EQ(0, cudaResult); + lc.setReductionPointer(reductionPointer); + lc.setAllocationPointer(allocationPointer); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceFloatScalar(&lc, sd::reduce::Mean, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execReduceSameScalar_1) { + + NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, sd::DataType::INT32); + NDArray z('c', {}, std::vector{100}, sd::DataType::INT32); + NDArray exp('c', {}, std::vector{156}, sd::DataType::INT32); + x.permutei({2,1,0}); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + void* reductionPointer; + cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); + int* allocationPointer; + cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); ASSERT_EQ(0, cudaResult); + lc.setReductionPointer(reductionPointer); + lc.setAllocationPointer(allocationPointer); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceSameScalar(&lc, sd::reduce::Sum, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execReduceSameScalar_2) { + + NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}, sd::DataType::DOUBLE); + NDArray z('c', {}, std::vector{100}, sd::DataType::DOUBLE); + NDArray exp('c', {}, std::vector{156}, sd::DataType::DOUBLE); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + void* reductionPointer; + cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); + int* allocationPointer; + cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); ASSERT_EQ(0, cudaResult); + lc.setReductionPointer(reductionPointer); + lc.setAllocationPointer(allocationPointer); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceSameScalar(&lc, sd::reduce::Sum, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execReduceBoolScalar_1) { + + NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,-7,-8,-9,-10,-11,-12,-13,-14,-15,-16,-17,-18}, sd::DataType::INT32); + NDArray z('c', {}, std::vector{100}, sd::DataType::BOOL); + NDArray exp('c', {}, std::vector{1}, sd::DataType::BOOL); + x.permutei({2,1,0}); + x.syncShape(); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + void* reductionPointer; + cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); + int* allocationPointer; + cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); ASSERT_EQ(0, cudaResult); + lc.setReductionPointer(reductionPointer); + lc.setAllocationPointer(allocationPointer); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceBoolScalar(&lc, sd::reduce::IsPositive, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execReduceBoolScalar_2) { + + NDArray x('c', {2,3,4}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6,-7,-8,-9,-10,-11,-12,-13,-14,-15,-16,-17,-18}, sd::DataType::DOUBLE); + NDArray z('c', {}, std::vector{100}, sd::DataType::BOOL); + NDArray exp('c', {}, std::vector{1}, sd::DataType::BOOL); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + void* reductionPointer; + cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); + int* allocationPointer; + cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); ASSERT_EQ(0, cudaResult); + lc.setReductionPointer(reductionPointer); + lc.setAllocationPointer(allocationPointer); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceBoolScalar(&lc, sd::reduce::IsPositive, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execReduceLongScalar_1) { + + NDArray x('c', {2,3,4}, {-5,0,-3,0,-1,0,1,2,3,4,5,6,7,0,9,10,11,0,13,14,0,16,0,18}, sd::DataType::INT32); + NDArray z('c', {}, std::vector{100}, sd::DataType::INT64); + NDArray exp('c', {}, std::vector{17}, sd::DataType::INT64); + x.permutei({2,1,0}); + x.syncShape(); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + void* reductionPointer; + cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); + int* allocationPointer; + cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); ASSERT_EQ(0, cudaResult); + lc.setReductionPointer(reductionPointer); + lc.setAllocationPointer(allocationPointer); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceLongScalar(&lc, sd::reduce::CountNonZero, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execReduceLongScalar_2) { + + NDArray x('c', {2,3,4}, {-5,0,-3,0,-1,0,1,2,3,4,5,6,7,0,9,10,11,0,13,14,0,16,0,18}, sd::DataType::DOUBLE); + NDArray z('c', {}, std::vector{100}, sd::DataType::INT64); + NDArray exp('c', {}, std::vector{17}, sd::DataType::INT64); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + void* reductionPointer; + cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); + int* allocationPointer; + cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); ASSERT_EQ(0, cudaResult); + lc.setReductionPointer(reductionPointer); + lc.setAllocationPointer(allocationPointer); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduceLongScalar(&lc, sd::reduce::CountNonZero, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execReduce3TAD_1) { + + NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, sd::DataType::FLOAT32); + NDArray y('c', {2,2}, {1,2,3,4}, sd::DataType::FLOAT32); + NDArray exp('c', {3}, {10,20,30}, sd::DataType::DOUBLE); + NDArray z('c', {3}, {100,100,100}, sd::DataType::DOUBLE); + + std::vector dimensions = {0,1}; + auto packX = ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), dimensions); + LaunchContext* context = x.getContext(); + + x.syncToDevice(); + y.syncToDevice(); + PointersManager pm(context, "execReduce3TAD_1"); + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3TAD(context, sd::reduce3::Dot, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, + nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + nullptr, dimensions.size(), + packX.specialShapeInfo(), packX.specialOffsets(), nullptr, nullptr); + pm.synchronize(); +// cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); +// z.printIndexedBuffer("OutputReduce3TAD"); + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execReduce3TAD_2) { + + NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, sd::DataType::INT64); + NDArray y('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::INT64); + NDArray exp('c', {2}, {10,73}, sd::DataType::FLOAT32); + NDArray z('c', {2}, {100,100}, sd::DataType::FLOAT32); + + std::vector dimensions = {0,2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3TAD(&lc, sd::reduce3::Dot, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, + nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + (int*)devicePtrs[0], dimensions.size(), + (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execReduce3TAD_3) { + + NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, sd::DataType::INT64); + NDArray y('c', {3}, {1,2,3}, sd::DataType::INT64); + NDArray exp('c', {2,2}, {-22,-4,14,32}, sd::DataType::FLOAT32); + NDArray z('c', {2,2}, {100,100,100,100}, sd::DataType::FLOAT32); + + std::vector dimensions = {2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3TAD(&lc, sd::reduce3::Dot, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, + nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + (int*)devicePtrs[0], dimensions.size(), + (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execReduce3TAD_4) { + + NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, sd::DataType::DOUBLE); + NDArray y('c', {2,2,3}, {10,20,30,40,50,60,70,80,90,100,110,120}, sd::DataType::DOUBLE); + NDArray exp('c', {}, std::vector{1820}, sd::DataType::FLOAT32); + NDArray z('c', {}, std::vector{100}, sd::DataType::FLOAT32); + + std::vector dimensions = {0,1,2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execReduce3TAD(&lc, sd::reduce3::Dot, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, + nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + (int*)devicePtrs[0], dimensions.size(), + (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2]); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execSummaryStats_1) { + // FIXME: Yurii, this test should be fixed + if (1 > 0) + return; + + NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, sd::DataType::INT64); + NDArray exp('c', {}, std::vector{3.605551}, sd::DataType::FLOAT32); + NDArray z('c', {}, std::vector{100}, sd::DataType::FLOAT32); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + void* reductionPointer; + cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); + lc.setReductionPointer(reductionPointer); + + // call cuda kernel which calculates result + NativeOpExecutioner::execSummaryStats(&lc, sd::variance::SummaryStatsStandardDeviation, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + true); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execSummaryStats_2) { + + NDArray x('c', {2,2,3}, {-5,-4,-3,-20,-1,0,1,2,3,4,5,6}, sd::DataType::DOUBLE); + NDArray exp('c', {2}, {3.405877, 9.715966}, sd::DataType::FLOAT32); + NDArray z('c', {2}, {100,100}, sd::DataType::FLOAT32); + + std::vector dimensions = {0,2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execSummaryStats(&lc, sd::variance::SummaryStatsStandardDeviation, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + (int*)devicePtrs[0], dimensions.size(), + (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], + true); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} +/* +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execSummaryStats_3) { + + NDArray x('c', {2,2,3}, {-5,-4,-3,-20,-1,0,1,2,3,4,5,6}, sd::DataType::DOUBLE); + NDArray exp('c', {2}, {10.606602, 2.121320}, sd::DataType::FLOAT32); + NDArray z('c', {2}, {100,100}, sd::DataType::FLOAT32); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo));// 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execSummaryStats(&lc, sd::variance::SummaryStatsStandardDeviation, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, + nullptr, z.shapeInfo(), z.specialBuffer(), z.special(), + (int*)devicePtrs[0], dimensions.size(), + (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], + true); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} +*/ + +//////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execSummaryStatsScalar_1) { + + NDArray x('c', {2,2,3}, {-5,-4,-3,-2,-1,0,1,2,3,4,5,6}, sd::DataType::INT64); + NDArray exp('c', {}, std::vector{3.605551}, sd::DataType::FLOAT32); + NDArray z('c', {}, std::vector{100}, sd::DataType::FLOAT32); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + void* reductionPointer; + cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); ASSERT_EQ(0, cudaResult); + lc.setReductionPointer(reductionPointer); + + // call cuda kernel which calculates result + NativeOpExecutioner::execSummaryStatsScalar(&lc, sd::variance::SummaryStatsStandardDeviation, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + true); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execRandom_1) { + +// NDArray z('c', {10}, {100,0,0,0,0,0,0,0,0,0}, sd::DataType::DOUBLE); + NDArray z('c', {10}, {100,0,0,0,0,0,0,0,0,100}, sd::DataType::FLOAT32); + NDArray exp('c', {10}, {0.050942, -0.183229, -0.093921, 0.075469, 0.257166, -0.254838, 0.342227, -0.682188, -0.004345, 0.464633}, sd::DataType::FLOAT32); + + sd::graph::RandomGenerator gen(119,5); + + cudaError_t cudaResult; + NDArray* array = &z; + ExtraArguments arguments({0.f, 0.5f}); + auto context = z.getContext(); + PointersManager pm(context, "tests::execRandom_1"); +// z.printIndexedBuffer("Input data"); +// z.syncToDevice(); + NativeOpExecutioner::execRandom(context, random::GaussianDistribution, &gen, array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), array->buffer(), array->shapeInfo(), array->specialBuffer(), array->specialShapeInfo(), arguments.argumentsAsT(array->dataType())); + pm.synchronize(); + z.tickWriteDevice(); +// z.printIndexedBuffer("Output Gaussian"); +// RandomLauncher::fillGaussian(context, gen, &z, 0.f, 0.5f); +// pm.synchronize(); +// z.tickWriteDevice(); +// z.printIndexedBuffer("Output Gaussian"); + +// cudaStream_t stream; +// cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); +// LaunchContext lc(&stream); +// +// // ::execRandom(extraPointers, random::GaussianDistribution, &gen, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.special(), &extra); +// // call cuda kernel which calculates result +// NativeOpExecutioner::execRandom(&lc, sd::random::GaussianDistribution, +// &gen, +// nullptr, z.shapeInfo(), z.specialBuffer(), z.special(), +// nullptr, z.shapeInfo(), z.specialBuffer(), z.special(), +// nullptr, z.shapeInfo(), z.specialBuffer(), z.special(), +// extraArguments.argumentsAsT(z.dataType())); +// +// cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); +// ASSERT_EQ(cudaResult, 0); +// z.tickWriteDevice(); +// z.syncToHost(); +// z.printIndexedBuffer("Random1"); + ASSERT_EQ(exp, z); +// // verify results +// for (int e = 0; e < z.lengthOf(); e++) +// ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); +// cudaFree(dExtraArgs); + // free allocated global device memory +// cudaFree(dGen); + // delete cuda stream +// cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execRandom_2) { + + NDArray x('c', {10}, {0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1}, sd::DataType::DOUBLE); + NDArray z('c', {2,5}, {100,100,100,100,100,100,100,100,100,100}, sd::DataType::DOUBLE); + NDArray exp('c', {10}, {0., 0., 0.3, 0., 0.5, 0., 0.7, 0., 0., 1.}, sd::DataType::DOUBLE); + + ExtraArguments extraArguments({0.7}); + sd::graph::RandomGenerator gen(119,5); + +// // prepare input arrays for prepareDataForCuda function +// std::vector> hostData; +// hostData.emplace_back(extraArguments.data(), extraArguments.size() * sizeof(double)); // 0 -- dimensions +// std::vector devicePtrs(hostData.size(), nullptr); +// + // create cuda stream and LaunchContext + cudaError_t cudaResult; +// cudaStream_t stream; +// cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext* lc = x.getContext(); //(&stream); + + // allocate required amount of global device memory and copy host data to it +// cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execRandom(lc, sd::random::DropOut, + &gen, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + extraArguments.argumentsAsT(z.dataType())); + + cudaResult = cudaStreamSynchronize(*lc->getCudaStream()); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + z.syncToHost(); + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory +// for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream +// cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execRandom_3) { + + NDArray z('c', {10}, {100,100,100,100,100,100,100,100,100,100}, sd::DataType::DOUBLE); + NDArray exp('c', {10}, {2.373649, 2.239791, 1.887353, 2.488636, 2.068904, 2.281399, 1.828228, 2.228222, 2.490847, 1.669537}, sd::DataType::DOUBLE); + + std::vector extraArguments = {1.5, 2.5}; + sd::graph::RandomGenerator gen(119,5); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(extraArguments.data(), extraArguments.size() * sizeof(double)); // 0 -- dimensions + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execRandom(&lc, sd::random::UniformDistribution, + &gen, + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + devicePtrs[0]); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests1, execRandom_4) { + + NDArray z('c', {2,5}, {1,2,3,4,5,6,7,8,9,10}, sd::DataType::FLOAT32); + NDArray exp('c', {10}, {2.373649, 2.281399, 2.239791, 1.828228, 1.887353, 2.228222, 2.488636, 2.490847, 2.068904, 1.669537}, sd::DataType::FLOAT32); + z.permutei({1,0}); + + ExtraArguments extraArguments({1.5, 2.5}); + sd::graph::RandomGenerator gen(119,5); + +// // prepare input arrays for prepareDataForCuda function +// std::vector> hostData; +// hostData.emplace_back(extraArguments.data(), extraArguments.size() * sizeof(double)); // 0 -- dimensions +// std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext +// cudaError_t cudaResult; +// cudaStream_t stream; +// cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); +// LaunchContext lc(&stream); +// +// // allocate required amount of global device memory and copy host data to it +// cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); + auto context = z.getContext(); + PointersManager pm(context, "execRandom4"); + // call cuda kernel which calculates result + NativeOpExecutioner::execRandom(context, sd::random::UniformDistribution, + &gen, + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + extraArguments.argumentsAsT(z.dataType())); + +// cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); +// z.printIndexedBuffer("Output Uniform4"); + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory +// for(int i = 0; i < devicePtrs.size(); ++i) cudaFree(devicePtrs[i]); + + // delete cuda stream +// cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/CudaBasicsTests2.cu b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/CudaBasicsTests2.cu new file mode 100644 index 000000000..bc95ce39b --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/CudaBasicsTests2.cu @@ -0,0 +1,1159 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // + // @author raver119@gmail.com + // + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +using namespace sd; +using namespace sd::graph; + +class CudaBasicsTests2 : public testing::Test { +public: + +}; + +TEST_F(CudaBasicsTests2, test_devices_1) { + auto caps = Environment::getInstance().capabilities(); + ASSERT_FALSE(caps.empty()); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxM_1) { + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); + NDArray b('f', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::FLOAT32); + NDArray c('f', {M,N}, sd::DataType::FLOAT32); + + NDArray exp('f', {M,N}, {0.1, 0.3, 0.5, 2.5, 2.7, 2.9, 4.9, 5.1, 5.3, 7.3, 7.5, 7.7, 9.7, 9.9, 10.1}, sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + // c.printIndexedBuffer(); + + ASSERT_TRUE(c.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxM_2) { + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + NDArray b('f', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::DOUBLE); + NDArray c('f', {M,N}, sd::DataType::DOUBLE); + NDArray exp('f', {M,N}, {-1.6, -0.7, 0.2, -0.8, 0.1, 1., -0., 0.9, 1.8, 0.8, 1.7, 2.6, 1.6, 2.5, 3.4}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxM_3) { + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::DOUBLE); + NDArray c('f', {M,N}, sd::DataType::DOUBLE); + + NDArray exp('f', {M,N}, {-1.9, -0.9, 0.1, 1.3, 0.3, -0.7, -0.7, 0.3, 1.3, 0.1, -0.9, -1.9, 0.5, 1.5, 2.5}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxM_4) { + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + NDArray b('f', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::DOUBLE); + NDArray c('c', {M,N}, sd::DataType::DOUBLE); + + NDArray exp('c', {M,N}, {0.1, 2.5, 4.9, 7.3, 9.7,0.3, 2.7, 5.1, 7.5, 9.9,0.5, 2.9, 5.3, 7.7, 10.1}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + ASSERT_TRUE(c.equalsTo(&exp)); + + + // NDArray* pA = a.permute({1,0}); + // NDArray* pB = b.permute({1,0}); + // NDArray* pC = c.permute({1,0}); + + // sd::MmulHelper::mmul(pB, pA, pC, 1., 0.); + // ASSERT_TRUE(c.equalsTo(&exp)); + + // delete pA; + // delete pB; + // delete pC; +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxM_5) { + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::DOUBLE); + NDArray c('f', {M,N}, sd::DataType::DOUBLE); + + NDArray exp('f', {M,N}, {-8.8, -4.3, 0.2, 8.6, 4.1, -0.4, -8.4, -3.9, 0.6, 8.2, 3.7, -0.8, -8.0, -3.5, 1.}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxM_6) { + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + NDArray b('f', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::DOUBLE); + NDArray c('c', {M,N}, sd::DataType::DOUBLE); + + NDArray exp('c', {M,N}, {-1.6, -0.8, -0.0, 0.8, 1.6, -0.7, 0.1, 0.9, 1.7, 2.5, 0.2, 1.0, 1.8, 2.6, 3.4}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxM_7) { + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::DOUBLE); + NDArray c('c', {M,N}, sd::DataType::DOUBLE); + + NDArray exp('c', {M,N}, {-1.9, 1.3, -0.7, 0.1, 0.5, -0.9, 0.3, 0.3, -0.9, 1.5, 0.1, -0.7, 1.3, -1.9, 2.5}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxM_8) { + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::DOUBLE); + NDArray c('c', {M,N}, sd::DataType::DOUBLE); + + NDArray exp('c', {M,N}, {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, -3.9, 3.7, -3.5, 0.2, -0.4, 0.6, -0.8, 1.}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxM_9) { + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); + NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::FLOAT32); + NDArray c('c', {M,N}, sd::DataType::FLOAT32); + + NDArray exp('c', {M,N}, {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, -3.9, 3.7, -3.5, 0.2, -0.4, 0.6, -0.8, 1.}, sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxM_10) { + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); + NDArray b('f', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::FLOAT32); + NDArray c('f', {M,N}, sd::DataType::FLOAT32); + + NDArray exp('f', {M,N}, {0.1, 0.3, 0.5, 2.5, 2.7, 2.9, 4.9, 5.1, 5.3, 7.3, 7.5, 7.7, 9.7, 9.9, 10.1}, sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + // c.printIndexedBuffer(); + + ASSERT_TRUE(c.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxM_11) { + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); + NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::FLOAT32); + NDArray c('f', {M,N}, sd::DataType::FLOAT32); + + NDArray exp('f', {M,N}, {-1.9, -0.9, 0.1, 1.3, 0.3, -0.7, -0.7, 0.3, 1.3, 0.1, -0.9, -1.9, 0.5, 1.5, 2.5}, sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxM_12) { + + int devCnt = 0; + cudaGetDevice(&devCnt); + if(Environment::getInstance().capabilities()[devCnt].first() < 5) return; + + const Nd4jLong M = 4; + const Nd4jLong K = 4; + const Nd4jLong N = 4; + + NDArray a('f', {M,K}, {1.,2,3,4,5,6,7,8,9,2,3,2,1,0,4,7.}, sd::DataType::INT8); + NDArray b('f', {K,N}, {-2,-3,0,1,5,-6,7,-8,9,-1,2,-2,3,-4,5,-6.}, sd::DataType::INT8); + NDArray c('f', {M,N}, sd::DataType::FLOAT32); + + NDArray exp('f', {M,N}, {-16., -22., -23., -25., 30., -12., -38., -70., 20., 16., 18., 18., 22., -8., -28., -52.}, sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + // c.printBuffer(); + + ASSERT_TRUE(c.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxM_13) { + + int devCnt = 0; + cudaGetDevice(&devCnt); + if(Environment::getInstance().capabilities()[devCnt].first() < 5) return; + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('f', {M,K}, {1.,2,3,4,5,6,7,8,9,10,11,12}, sd::DataType::INT8); + NDArray b('c', {K,N}, {-2,-3,0,1,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::INT8); + NDArray c('f', {M,N}, sd::DataType::FLOAT32); + + NDArray exp('f', {M,N}, {-109., -122., -135., 111., 120., 129., -121., -134., -147., 129., 144., 159., -130., -140., -150.}, sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxM_14) { + + int devCnt = 0; + cudaGetDevice(&devCnt); + if(Environment::getInstance().capabilities()[devCnt].first() < 5) return; + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('c', {M,K}, {1.,2,3,4,5,6,7,8,9,10,11,12}, sd::DataType::INT8); + NDArray b('c', {K,N}, {-2,-3,0,1,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::INT8); + NDArray c('c', {M,N}, sd::DataType::FLOAT32); + + NDArray exp('c', {M,N}, {-45., 43., -49., 53., -50., -97., 79., -101., 113., -90., -149., 115., -153., 173., -130.}, sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxM_15) { + + int devCnt = 0; + cudaGetDevice(&devCnt); + if(Environment::getInstance().capabilities()[devCnt].first() < 5) return; + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::HALF); + NDArray b('f', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::HALF); + NDArray c('f', {M,N}, sd::DataType::FLOAT32); + + NDArray exp('f', {M,N}, {0.1, 0.3, 0.5, 2.5, 2.7, 2.9, 4.9, 5.1, 5.3, 7.3, 7.5, 7.7, 9.7, 9.9, 10.1}, sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + // c.printBuffer(); + + ASSERT_TRUE(c.equalsTo(&exp, 0.01)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxM_16) { + + int devCnt = 0; + cudaGetDevice(&devCnt); + if(Environment::getInstance().capabilities()[devCnt].first() < 5) return; + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::HALF); + NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::HALF); + NDArray c('f', {M,N}, sd::DataType::FLOAT32); + + NDArray exp('f', {M,N}, {-1.9, -0.9, 0.1, 1.3, 0.3, -0.7, -0.7, 0.3, 1.3, 0.1, -0.9, -1.9, 0.5, 1.5, 2.5}, sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp, 0.01)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxM_17) { + + int devCnt = 0; + cudaGetDevice(&devCnt); + if(Environment::getInstance().capabilities()[devCnt].first() < 5) return; + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::HALF); + NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::HALF); + NDArray c('c', {M,N}, sd::DataType::FLOAT32); + + NDArray exp('c', {M,N}, {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, -3.9, 3.7, -3.5, 0.2, -0.4, 0.6, -0.8, 1.}, sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp, 0.01)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxM_18) { + + int devCnt = 0; + cudaGetDevice(&devCnt); + if(Environment::getInstance().capabilities()[devCnt].first() < 5.3) return; + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::HALF); + NDArray b('f', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::HALF); + NDArray c('f', {M,N}, sd::DataType::HALF); + + NDArray exp('f', {M,N}, {0.1, 0.3, 0.5, 2.5, 2.7, 2.9, 4.9, 5.1, 5.3, 7.3, 7.5, 7.7, 9.7, 9.9, 10.1}, sd::DataType::HALF); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp, 1e-1)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxM_19) { + + int devCnt = 0; + cudaGetDevice(&devCnt); + if(Environment::getInstance().capabilities()[devCnt].first() < 5.3) return; + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::HALF); + NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::HALF); + NDArray c('f', {M,N}, sd::DataType::HALF); + + NDArray exp('f', {M,N}, {-1.9, -0.9, 0.1, 1.3, 0.3, -0.7, -0.7, 0.3, 1.3, 0.1, -0.9, -1.9, 0.5, 1.5, 2.5}, sd::DataType::HALF); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp, 1e-1)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxM_20) { + + int devCnt = 0; + cudaGetDevice(&devCnt); + if(Environment::getInstance().capabilities()[devCnt].first() < 5.3) return; + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::HALF); + NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::HALF); + NDArray c('c', {M,N}, sd::DataType::HALF); + + NDArray exp('c', {M,N}, {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, -3.9, 3.7, -3.5, 0.2, -0.4, 0.6, -0.8, 1.}, sd::DataType::HALF); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp, 1e-1)); +} +/* +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxM_21) { + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('c', {M,K}, {1.,2,3,4,5,6,7,8,9,10,11,12}, sd::DataType::INT8); + NDArray b('c', {K,N}, {-2,-3,0,1,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::FLOAT32); + NDArray c('c', {M,N}, sd::DataType::DOUBLE); + + NDArray exp('c', {M,N}, {-45., 43., -49., 53., -50., -97., 79., -101., 113., -90., -149., 115., -153., 173., -130.}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxM_22) { + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('c', {M,K}, {1.,2,3,4,5,6,7,8,9,10,11,12}, sd::DataType::FLOAT32); + NDArray b('c', {K,N}, {-2,-3,0,1,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::HALF); + NDArray c('c', {M,N}, sd::DataType::FLOAT32); + + NDArray exp('c', {M,N}, {-45., 43., -49., 53., -50., -97., 79., -101., 113., -90., -149., 115., -153., 173., -130.}, sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + // c.printBuffer(); + + ASSERT_TRUE(c.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxM_23) { + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::HALF); + NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::HALF); + NDArray c('c', {M,N}, sd::DataType::DOUBLE); + + NDArray exp('c', {M,N}, {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, -3.9, 3.7, -3.5, 0.2, -0.4, 0.6, -0.8, 1.}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp, 0.01)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxM_24) { + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::HALF); + NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::FLOAT32); + NDArray c('c', {M,N}, sd::DataType::DOUBLE); + + NDArray exp('c', {M,N}, {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, -3.9, 3.7, -3.5, 0.2, -0.4, 0.6, -0.8, 1.}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp, 0.01)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxM_25) { + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + NDArray b('c', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::FLOAT32); + NDArray c('c', {M,N}, sd::DataType::HALF); + + NDArray exp('c', {M,N}, {-8.8, 8.6, -8.4, 8.2, -8.0, -4.3, 4.1, -3.9, 3.7, -3.5, 0.2, -0.4, 0.6, -0.8, 1.}, sd::DataType::HALF); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp, 0.01)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxM_26) { + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + // 3x4 * 4x5 = 3x5 + NDArray a('c', {M,K}, {1.,2,3,4,5,6,7,8,9,10,11,12}, sd::DataType::INT64); + NDArray b('c', {K,N}, {-2,-3,0,1,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::FLOAT32); + NDArray c('c', {M,N}, sd::DataType::DOUBLE); + + NDArray exp('c', {M,N}, {-45., 43., -49., 53., -50., -97., 79., -101., 113., -90., -149., 115., -153., 173., -130.}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + // c.printBuffer(); + + ASSERT_TRUE(c.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxM_27) { + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('f', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::HALF); + NDArray b('f', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::DOUBLE); + NDArray c('f', {M,N}, sd::DataType::FLOAT32); + + NDArray exp('f', {M,N}, {0.1, 0.3, 0.5, 2.5, 2.7, 2.9, 4.9, 5.1, 5.3, 7.3, 7.5, 7.7, 9.7, 9.9, 10.1}, sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + // c.printBuffer(); + + ASSERT_TRUE(c.equalsTo(&exp, 0.01)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxM_28) { + + const Nd4jLong M = 3; + const Nd4jLong K = 4; + const Nd4jLong N = 5; + + NDArray a('c', {M,K}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); + NDArray b('f', {K,N}, {1,-2,3,-4,5,-6,7,-8,9,-10,11,-12,13,-14,15,-16,17,-18,19,-20}, sd::DataType::DOUBLE); + NDArray c('f', {M,N}, sd::DataType::FLOAT32); + + NDArray exp('f', {M,N}, {-1.6, -0.7, 0.2, -0.8, 0.1, 1., -0., 0.9, 1.8, 0.8, 1.7, 2.6, 1.6, 2.5, 3.4}, sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &b, &c, 1., 0.); + + ASSERT_TRUE(c.equalsTo(&exp)); +} + */ + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxV_1) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('f', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + NDArray x('f', {N}, {1,-2,3,-4}, sd::DataType::DOUBLE); + NDArray y('f', {M}, sd::DataType::DOUBLE); + + NDArray exp('f', {M}, {0.1, 0.3, 0.5}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxV_2) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + NDArray x('f', {N}, {1,-2,3,-4}, sd::DataType::DOUBLE); + NDArray y('f', {M}, sd::DataType::DOUBLE); + + NDArray exp('f', {M}, {-1.6, -0.7, 0.2}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxV_3) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + NDArray x('c', {N}, {1,-2,3,-4}, sd::DataType::DOUBLE); + NDArray y('f', {M}, sd::DataType::DOUBLE); + + NDArray exp('f', {M}, {-1.6, -0.7, 0.2}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxV_4) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + NDArray x('c', {N}, {1,-2,3,-4}, sd::DataType::DOUBLE); + NDArray y('c', {M}, sd::DataType::DOUBLE); + + NDArray exp('c', {M}, {-1.6, -0.7, 0.2}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxV_5) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('f', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + NDArray x('c', {N}, {1,-2,3,-4}, sd::DataType::DOUBLE); + NDArray y('c', {M}, sd::DataType::DOUBLE); + + NDArray exp('c', {M}, {0.1, 0.3, 0.5}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxV_6) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('f', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + NDArray temp('f', {M,N,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(6, {0,2}); + NDArray y('f', {M}, sd::DataType::DOUBLE); + + NDArray exp('f', {M}, {5.5, 5.1, 4.7}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxV_7) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + a.permutei({1,0}); + NDArray temp('f', {M,N,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(6, {0,2}); + NDArray y('f', {M}, sd::DataType::DOUBLE); + + NDArray exp('f', {M}, {5.1, 3.3, 1.5}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxV_8) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + a.permutei({1,0}); + NDArray temp('f', {N,M,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(4, {1,2}); + NDArray y('f', {M}, sd::DataType::DOUBLE); + + NDArray exp('f', {M}, {6.2, 4.5, 1.7}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxV_9) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + a.permutei({1,0}); + NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(3, {0,1}); + NDArray y('f', {M}, sd::DataType::DOUBLE); + + NDArray exp('f', {M}, {1.5, 1.8, 1.5}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxV_10) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + a.permutei({1,0}); + NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(2, {0,1}); + NDArray y('f', {M}, sd::DataType::DOUBLE); + + NDArray exp('f', {M}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxV_11) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + a.permutei({1,0}); + NDArray temp('c', {5,N,M}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(13, {0,2}); + NDArray y('f', {M}, sd::DataType::DOUBLE); + + NDArray exp('f', {M}, {-12.1, -10.9, -9.7}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxV_12) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + a.permutei({1,0}); + NDArray temp('c', {5,N,M}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(10, {0,2}); + NDArray y('c', {M}, sd::DataType::DOUBLE); + + NDArray exp('c', {M}, {3.3, 3.3, 3.3}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxV_13) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + a.permutei({1,0}); + NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(2, {0,1}, true); + NDArray y('f', {M}, sd::DataType::DOUBLE); + + NDArray exp('f', {M}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxV_14) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + a.permutei({1,0}); + NDArray temp('c', {5,N,M}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(10, {0,2}, true); + NDArray y('c', {M}, sd::DataType::DOUBLE); + + NDArray exp('c', {M}, {3.3, 3.3, 3.3}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxV_15) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + a.permutei({1,0}); + NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(2, {0,1}); + NDArray y = temp(17, {0,2}); + + NDArray exp('f', {M}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxV_16) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + a.permutei({1,0}); + NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray temp1('c', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(2, {0,1}); + NDArray y = temp1(17, {0,2}); + + NDArray exp('c', {M}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxV_17) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + a.permutei({1,0}); + NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(2, {0,1}); + NDArray y = temp(17, {0,2}, true); + // y.printShapeInfo(); + + NDArray exp('f', {1,M,1}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxV_18) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + a.permutei({1,0}); + NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray temp1('c', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(2, {0,1},true); + NDArray y = temp1(17, {0,2},true); + + NDArray exp('c', {1,M,1}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +/* +TEST_F(CudaBasicsTests2, mmulMxV_19) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('f', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); + NDArray x('f', {N}, {1,-2,3,-4}, sd::DataType::DOUBLE); + NDArray y('f', {M}, sd::DataType::FLOAT32); + + NDArray exp('f', {M}, {0.1, 0.3, 0.5}, sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxV_20) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); + NDArray x('f', {N}, {1,-2,3,-4}, sd::DataType::DOUBLE); + NDArray y('f', {M}, sd::DataType::FLOAT32); + NDArray exp('f', {M}, {-1.6, -0.7, 0.2}, sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxV_21) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); + NDArray x('c', {N}, {1,-2,3,-4}, sd::DataType::DOUBLE); + NDArray y('c', {M}, sd::DataType::FLOAT32); + NDArray exp('c', {M}, {-1.6, -0.7, 0.2}, sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxV_22) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('f', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); + NDArray temp('f', {M,N,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(6, {0,2}); + NDArray y('f', {M}, sd::DataType::FLOAT32); + + NDArray exp('f', {M}, {5.5, 5.1, 4.7}, sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxV_23) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); + a.permutei({1,0}); + NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(3, {0,1}); + NDArray y('f', {M}, sd::DataType::FLOAT32); + + NDArray exp('f', {M}, {1.5, 1.8, 1.5}, sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxV_24) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('f', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); + NDArray temp('f', {M,N,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(6, {0,2},true); + NDArray y('f', {M}, sd::DataType::FLOAT32); + + NDArray exp('f', {M}, {5.5, 5.1, 4.7}, sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxV_25) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); + a.permutei({1,0}); + NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(3, {0,1}, true); + NDArray y('f', {M}, sd::DataType::FLOAT32); + + NDArray exp('f', {M}, {1.5, 1.8, 1.5}, sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxV_26) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); + a.permutei({1,0}); + NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray temp1('c', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::FLOAT32); + NDArray x = temp(2, {0,1}); + NDArray y = temp1(17, {0,2}); + + NDArray exp('c', {M}, {-0.3, 0.3, 0.9}, sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxV_27) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); + a.permutei({1,0}); + NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray temp1('c', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::FLOAT32); + NDArray x = temp(2, {0,1},true); + NDArray y = temp1(17, {0,2},true); + + NDArray exp('c', {1,M,1}, {-0.3, 0.3, 0.9}, sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulMxV_28) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::FLOAT32); + NDArray temp('f', {M,N,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(6, {0,2}); + NDArray y('f', {M}, sd::DataType::FLOAT32); + + NDArray exp('f', {M}, {5.1, 3.3, 1.5}, sd::DataType::FLOAT32); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulDot_1) { + + const Nd4jLong N = 4; + + NDArray x('c', {N}, {1, 2, 3, 4}, sd::DataType::INT32); + NDArray y('f', {N}, {0.1, 0.2, 0.3, 0.4}, sd::DataType::FLOAT32); + NDArray z(sd::DataType::DOUBLE); + + NDArray exp('c', {}, {3}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&x, &y, &z); + ASSERT_TRUE(z.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulDot_2) { + + const Nd4jLong N = 4; + + NDArray x('c', {1,1,N}, {1,2, 3, 4}, sd::DataType::INT32); + NDArray y('f', {1,1,N,1,1,1}, {0.1, 0.2, 0.3, 0.4}, sd::DataType::FLOAT32); + NDArray z(sd::DataType::DOUBLE); + + NDArray exp('c', {}, {3}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&x, &y, &z); + ASSERT_TRUE(z.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulDot_3) { + + const Nd4jLong N = 4; + + NDArray xBig('c', {4,2}, {1, 0, 2, 0, 3, 0, 4, 0}, sd::DataType::INT32); + NDArray yBig('c', {4,3}, {0.1, 0, 0, 0.2, 0, 0, 0.3, 0, 0, 0.4, 0,0}, sd::DataType::FLOAT32); + NDArray x = xBig(0, {1}, true); + NDArray y = yBig(0, {1}, true); + NDArray z(sd::DataType::DOUBLE); + + NDArray exp('c', {}, {3}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&x, &y, &z); + ASSERT_TRUE(z.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(CudaBasicsTests2, mmulDot_4) { + + const Nd4jLong N = 4; + + NDArray xBig('f', {4,2}, {1, 2, 3, 4, 0, 0, 0, 0}, sd::DataType::INT32); + NDArray yBig('c', {4,3}, {0.1, 0, 0, 0.2, 0, 0, 0.3, 0, 0, 0.4, 0,0}, sd::DataType::FLOAT32); + NDArray x = xBig(0, {1}, true); + NDArray y = yBig(0, {1}); + NDArray z(sd::DataType::DOUBLE); + + NDArray exp('c', {}, {3}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&x, &y, &z); + ASSERT_TRUE(z.equalsTo(&exp)); +} + */ \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/CudaExtraArgumentsTests.cu b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/CudaExtraArgumentsTests.cu new file mode 100644 index 000000000..b1cce9fab --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/CudaExtraArgumentsTests.cu @@ -0,0 +1,76 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include + +using namespace sd; + +class CudaExtraArgumentsTests : public testing::Test { +public: + + CudaExtraArgumentsTests() { + printf("\n"); + fflush(stdout); + } +}; + +TEST_F(CudaExtraArgumentsTests, Basic_Test_1) { + ExtraArguments args({1.0, 2.0, 3.0}); + + float ef[] = {1.f, 2.f, 3.f}; + double ed[] = {1., 2., 3.}; + + auto ptrFloat = reinterpret_cast(args.argumentsAsT()); + auto ptrDouble = reinterpret_cast(args.argumentsAsT()); + ASSERT_TRUE(ptrFloat != nullptr); + ASSERT_TRUE(ptrDouble != nullptr); + + auto tmpFloat = new float[3]; + auto tmpDouble = new double[3]; + + cudaMemcpy(tmpFloat, ptrFloat, 3 * sizeof(float), cudaMemcpyDeviceToHost); + cudaMemcpy(tmpDouble, ptrDouble, 3 * sizeof(double), cudaMemcpyDeviceToHost); + + for (int e = 0; e < 3; e++) { + ASSERT_NEAR(ef[e], tmpFloat[e], 1e-5f); + } + + for (int e = 0; e < 3; e++) { + ASSERT_NEAR(ed[e], tmpDouble[e], 1e-5); + } + + delete[] tmpFloat; + delete[] tmpDouble; +} + + +TEST_F(CudaExtraArgumentsTests, Basic_Test_2) { + ExtraArguments args; + + auto ptrInt = args.argumentsAsT(); + ASSERT_TRUE(ptrInt == nullptr); +} + diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/CudaLaunchHelperTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/CudaLaunchHelperTests.cpp new file mode 100644 index 000000000..97dc00b3a --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/CudaLaunchHelperTests.cpp @@ -0,0 +1,48 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver on 11/26/2018. +// + +#include "testlayers.h" +#include + +using namespace sd; +using namespace sd::graph; + +class CudaLaunchHelperTests : public testing::Test { +public: + +}; + +TEST_F(CudaLaunchHelperTests, test_reduction_blocks_1) { + ASSERT_EQ(1, CudaLaunchHelper::getReductionBlocks(512)); +} + +TEST_F(CudaLaunchHelperTests, test_reduction_blocks_2) { + ASSERT_EQ(1, CudaLaunchHelper::getReductionBlocks(121)); +} + +TEST_F(CudaLaunchHelperTests, test_reduction_blocks_3) { + ASSERT_EQ(2, CudaLaunchHelper::getReductionBlocks(513)); +} + +TEST_F(CudaLaunchHelperTests, test_reduction_blocks_4) { + ASSERT_EQ(3, CudaLaunchHelper::getReductionBlocks(1225)); +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DataBufferTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DataBufferTests.cpp new file mode 100644 index 000000000..6e74ba7da --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DataBufferTests.cpp @@ -0,0 +1,80 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace sd; +using namespace sd::graph; +using namespace sd::memory; + +class DataBufferTests : public testing::Test { +public: + +}; + +TEST_F(DataBufferTests, test_alloc_limit_1) { + if (!Environment::getInstance().isCPU()) + return; + + auto deviceId = AffinityManager::currentDeviceId(); + auto odLimit = MemoryCounter::getInstance().deviceLimit(deviceId); + auto ogLimit = MemoryCounter::getInstance().groupLimit(MemoryType::HOST); + auto odUse = MemoryCounter::getInstance().allocatedDevice(deviceId); + auto ogUse = MemoryCounter::getInstance().allocatedGroup(MemoryType::HOST); + + auto limitSize = odUse + (150 * 1024 * 1024); + auto allocSize = 100000000; + + MemoryCounter::getInstance().setDeviceLimit(deviceId, odLimit + limitSize); + MemoryCounter::getInstance().setGroupLimit(MemoryType::HOST, odLimit + limitSize); + + DataBuffer buffer(allocSize, DataType::INT32); + + // separately testing per-device limits and group limits + ASSERT_EQ(odUse + allocSize, MemoryCounter::getInstance().allocatedDevice(deviceId)); + ASSERT_EQ(ogUse + allocSize, MemoryCounter::getInstance().allocatedGroup(MemoryType::HOST)); + + + // setting smaller limits, to make sure next allocation fails with OOM exception + MemoryCounter::getInstance().setDeviceLimit(deviceId, allocSize - 100); + MemoryCounter::getInstance().setGroupLimit(MemoryType::HOST, allocSize - 100); + + try { + DataBuffer bufferFailed(allocSize, DataType::INT32); + ASSERT_TRUE(false); + } catch (allocation_exception &e) { + // we expect exception here + } + + // restore original limits, so subsequent tests do not fail + MemoryCounter::getInstance().setDeviceLimit(deviceId, odLimit); + MemoryCounter::getInstance().setGroupLimit(MemoryType::HOST, odLimit); +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DataBufferTestsCuda.cu b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DataBufferTestsCuda.cu new file mode 100644 index 000000000..1368c30bd --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DataBufferTestsCuda.cu @@ -0,0 +1,91 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace sd; +using namespace sd::graph; +using namespace sd::memory; + +class DataBufferTestsCuda : public testing::Test { +public: + +}; + +/* +TEST_F(DataBufferTestsCuda, test_alloc_limit_1) { + auto deviceId = AffinityManager::currentDeviceId(); + + auto odLimit = MemoryCounter::getInstance().deviceLimit(deviceId); + + auto opLimit = MemoryCounter::getInstance().groupLimit(MemoryType::HOST); + auto osLimit = MemoryCounter::getInstance().groupLimit(MemoryType::DEVICE); + + auto odUse = MemoryCounter::getInstance().allocatedDevice(deviceId); + + auto opUse = MemoryCounter::getInstance().allocatedGroup(MemoryType::HOST); + auto osUse = MemoryCounter::getInstance().allocatedGroup(MemoryType::DEVICE); + + auto limitSize = odUse + 150000000; + auto allocSize = 100000000; + + MemoryCounter::getInstance().setDeviceLimit(deviceId, odLimit + limitSize); + MemoryCounter::getInstance().setGroupLimit(MemoryType::HOST, opLimit + limitSize); + MemoryCounter::getInstance().setGroupLimit(MemoryType::DEVICE, osLimit + limitSize); + + DataBuffer buffer(allocSize, DataType::INT32, nullptr, true); + + // separately testing per-device limits and group limits + ASSERT_EQ(odUse + allocSize, MemoryCounter::getInstance().allocatedDevice(deviceId)); + ASSERT_EQ(opUse + allocSize, MemoryCounter::getInstance().allocatedGroup(MemoryType::HOST)); + ASSERT_EQ(osUse + allocSize, MemoryCounter::getInstance().allocatedGroup(MemoryType::DEVICE)); + + // setting smaller limits, to make sure next allocation fails with OOM exception + MemoryCounter::getInstance().setDeviceLimit(deviceId, allocSize - 100); + MemoryCounter::getInstance().setGroupLimit(MemoryType::DEVICE, allocSize - 100); + + + // this allocation should fail, since we're allocating too much + try { + DataBuffer bufferFailed(allocSize + 1, DataType::INT32); + ASSERT_TRUE(false); + } catch (allocation_exception &e) { + // we expect exception here + } + + // + + // restore original limits, so subsequent tests do not fail + MemoryCounter::getInstance().setDeviceLimit(deviceId, odLimit); + MemoryCounter::getInstance().setGroupLimit(MemoryType::HOST, opLimit); + MemoryCounter::getInstance().setGroupLimit(MemoryType::DEVICE, osLimit); +} + */ \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DataTypesValidationTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DataTypesValidationTests.cpp new file mode 100644 index 000000000..c6fe48ca4 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DataTypesValidationTests.cpp @@ -0,0 +1,158 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace sd; +using namespace sd::graph; + +class DataTypesValidationTests : public testing::Test { +public: + +}; + +TEST_F(DataTypesValidationTests, Basic_Test_1) { + auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto exp = NDArrayFactory::create('c', {1, 4, 1, 4}, {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.}); + + weights.assign(2.0); + input.linspace(1); + + sd::ops::conv2d op; + auto result = op.evaluate({&input, &weights}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}); + + ASSERT_EQ(ND4J_STATUS_VALIDATION, result.status()); +} + +TEST_F(DataTypesValidationTests, Basic_Test_2) { + auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto exp = NDArrayFactory::create('c', {1, 4, 1, 4}, {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.}); + + weights.assign(2.0); + input.linspace(1); + + sd::ops::conv2d op; + auto result = op.evaluate({&input, &weights}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + +} + + +TEST_F(DataTypesValidationTests, Basic_Test_3) { + auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto exp = NDArrayFactory::create('c', {1, 4, 1, 4}, {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.}); + auto out = NDArrayFactory::create('c', {1, 4, 1, 4}); + + weights.assign(2.0); + input.linspace(1); + + sd::ops::conv2d op; + auto result = op.execute({&input, &weights}, {&out}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {}); + ASSERT_EQ(Status::OK(), result); + + ASSERT_EQ(exp, out); +} + +TEST_F(DataTypesValidationTests, Basic_Test_4) { + auto input = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto weights = NDArrayFactory::create('c', {1, 1, 1, 4}); + auto exp = NDArrayFactory::create('c', {1, 4, 1, 4}, {2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8., 2., 4., 6., 8.}); + auto out = NDArrayFactory::create('c', {1, 4, 1, 4}); + + weights.assign(2.0); + input.linspace(1); + + sd::ops::conv2d op; + auto result = op.execute({&input, &weights}, {&out}, {}, {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}, {}); + ASSERT_EQ(ND4J_STATUS_VALIDATION, result); +} + +TEST_F(DataTypesValidationTests, test_bfloat16_rand_1) { + auto x = NDArrayFactory::create('c', {5, 10}); + RandomGenerator gen(119, 120); + RandomLauncher::fillUniform(LaunchContext::defaultContext(), gen, &x, 1, 6); + + ASSERT_TRUE(x.sumNumber().e(0) != 0.f); +} + +TEST_F(DataTypesValidationTests, test_bfloat16_rand_2) { + auto x = NDArrayFactory::create('c', {5, 10}); + RandomGenerator gen(119, 120); + RandomLauncher::fillGaussian(LaunchContext::defaultContext(), gen, &x, 0, 1); + + ASSERT_TRUE(x.sumNumber().e(0) != 0.f); +} + +TEST_F(DataTypesValidationTests, cast_1) { + + float16 x = static_cast(1.f); + float y = static_cast(x); + + ASSERT_TRUE(static_cast(1.f) == x); + ASSERT_TRUE(y == static_cast(x)); +} + +TEST_F(DataTypesValidationTests, test_bits_hamming_distance_1) { + auto x = NDArrayFactory::create('c', {3}, {0b01011000, 0b01011111, 0b01111110}); + auto y = NDArrayFactory::create('c', {3}, {0b00010110, 0b01011000, 0b01011000}); + auto z = NDArrayFactory::create(0); + + Context ctx(1); + ctx.setInputArray(0, &x); + ctx.setInputArray(1, &y); + ctx.setOutputArray(0, &z); + + sd::ops::bits_hamming_distance op; + auto status = op.execute(&ctx); + ASSERT_NE(Status::OK(), status); +} + +TEST_F(DataTypesValidationTests, test_bits_hamming_distance_2) { + auto x = NDArrayFactory::create('c', {3}, {0b01011000, 0b01011111, 0b01111110}); + auto y = NDArrayFactory::create('c', {3}, {0b00010110, 0b01011000, 0b01011000}); + auto z = NDArrayFactory::create(0); + + Context ctx(1); + ctx.setInputArray(0, &x); + ctx.setInputArray(1, &y); + ctx.setOutputArray(0, &z); + + sd::ops::bits_hamming_distance op; + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests1.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests1.cpp new file mode 100644 index 000000000..99ceddd0c --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests1.cpp @@ -0,0 +1,3370 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // + // @author raver119@gmail.com + // + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace sd; +using namespace sd::graph; + +class DeclarableOpsTests1 : public testing::Test { +public: + + const int bS = 2; // batch size + const int iD = 1; // input depth (number of picture channels, for example rgb=3) + const int iH = 28; // picture height in pixels + const int iW = 28; // picture width in pixels + const int oD = 3; // output depth (= N for dense layer) + const int kH = 5; // kernel height in pixels + const int kW = 5; // kernel width in pixels + const int sH = 1; // stride step in horizontal direction + const int sW = 1; // stride step in vertical direction + const int pH = 0; // padding height + const int pW = 0; // padding width + const int dH = 2; // dilation height + const int dW = 2; // dilation width + const int oH = (iH - kH - (kH - 1) * (dH - 1) + 2 * pH) / sH + 1; // output height + const int oW = (iW - kW - (kW - 1) * (dW - 1) + 2 * pW) / sW + 1; // output width + + DeclarableOpsTests1() { + sd::memory::MemoryTracker::getInstance().reset(); + } + + ~DeclarableOpsTests1() { + sd::memory::MemoryTracker::getInstance().summarize(); + } +}; + +template +class TypedDeclarableOpsTests1 : public testing::Test { +public: + + const int bS = 2; // batch size + const int iD = 1; // input depth (number of picture channels, for example rgb=3) + const int iH = 28; // picture height in pixels + const int iW = 28; // picture width in pixels + const int oD = 3; // output depth (= N for dense layer) + const int kH = 5; // kernel height in pixels + const int kW = 5; // kernel width in pixels + const int sH = 1; // stride step in horizontal direction + const int sW = 1; // stride step in vertical direction + const int pH = 0; // padding height + const int pW = 0; // padding width + const int dH = 2; // dilation height + const int dW = 2; // dilation width + const int oH = (iH - kH - (kH - 1) * (dH - 1) + 2 * pH) / sH + 1; // output height + const int oW = (iW - kW - (kW - 1) * (dW - 1) + 2 * pW) / sW + 1; // output width + + TypedDeclarableOpsTests1() { + printf("\n"); + } +}; + +typedef ::testing::Types TestingTypes; +TYPED_TEST_CASE(TypedDeclarableOpsTests1, TestingTypes); + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, BasicInitialization1) { + auto concat = new sd::ops::concat(); + std::string expName("concat"); + ASSERT_EQ(expName, *(concat->getOpName())); + + auto x0 = NDArrayFactory::create_('c', { 1, 5 }); + auto x1 = NDArrayFactory::create_('c', { 1, 5 }); + auto x2 = NDArrayFactory::create_('c', { 1, 5 }); + auto x3 = NDArrayFactory::create_('c', { 1, 5 }); + auto x4 = NDArrayFactory::create_('c', { 1, 5 }); + + x0->assign(1.0f); + x1->assign(1.0f); + x2->assign(1.0f); + x3->assign(1.0f); + x4->assign(1.0f); + + auto variableSpace = new VariableSpace(); + + variableSpace->putVariable(-1, x0); + variableSpace->putVariable(-2, x1); + variableSpace->putVariable(-3, x2); + variableSpace->putVariable(-4, x3); + variableSpace->putVariable(-5, x4); + + auto nodeVar = new Variable(); + + variableSpace->putVariable(1, nodeVar); + + Context block(1, variableSpace); + block.getIArguments()->push_back(1); + block.fillInputs({ -1, -2, -3, -4, -5 }); + + ASSERT_FALSE(nodeVar->hasNDArray()); + + Nd4jStatus result = concat->execute(&block); + + ASSERT_TRUE(nodeVar->hasNDArray()); + + ASSERT_EQ(25, nodeVar->getNDArray()->lengthOf()); + + ASSERT_NEAR(25.0, nodeVar->getNDArray()->reduceNumber(reduce::Sum).e(0), 1e-5); + + ASSERT_EQ(ND4J_STATUS_OK, result); + + + delete variableSpace; + delete concat; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, BasicInitialization2) { + auto op = sd::ops::OpRegistrator::getInstance().getOperation("concat"); + + ASSERT_TRUE(op != nullptr); + std::string expName("concat"); + ASSERT_EQ(expName, *(op->getOpName())); + + ASSERT_EQ(-1, op->getOpDescriptor()->getNumberOfInputs()); + ASSERT_EQ(1, op->getOpDescriptor()->getNumberOfOutputs()); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, ApplyGradientDescent_1) { + auto x = NDArrayFactory::create('c', { 3,4 }, { 1,2,3,4,5,6,7,8,9,10,11,12 }); + auto y = NDArrayFactory::create('c', { 3,4 }, { 0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0,1.1,1.2 }); + auto exp = NDArrayFactory::create('c', { 3,4 }); + exp.linspace(0.9, 0.9); + sd::ops::apply_sgd op; + auto result = op.evaluate({ &x, &y }, { 1. }, {}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto z = result.at(0); + + ASSERT_TRUE(z->equalsTo(exp)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, AssignBroadcastTest_1) { + auto x = NDArrayFactory::create('c', { 3,4 }, { 1,2,3,4,5,6,7,8,9,10,11,12 }); + auto y = NDArrayFactory::create('c', { 1,4 }, { 0.1,0.2,0.3,0.4 }); + auto exp = NDArrayFactory::create('c', { 3,4 }, { 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4 }); + sd::ops::assign op; + auto result = op.evaluate({ &x, &y }); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto z = result.at(0); + + ASSERT_TRUE(z->equalsTo(exp)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, AssignBroadcastTest_2) { + auto x = NDArrayFactory::create('c', { 3,4 }, { 1,2,3,4,5,6,7,8,9,10,11,12 }); + auto y = NDArrayFactory::create('c', { 1,4 }, { 0.1,0.2,0.3,0.4 }); + auto eps = NDArrayFactory::create('c', { 3,4 }, { 1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4 }); + auto exp1 = NDArrayFactory::create('c', { 3,4 }); // zero + auto exp2 = NDArrayFactory::create('c', { 1,4 }, { 3, 6, 9, 12 }); + sd::ops::assign_bp op; + auto result = op.evaluate({ &x, &y, &eps }); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto z1 = result.at(0); + auto z2 = result.at(1); + + ASSERT_TRUE(z1->equalsTo(exp1)); + ASSERT_TRUE(z2->equalsTo(exp2)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, AXpY_Test_1) { + auto x = NDArrayFactory::create('c', { 3,4 }, { 1,2,3,4,5,6,7,8,9,10,11,12 }); + auto y = NDArrayFactory::create('c', { 3,4 }, { 1,2,3,4,5,6,7,8,9,10,11,12 }); + auto exp = NDArrayFactory::create('c', { 3,4 }); + exp.linspace(3, 3); + sd::ops::axpy op; + auto result = op.evaluate({ &x, &y }, { 2. }); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto z = result.at(0); + + ASSERT_TRUE(z->equalsTo(exp)); + +} + +TEST_F(DeclarableOpsTests1, BasicInitialization3) { + auto op1 = sd::ops::OpRegistrator::getInstance().getOperation("concat"); + std::string expName("concat"); + auto hash = sd::ops::HashHelper::getInstance().getLongHash(expName); + + auto op2 = sd::ops::OpRegistrator::getInstance().getOperation(hash); + + ASSERT_TRUE(op1 == op2); +} + + +TEST_F(DeclarableOpsTests1, SynonymInitialization2) { + auto op = sd::ops::OpRegistrator::getInstance().getOperation("Mul"); + auto op2 = sd::ops::OpRegistrator::getInstance().getOperation("multiply"); + + ASSERT_TRUE(op != nullptr); + std::string expName("multiply"); + ASSERT_EQ(expName, *(op->getOpName())); + ASSERT_TRUE(op == op2); +} + + +TEST_F(DeclarableOpsTests1, TestTensorMmul1) { + + NDArray x('c', { 2, 3, 4 }, sd::DataType::FLOAT32); + NDArray y('c', { 2, 3, 4 }, sd::DataType::FLOAT32); + + x.linspace(1); + y.linspace(1); + + NDArray exp('c', { 2, 2 }, { 650.0, 1586.0, 1586.0, 4250.0 }, sd::DataType::FLOAT32); + + sd::ops::tensormmul op; + auto results = op.evaluate({ &x, &y }, {}, { 2,1,2,2,1,2 }); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto* out = results.at(0); + + ASSERT_TRUE(exp.isSameShape(out)); + ASSERT_TRUE(exp.equalsTo(out)); + + +} + +TEST_F(DeclarableOpsTests1, TestTensorDot2) { + + NDArray x('f', { 2, 3, 4 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. }, sd::DataType::FLOAT32); + NDArray y('f', { 2, 3, 4 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. }, sd::DataType::FLOAT32); + + NDArray exp('c', { 2, 2 }, { 2300.0, 2444.0, 2444.0, 2600.0 }, sd::DataType::FLOAT32); + + sd::ops::tensormmul op; + auto results = op.evaluate({ &x, &y }, {}, { 2,1,2,2,1,2 }); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto* out = results.at(0); + + ASSERT_TRUE(exp.isSameShape(out)); + ASSERT_TRUE(exp.equalsTo(out)); + + +} + +TEST_F(DeclarableOpsTests1, TestTensorDot3) { + + NDArray x('c', { 2, 3, 4 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. }, sd::DataType::FLOAT32); + NDArray y('f', { 2, 3, 4 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. }, sd::DataType::FLOAT32); + + NDArray exp('f', { 2, 2 }, { 1090.0, 2818.0, 1168.0, 3040.0 }, sd::DataType::FLOAT32); + + sd::ops::tensormmul op; + auto results = op.evaluate({ &x, &y }, {}, { 2,1,2,2,1,2 }); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto* out = results.at(0); + + ASSERT_TRUE(exp.isSameShape(out)); + ASSERT_TRUE(exp.equalsTo(out)); + + +} + +TEST_F(DeclarableOpsTests1, TestTensorDot4) { + + NDArray x('f', { 2, 3, 4 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. }, sd::DataType::FLOAT32); + NDArray y('c', { 2, 3, 4 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. }, sd::DataType::FLOAT32); + + NDArray exp('f', { 2, 2 }, { 1090.0, 1168.0, 2818.0, 3040.0 }, sd::DataType::FLOAT32); + + sd::ops::tensormmul op; + auto results = op.evaluate({ &x, &y }, {}, { 2,1,2,2,1,2 }); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto* out = results.at(0); + + ASSERT_TRUE(exp.isSameShape(out)); + ASSERT_TRUE(exp.equalsTo(out)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, TestTensorDot5) { + + auto x = NDArrayFactory::create('c', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); + auto y = NDArrayFactory::create('c', { 2,4,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); + auto expected = NDArrayFactory::create('c', { 2,4,2,4 }, { 44,110,160, 66,132, 38, 88,154, 68,170,224,102,204, 82,136,238, 92,230,288,138,276,126,184,322, 116,290,352,174,348,170,232,406, 76,190,160,114,228,182,152,266, 100,250,224,150,300,226,200,350, 124,310,288,186,372,270,248,434, 148,370,352,222,444,314,296,518 }); + + sd::ops::tensormmul op; + auto results = op.evaluate({ &x, &y }, {}, { 1,1,1,2 }); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto* result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + + +} + + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, TestTensorDot6) { + + auto x = NDArrayFactory::create('c', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); + auto y = NDArrayFactory::create('f', { 2,4,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); + auto expected = NDArrayFactory::create('c', { 2,4,2,4 }, { 22, 66,110,154, 44, 88,132,176, 34,102,170,238, 68,136,204,272, 46,138,230,322, 92,184,276,368, 58,174,290,406,116,232,348,464, 38,114,190,266, 76,152,228,304, 50,150,250,350,100,200,300,400, 62,186,310,434,124,248,372,496, 74,222,370,518,148,296,444,592 }); + + sd::ops::tensormmul op; + auto results = op.evaluate({ &x, &y }, {}, { 1,1,1,2 }); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto* result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, TestTensorDot7) { + + auto x = NDArrayFactory::create('f', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); + auto y = NDArrayFactory::create('c', { 2,4,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); + auto expected = NDArrayFactory::create('c', { 2,4,2,4 }, { 76,166,112,106,196, 62,136,226, 60,174,208, 98,212,230,136,250, 76,214,336,122,260,174,168,306, 124,286,240,178,340,150,232,394, 100,226,176,142,268,106,184,310, 84,234,272,134,284,274,184,334, 100,274,400,158,332,218,216,390, 148,346,304,214,412,194,280,478 }); + + sd::ops::tensormmul op; + auto results = op.evaluate({ &x, &y }, {}, { 1,1,1,2 }); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto* result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, TestTensorDot8) { + + auto x = NDArrayFactory::create('f', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); + auto y = NDArrayFactory::create('f', { 2,4,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); + auto expected = NDArrayFactory::create('c', { 2,4,2,4 }, { 30, 90,150,210, 60,120,180,240, 38,114,190,266, 76,152,228,304, 46,138,230,322, 92,184,276,368, 54,162,270,378,108,216,324,432, 42,126,210,294, 84,168,252,336, 50,150,250,350,100,200,300,400, 58,174,290,406,116,232,348,464, 66,198,330,462,132,264,396,528 }); + + sd::ops::tensormmul op; + auto results = op.evaluate({ &x, &y }, {}, { 1,1,1,2 }); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto* result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, TestTensorDot9) { + + // NDArray z('f',{2,2,3}, sd::DataType::DOUBLE); + // z.linspace(1); + // z.printShapeInfo(); + // z.printIndexedBuffer(); + // z.reshapei('c', {4,3}); + // z.printShapeInfo(); + // z.printIndexedBuffer(); + + auto x = NDArrayFactory::create('f', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); + auto y = NDArrayFactory::create('f', { 2,4,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); + auto expected = NDArrayFactory::create('c', { 3,4,4,3 }, { 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422, 62, 62, 62,142,142,142,222,222,222,302,302,302, 62, 62, 62,142,142,142,222,222,222,302,302,302, 38, 38, 38, 86, 86, 86,134,134,134,182,182,182, 14, 14, 14, 30, 30, 30, 46, 46, 46, 62, 62, 62, 86, 86, 86,198,198,198,310,310,310,422,422,422 }); + + sd::ops::tensormmul op; + auto results = op.evaluate({ &x, &y }, {}, { 1,0,1,0 }); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto* result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, TestTensorDot10) { + + auto x = NDArrayFactory::create('f', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); + auto y = NDArrayFactory::create('f', { 2,4,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); + auto expected = NDArrayFactory::create('c', { 4,4 }, { 114,258,402,546, 138,314,490,666, 162,370,578,786, 186,426,666,906 }); + + sd::ops::tensormmul op; + auto results = op.evaluate({ &x, &y }, {}, { 2,0,1, 2,0,2 }); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto* result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + + +} + + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, TestTensorDot11) { + + auto x = NDArrayFactory::create('c', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); + auto y = NDArrayFactory::create('f', { 2,4,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); + auto expected = NDArrayFactory::create('c', { 4,4 }, { 98,218,338,458, 134,302,470,638, 170,386,602,818, 206,470,734,998 }); + + sd::ops::tensormmul op; + auto results = op.evaluate({ &x, &y }, {}, { 2,0,1, 2,0,2 }); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto* result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, TestTensorDot12) { + + auto x = NDArrayFactory::create('c', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); + auto y = NDArrayFactory::create('c', { 2,4,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); + auto expected = NDArrayFactory::create('c', { 4,4 }, { 272,292,312,332, 368,396,424,452, 464,500,536,572, 560,604,648,692 }); + + sd::ops::tensormmul op; + auto results = op.evaluate({ &x, &y }, {}, { 2,0,1, 2,0,2 }); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto* result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, TestTensorDot13) { + + auto x = NDArrayFactory::create('c', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); + auto y = NDArrayFactory::create('c', { 4,2,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); + auto expected = NDArrayFactory::create('c', { 3,3 }, { 640,560,640, 576,624,576, 640,560,640 }); + + sd::ops::tensormmul op; + auto results = op.evaluate({ &x, &y }, {}, { 2,0,2, 2,1,0 }); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto* result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, TestTensorDot14) { + + auto x = NDArrayFactory::create('f', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); + auto y = NDArrayFactory::create('c', { 4,2,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); + auto expected = NDArrayFactory::create('c', { 3,3 }, { 648,600,520, 648,536,648, 520,600,648 }); + + sd::ops::tensormmul op; + auto results = op.evaluate({ &x, &y }, {}, { 2,0,2, 2,1,0 }); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto* result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, TestTensorDot15) { + + auto x = NDArrayFactory::create('f', { 2,3,4 }, { 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15, 1,3,5,7,9,11,13,15 }); + auto y = NDArrayFactory::create('f', { 4,2,3 }, { 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16, 2,4,6,8,10,12,14,16 }); + auto expected = NDArrayFactory::create('c', { 3,3 }, { 624,624,624, 656,656,656, 624,624,624 }); + + sd::ops::tensormmul op; + auto results = op.evaluate({ &x, &y }, {}, { 2,0,2, 2,1,0 }); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto* result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, TestTensorDot16) { + + NDArray x('c', { 1 }, std::vector{2}, sd::DataType::FLOAT32); + NDArray y('c', { 2,1,2 }, { 1,2,3,4 }, sd::DataType::FLOAT32); + NDArray exp('c', { 2,2 }, { 2,4,6,8 }, sd::DataType::FLOAT32); + + sd::ops::tensormmul op; + auto results = op.evaluate({ &x, &y }, {}, { 1,0, 1,1 }); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto* result = results.at(0); + + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, TestTensorDot17) { + + NDArray x('f', { 16,16 }, sd::DataType::FLOAT32); + NDArray y('f', { 1000,16 }, sd::DataType::FLOAT32); + NDArray z('c', { 16,1000 }, sd::DataType::FLOAT32); + + sd::ops::tensormmul op; + auto status = op.execute({ &x, &y }, { &z }, {}, { 1,1, 1,1 }, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, DivergentCheck1) { + auto op = sd::ops::OpRegistrator::getInstance().getOperation("switch"); + + ASSERT_TRUE(op != nullptr); + std::string expName("Switch"); + ASSERT_EQ(expName, *(op->getOpName())); + ASSERT_TRUE(op->getOpDescriptor()->isDivergent()); + ASSERT_EQ(2, op->getOpDescriptor()->getNumberOfOutputs()); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, AddMatrices1) { + + auto x = NDArrayFactory::create_('c', { 5, 3 }); + auto y = NDArrayFactory::create_('c', { 5, 3 }); + auto exp = NDArrayFactory::create_('c', { 5, 3 }); + x->assign(2); + y->assign(1); + exp->assign(3); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + auto block = new Context(1, variableSpace, true); + block->fillInputs({ -1, -2 }); + + sd::ops::add addOp; + + addOp.execute(block); + + ASSERT_TRUE(x->equalsTo(exp)); + + delete exp; + delete block; + delete variableSpace; + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, AddVectorVector1) { + + auto x = NDArrayFactory::create_('c', { 1, 15 }); + auto y = NDArrayFactory::create_('c', { 1, 15 }); + auto exp = NDArrayFactory::create_('c', { 1, 15 }); + x->assign(2); + y->assign(1); + exp->assign(3); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + auto block = new Context(1, variableSpace, true); + block->fillInputs({ -1, -2 }); + + sd::ops::add addOp; + + addOp.execute(block); + + ASSERT_TRUE(x->equalsTo(exp)); + + delete exp; + delete block; + delete variableSpace; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, AddMatrixScalar1) { + + auto x = NDArrayFactory::create_('c', { 5, 3 }); + auto y = NDArrayFactory::create_('c', { 1, 1 }); + auto exp = NDArrayFactory::create('c', { 5, 3 }); + x->assign(2); + y->assign(1); + exp.assign(3); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + auto block = new Context(1, variableSpace, true); + block->fillInputs({ -1, -2 }); + + sd::ops::add addOp; + + addOp.execute(block); + + ASSERT_TRUE(x->equalsTo(&exp)); + + delete variableSpace; + delete block; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, AddScalarScalar1) { + + auto x = NDArrayFactory::create_('c', { 1, 1 }); + auto y = NDArrayFactory::create_('c', { 1, 1 }); + auto exp = NDArrayFactory::create('c', { 1, 1 }); + x->assign(2); + y->assign(1); + exp.assign(3); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + auto block = new Context(1, variableSpace, true); + block->fillInputs({ -1, -2 }); + + sd::ops::add addOp; + + addOp.execute(block); + + ASSERT_TRUE(x->equalsTo(&exp)); + + delete variableSpace; + delete block; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, SubtractMatrices1) { + + auto x = NDArrayFactory::create_('c', { 5, 3 }); + auto y = NDArrayFactory::create_('c', { 5, 3 }); + auto exp = NDArrayFactory::create('c', { 5, 3 }); + x->assign(3); + y->assign(1); + exp.assign(2); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + auto block = new Context(1, variableSpace, true); + block->fillInputs({ -1, -2 }); + + sd::ops::subtract subOp; + + subOp.execute(block); + + ASSERT_TRUE(x->equalsTo(&exp)); + + + delete variableSpace; + delete block; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, SubtractTest_1) { + + auto x = NDArrayFactory::create_('c', { 1, 6 }); + auto y = NDArrayFactory::create_('c', { 1, 6 }); + auto exp = NDArrayFactory::create('c', { 1, 6 }); + x->assign(3); + y->assign(1); + exp.assign(2); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + auto block = new Context(1, variableSpace, true); + block->fillInputs({ -1, -2 }); + + sd::ops::subtract subOp; + + subOp.execute(block); + + ASSERT_TRUE(x->equalsTo(&exp)); + + + delete variableSpace; + delete block; +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, SubtractTest_2) { + + auto x = NDArrayFactory::create('c', { 3, 4, 5, 1 }); + auto y = NDArrayFactory::create('c', { 1, 6 }); + // auto y({6}, {1,1,1,1,1,1}); + auto exp = NDArrayFactory::create('c', { 3, 4, 5, 6 }); + x.assign(3); + y.assign(1); + exp.assign(2); + + + sd::ops::subtract subOp; + + auto res = subOp.evaluate({ &x, &y }); + + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + + ASSERT_TRUE(res.at(0)->equalsTo(&exp)); + + +} + +TEST_F(DeclarableOpsTests1, TestRng1) { + /* + Nd4jLong *buffer = new Nd4jLong[100000]; + + sd::random::RandomBuffer *rng = (sd::random::RandomBuffer *) initRandom(nullptr, 123, 100000, (Nd4jPointer) buffer); + + if (rng == nullptr) + throw std::runtime_error("RNG initialization failed"); + + auto x = NDArrayFactory::create_('c', {5, 3}); + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + auto block = new Context(1, variableSpace, true); + block->fillInputs({-1}); + block->setRNG(rng); + block->getTArguments()->push_back(0.0f); + block->getTArguments()->push_back(1.0f); + + sd::ops::randomuniform uniform; + + Nd4jStatus status = uniform.execute(block); + + ASSERT_EQ(ND4J_STATUS_OK, status); + + ASSERT_TRUE(x->sumNumber() > 0.0); + + destroyRandom((Nd4jPointer) rng); + delete[] buffer; + + delete variableSpace; + delete block; + */ +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, MergeSumTest1) { + + auto x = NDArrayFactory::create_('c', { 5, 5 }); + auto y = NDArrayFactory::create_('c', { 5, 5 }); + auto z = NDArrayFactory::create_('c', { 5, 5 }); + auto exp = NDArrayFactory::create('c', { 5, 5 }); + x->assign(3); + y->assign(1); + z->assign(2); + exp.assign(6); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + variableSpace->putVariable(-3, z); + variableSpace->putVariable(1, new Variable(NDArrayFactory::create_('c', { 5, 5 }))); + auto block = new Context(1, variableSpace, false); + block->fillInputs({ -1, -2, -3 }); + + sd::ops::mergeadd merge; + + merge.execute(block); + + auto res = variableSpace->getVariable(1)->getNDArray(); + + ASSERT_TRUE(res->equalsTo(&exp)); + + delete variableSpace; + delete block; +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, ClipByValue1) { + + auto x = NDArrayFactory::create_('c', { 5, 5 }); + auto exp = NDArrayFactory::create('c', { 5, 5 }); + x->assign(4); + x->p(0, -1); + x->p(1, 2); + exp.assign(3); + exp.p(0, 0); + exp.p(1, 2); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(1, new Variable()); + auto block = new Context(1, variableSpace, true); + block->getTArguments()->push_back(0.0f); + block->getTArguments()->push_back(3.0f); + block->fillInputs({ -1 }); + + sd::ops::clipbyvalue clip; + + clip.execute(block); + + ASSERT_TRUE(x->equalsTo(&exp)); + + + delete variableSpace; + delete block; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, MergeAvgTest1) { + + auto x = NDArrayFactory::create_('c', { 5, 5 }); + auto y = NDArrayFactory::create_('c', { 5, 5 }); + auto z = NDArrayFactory::create_('c', { 5, 5 }); + auto exp = NDArrayFactory::create('c', { 5, 5 }); + x->assign(3); + y->assign(1); + z->assign(2); + exp.assign(2); + + auto zu = NDArrayFactory::create('c', { 5, 5 }); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + variableSpace->putVariable(-3, z); + variableSpace->putVariable(1, new Variable(NDArrayFactory::create_('c', { 5, 5 }))); + auto block = new Context(1, variableSpace, false); + block->fillInputs({ -1, -2, -3 }); + + sd::ops::mergeavg merge; + + merge.execute(block); + + auto res = variableSpace->getVariable(1)->getNDArray(); + + ASSERT_TRUE(res->equalsTo(&exp)); + + delete block; + delete variableSpace; +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, SubtractVectorVector1) { + + auto x = NDArrayFactory::create_('c', { 1, 15 }); + auto y = NDArrayFactory::create_('c', { 1, 15 }); + auto exp = NDArrayFactory::create('c', { 1, 15 }); + x->assign(3); + y->assign(1); + exp.assign(2); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + auto block = new Context(1, variableSpace, true); + block->fillInputs({ -1, -2 }); + + sd::ops::subtract subOp; + + subOp.execute(block); + + ASSERT_TRUE(x->equalsTo(&exp)); + + delete block; + delete variableSpace; + +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, SubtractMatrixScalar1) { + + auto x = NDArrayFactory::create_('c', { 5, 3 }); + auto y = NDArrayFactory::create_('c', { 1, 1 }); + auto exp = NDArrayFactory::create('c', { 5, 3 }); + x->assign(3); + y->assign(1); + exp.assign(2); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + auto block = new Context(1, variableSpace, true); + block->fillInputs({ -1, -2 }); + + sd::ops::subtract subOp; + + subOp.execute(block); + + ASSERT_TRUE(x->equalsTo(&exp)); + + delete block; + delete variableSpace; +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, SubtractScalarScalar1) { + + auto x = NDArrayFactory::create_('c', { 1, 1 }); + auto y = NDArrayFactory::create_('c', { 1, 1 }); + auto exp = NDArrayFactory::create('c', { 1, 1 }); + x->assign(3); + y->assign(1); + exp.assign(2); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + auto block = new Context(1, variableSpace, true); + block->fillInputs({ -1, -2 }); + + sd::ops::subtract subOp; + + subOp.execute(block); + + ASSERT_TRUE(x->equalsTo(&exp)); + + delete block; + delete variableSpace; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, ReverseSubtractMatrices1) { + + auto x = NDArrayFactory::create_('c', { 5, 3 }); + auto y = NDArrayFactory::create_('c', { 5, 3 }); + auto exp = NDArrayFactory::create('c', { 5, 3 }); + x->assign(3.f); + y->assign(1.f); + exp.assign(-2.f); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + auto block = new Context(1, variableSpace, true); + block->fillInputs({ -1, -2 }); + + sd::ops::reversesubtract subOp; + + subOp.execute(block); + + ASSERT_TRUE(x->equalsTo(&exp)); + + delete variableSpace; + delete block; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, ReverseSubtractTest_1) { + + auto x = NDArrayFactory::create('c', { 1, 6 }); + auto y = NDArrayFactory::create('c', { 1, 6 }); + auto exp = NDArrayFactory::create('c', { 1, 6 }); + x.assign(3.f); + y.assign(1.f); + exp.assign(-2.f); + + sd::ops::reversesubtract subOp; + + auto res = subOp.evaluate({ &x, &y }); + + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + ASSERT_TRUE(res.at(0)->equalsTo(&exp)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, ReverseSubtractTest_2) { + + // auto x('c', {1, 6}); + auto x = NDArrayFactory::create('c', { 1, 6 }); + auto y = NDArrayFactory::create('c', { 3, 4, 5, 1 }); + auto exp = NDArrayFactory::create('c', { 3, 4, 5, 6 }); + auto z(exp); + x.assign(3.f); + y.assign(1.f); + exp.assign(-2.f); + x.applyTrueBroadcast(BROADCAST(ReverseSubtract), y, z, true); + + ASSERT_TRUE(exp.equalsTo(&z)); + + sd::ops::reversesubtract subOp; + + auto res = subOp.evaluate({ &x, &y }); + + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + ASSERT_TRUE(res.at(0)->equalsTo(&exp)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, ReverseSubtractTest_3) { + + // auto x('c', {1, 6}); + auto x = NDArrayFactory::create('c', { 6 }); + auto y = NDArrayFactory::create('c', { 3, 4, 5, 1 }); + auto exp = NDArrayFactory::create('c', { 3, 4, 5, 6 }); + auto z(exp); + x.assign(1); + y.assign(3); + exp.assign(2); + x.applyTrueBroadcast(BROADCAST(ReverseSubtract), y, z, true); + ASSERT_TRUE(z.equalsTo(&exp)); + sd::ops::reversesubtract subOp; + + auto res = subOp.evaluate({ &x, &y }); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + ASSERT_TRUE(res.at(0)->equalsTo(&exp)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, ReverseModTest_1) { + + // auto x('c', {1, 6}); + auto x = NDArrayFactory::create('c', { 6 }); + auto y = NDArrayFactory::create('c', { 3, 4, 5, 1 }); + auto exp = NDArrayFactory::create('c', { 3, 4, 5, 6 }); + auto z(exp); + x.assign(2.); + y.assign(9.f); + exp.assign(1.f); + y.applyTrueBroadcast(BROADCAST(Mod), x, z, true); + ASSERT_TRUE(exp.equalsTo(&z)); + + x.applyTrueBroadcast(BROADCAST(ReverseMod), y, exp, true); + ASSERT_TRUE(exp.equalsTo(&z)); + + sd::ops::reversemod subOp; + + auto res = subOp.evaluate({ &x, &y }); + + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + ASSERT_TRUE(res.at(0)->equalsTo(&exp)); + ASSERT_TRUE(exp.equalsTo(&z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, ReverseModTest_2) { + + // auto x('c', {1, 6}); + auto x = NDArrayFactory::create('c', { 3, 4, 5 }); + auto y = NDArrayFactory::create('c', { 3, 4, 5 }); + auto exp = NDArrayFactory::create('c', { 3, 4, 5 }); + auto z(exp); + x.assign(2.f); + y.assign(9.f); + exp.assign(1.f); + x.applyTrueBroadcast(BROADCAST(ReverseMod), y, z, true); + ASSERT_TRUE(z.equalsTo(&exp)); + x.applyTrueBroadcast(BROADCAST(ReverseMod), y, exp, true); + ASSERT_TRUE(z.equalsTo(&exp)); + + sd::ops::reversemod subOp; + + auto res = subOp.evaluate({ &x, &y }); + + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + ASSERT_TRUE(res.at(0)->equalsTo(&exp)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, ReverseSubtractVectorVector1) { + + auto x = NDArrayFactory::create_('c', { 1, 15 }); + auto y = NDArrayFactory::create_('c', { 1, 15 }); + auto exp = NDArrayFactory::create_('c', { 1, 15 }); + x->assign(3); + y->assign(1); + exp->assign(-2); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + auto block = new Context(1, variableSpace, true); + block->fillInputs({ -1, -2 }); + + sd::ops::reversesubtract subOp; + + subOp.execute(block); + + ASSERT_TRUE(x->equalsTo(exp)); + + delete variableSpace; + delete block; + delete exp; +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, ReverseSubtractMatrixScalar1) { + + auto x = NDArrayFactory::create_('c', { 5, 3 }); + auto y = NDArrayFactory::create_('c', { 1, 1 }); + auto exp = NDArrayFactory::create_('c', { 5, 3 }); + x->assign(3); + y->assign(1); + exp->assign(-2); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + auto block = new Context(1, variableSpace, true); + block->fillInputs({ -1, -2 }); + + sd::ops::reversesubtract subOp; + + subOp.execute(block); + + ASSERT_TRUE(x->equalsTo(exp)); + + delete variableSpace; + delete block; + delete exp; +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, ReverseSubtractScalarScalar1) { + + auto x = NDArrayFactory::create_('c', { 1, 1 }); + auto y = NDArrayFactory::create_('c', { 1, 1 }); + auto exp = NDArrayFactory::create_('c', { 1, 1 }); + x->assign(3); + y->assign(1); + exp->assign(-2); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + auto block = new Context(1, variableSpace, true); + block->fillInputs({ -1, -2 }); + + sd::ops::reversesubtract subOp; + + subOp.execute(block); + + ASSERT_TRUE(x->equalsTo(exp)); + + delete variableSpace; + delete block; + delete exp; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, MultiplyMatrices1) { + + auto x = NDArrayFactory::create_('c', { 5, 3 }); + auto y = NDArrayFactory::create_('c', { 5, 3 }); + auto exp = NDArrayFactory::create_('c', { 5, 3 }); + x->assign(2); + y->assign(3); + exp->assign(6); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + auto block = new Context(1, variableSpace, true); + block->fillInputs({ -1, -2 }); + + sd::ops::multiply mul; + + mul.execute(block); + + ASSERT_TRUE(x->equalsTo(exp)); + + delete variableSpace; + delete block; + delete exp; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, MultiplyVectorVector1) { + + auto x = NDArrayFactory::create_('c', { 1, 15 }); + auto y = NDArrayFactory::create_('c', { 1, 15 }); + auto exp = NDArrayFactory::create_('c', { 1, 15 }); + x->assign(2); + y->assign(3); + exp->assign(6); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + auto block = new Context(1, variableSpace, true); + block->fillInputs({ -1, -2 }); + + sd::ops::multiply mul; + + mul.execute(block); + + ASSERT_TRUE(x->equalsTo(exp)); + + delete variableSpace; + delete block; + delete exp; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, MultiplyMatrixScalar) { + + auto x = NDArrayFactory::create_('c', { 5, 3 }); + auto y = NDArrayFactory::create_('c', { 1, 1 }); + auto exp = NDArrayFactory::create_('c', { 5, 3 }); + x->assign(2); + y->assign(3); + exp->assign(6); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + auto block = new Context(1, variableSpace, true); + block->fillInputs({ -1, -2 }); + + sd::ops::multiply mul; + + mul.execute(block); + + ASSERT_TRUE(x->equalsTo(exp)); + + delete variableSpace; + delete block; + delete exp; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, MultiplyScalarScalar1) { + + auto x = NDArrayFactory::create_('c', { 1, 1 }); + auto y = NDArrayFactory::create_('c', { 1, 1 }); + auto exp = NDArrayFactory::create_('c', { 1, 1 }); + x->assign(2); + y->assign(3); + exp->assign(6); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + auto block = new Context(1, variableSpace, true); + block->fillInputs({ -1, -2 }); + + sd::ops::multiply mul; + + mul.execute(block); + + ASSERT_TRUE(x->equalsTo(exp)); + + delete block; + delete variableSpace; + delete exp; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, TestSoftMax_bp_1) { + + auto input = NDArrayFactory::create_('c', { 2, 2 }); + for (int e = 0; e < input->lengthOf(); e++) + input->p(e, e + 1); + + auto epsilon = NDArrayFactory::create_('c', { 2, 2 }); + epsilon->p(0, 0.1f); + epsilon->p(1, 0.2f); + epsilon->p(2, 0.3f); + epsilon->p(3, 0.4f); + + auto output = NDArrayFactory::create_('c', { 2, 2 }); + output->assign(1.0f); + + auto exp = NDArrayFactory::create_('c', { 2, 2 }); + exp->p(0, -0.019661194f); + exp->p(1, 0.019661194f); + exp->p(2, -0.019661194f); + exp->p(3, 0.019661194f); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, input); + variableSpace->putVariable(-2, epsilon); + variableSpace->putVariable(1, output); + //variableSpace->putVariable(42, exp); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({ -1, -2 }); + + sd::ops::softmax_bp op; + + Nd4jStatus status = op.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + ASSERT_TRUE(output->equalsTo(exp)); + + delete variableSpace; + delete block; + delete exp; + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, BroadcastDivideTest_1) { + + auto x = NDArrayFactory::create('c', { 3, 4, 5, 1 }); + auto y = NDArrayFactory::create('c', { 1, 6 }); + auto exp = NDArrayFactory::create('c', { 3, 4, 5, 6 }); + x.assign(6); + y.assign(2); + exp.assign(3); + + sd::ops::divide div; + + auto res = div.evaluate({ &x, &y }); + + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + ASSERT_TRUE(res.at(0)->equalsTo(exp)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, BroadcastDivideTest_2) { + + auto x = NDArrayFactory::create('c', { 3, 4, 5, 1 }); + auto y = NDArrayFactory::create('c', { 1, 6 }); + auto exp = NDArrayFactory::create('c', { 3, 4, 5, 6 }); + x.assign(6); + y.assign(2); + exp.assign(3); + + sd::ops::divide_no_nan div; + auto res = div.evaluate({ &x, &y }); + + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + ASSERT_TRUE(res.at(0)->equalsTo(exp)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, BroadcastDivideTest_3) { + + auto x = NDArrayFactory::create({ 6,6,6,6,6 }); + auto y = NDArrayFactory::create({ 3,3,0,3,3 }); + auto exp = NDArrayFactory::create({ 2, 2, 0, 2, 2 }); + + sd::ops::divide_no_nan div; + auto res = div.evaluate({ &x, &y }); + + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + ASSERT_TRUE(res.at(0)->equalsTo(exp)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, BroadcastReverseDivideTest_1) { + + auto x = NDArrayFactory::create('c', { 3, 4, 5, 1 }); + auto y = NDArrayFactory::create('c', { 1, 6 }); + auto exp = NDArrayFactory::create('c', { 3, 4, 5, 6 }); + x.assign(3.f); + y.assign(6.f); + exp.assign(2.f); + + sd::ops::reversedivide div; + + auto res = div.evaluate({ &x, &y }); + + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + + ASSERT_TRUE(res.at(0)->equalsTo(exp)); + auto z(exp); + x.applyTrueBroadcast(BROADCAST(ReverseDivide), y, z, true); + y.applyTrueBroadcast(BROADCAST(Divide), x, exp, true); + + ASSERT_TRUE(z.equalsTo(&exp)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, DivideMatrices1) { + + auto x = NDArrayFactory::create_('c', { 5, 3 }); + auto y = NDArrayFactory::create_('c', { 5, 3 }); + auto exp = NDArrayFactory::create_('c', { 5, 3 }); + x->assign(6); + y->assign(2); + exp->assign(3); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + auto block = new Context(1, variableSpace, true); + block->fillInputs({ -1, -2 }); + + sd::ops::divide div; + + div.execute(block); + + ASSERT_TRUE(x->equalsTo(exp)); + + delete variableSpace; + delete block; + delete exp; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, DivideVectorVector1) { + + auto x = NDArrayFactory::create_('c', { 1, 15 }); + auto y = NDArrayFactory::create_('c', { 1, 15 }); + auto exp = NDArrayFactory::create('c', { 1, 15 }); + x->assign(6); + y->assign(2); + exp.assign(3); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + auto block = new Context(1, variableSpace, true); + block->fillInputs({ -1, -2 }); + + sd::ops::divide div; + + div.execute(block); + + ASSERT_TRUE(x->equalsTo(&exp)); + + delete variableSpace; + delete block; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, DivideMatrixScalar1) { + + auto x = NDArrayFactory::create_('c', { 5, 3 }); + auto y = NDArrayFactory::create_('c', { 1, 1 }); + auto exp = NDArrayFactory::create('c', { 5, 3 }); + x->assign(6); + y->assign(2); + exp.assign(3); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + auto block = new Context(1, variableSpace, true); + block->fillInputs({ -1, -2 }); + + sd::ops::divide div; + + div.execute(block); + + ASSERT_TRUE(x->equalsTo(&exp)); + + delete block; + delete variableSpace; +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, DivideScalarScalar1) { + + auto x = NDArrayFactory::create_('c', { 5, 1 }); + auto y = NDArrayFactory::create_('c', { 5, 1 }); + auto exp = NDArrayFactory::create('c', { 5, 1 }); + x->assign(6); + y->assign(2); + exp.assign(3); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + auto block = new Context(1, variableSpace, true); + block->fillInputs({ -1, -2 }); + + sd::ops::divide div; + + div.execute(block); + + ASSERT_TRUE(x->equalsTo(&exp)); + + delete variableSpace; + delete block; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, ReverseDivideMatrices1) { + + auto x = NDArrayFactory::create_('c', { 5, 3 }); + auto y = NDArrayFactory::create_('c', { 5, 3 }); + auto exp = NDArrayFactory::create('c', { 5, 3 }); + x->assign(2); + y->assign(6); + exp.assign(3); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + auto block = new Context(1, variableSpace, true); + block->fillInputs({ -1, -2 }); + + sd::ops::reversedivide div; + + div.execute(block); + + ASSERT_TRUE(x->equalsTo(&exp)); + + delete variableSpace; + delete block; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, ReverseDivideVectorVector1) { + + auto x = NDArrayFactory::create_('c', { 1, 15 }); + auto y = NDArrayFactory::create_('c', { 1, 15 }); + auto exp = NDArrayFactory::create('c', { 1, 15 }); + x->assign(2); + y->assign(6); + exp.assign(3); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + auto block = new Context(1, variableSpace, true); + block->fillInputs({ -1, -2 }); + + sd::ops::reversedivide div; + + div.execute(block); + + ASSERT_TRUE(x->equalsTo(&exp)); + + delete variableSpace; + delete block; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, ReverseDivideMatrixScalar1) { + + auto x = NDArrayFactory::create_('c', { 5, 3 }); + auto y = NDArrayFactory::create_('c', { 1, 1 }); + auto exp = NDArrayFactory::create('c', { 5, 3 }); + x->assign(2); + y->assign(6); + exp.assign(3); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + auto block = new Context(1, variableSpace, true); + block->fillInputs({ -1, -2 }); + + sd::ops::reversedivide div; + + div.execute(block); + + ASSERT_TRUE(x->equalsTo(&exp)); + + delete variableSpace; + delete block; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, ReverseDivideScalarScalar1) { + + auto x = NDArrayFactory::create_('c', { 1, 1 }); + auto y = NDArrayFactory::create_('c', { 1, 1 }); + auto exp = NDArrayFactory::create('c', { 1, 1 }); + x->assign(2); + y->assign(6); + exp.assign(3); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + auto block = new Context(1, variableSpace, true); + block->fillInputs({ -1, -2 }); + + sd::ops::reversedivide div; + + div.execute(block); + + ASSERT_TRUE(x->equalsTo(&exp)); + + delete variableSpace; + delete block; +} + +TEST_F(DeclarableOpsTests1, Test_Cast_1) { + // TODO: right now there's no real cast implementation, but genera idea should be the same: arrays equality to be expected + auto x = NDArrayFactory::create('c', { 5, 5 }); + auto yExp = NDArrayFactory::create('c', { 5, 5 }); + x.linspace(1); + yExp.linspace(1); + sd::ops::cast op; + + auto result = op.evaluate({ &x }, {}, { 3 }); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + ASSERT_TRUE(yExp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, TestRegistrator1) { + auto res = sd::ops::OpRegistrator::getInstance().getAllCustomOperations(); +} + +// ////////////////////////////////////////////////////////////////////// +// TEST_F(DeclarableOpsTests1, TestLegacyExecution1) { +// auto x = NDArrayFactory::create_('c', {10, 10}); +// x->assign(1.0f); + +// auto y = NDArrayFactory::create_('c', {10, 10}); +// y->assign(2.0f); + +// auto z = NDArrayFactory::create_('c', {10, 10}); + +// auto exp = NDArrayFactory::create_('c', {10, 10}); +// exp->assign(3.0f); +// z->assign(120.0f); +// std::string opName("add"); + +// auto hash = sd::ops::HashHelper::getInstance().getInstance()->getLongHash(opName); + +// auto inputBuffers = new Nd4jPointer[2]; +// auto inputShapes = new Nd4jPointer[2]; + +// inputBuffers[0] = (Nd4jPointer) x->buffer(); +// inputBuffers[1] = (Nd4jPointer) y->buffer(); + +// inputShapes[0] = (Nd4jPointer) x->shapeInfo(); +// inputShapes[1] = (Nd4jPointer) y->shapeInfo(); + +// auto outputBuffers = new Nd4jPointer[1]; +// auto outputShapes = new Nd4jPointer[1]; + +// outputBuffers[0] = (Nd4jPointer) z->buffer(); +// outputShapes[0] = (Nd4jPointer) z->shapeInfo(); + + +// //auto status = execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, false); +// auto status = execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); +// ASSERT_EQ(ND4J_STATUS_OK, status); +// ASSERT_NEAR(2.0f, y->meanNumber().e(0), 1e-5); +// ASSERT_NEAR(1.0f, x->meanNumber().e(0), 1e-5); +// ASSERT_NEAR(3.0f, z->meanNumber().e(0), 1e-5); + +// delete x; +// delete y; +// delete z; +// delete exp; +// delete[] inputBuffers; +// delete[] inputShapes; +// delete[] outputBuffers; +// delete[] outputShapes; +// } + +// ////////////////////////////////////////////////////////////////////// +// TEST_F(DeclarableOpsTests1, TestLegacyExecution2) { +// auto x = NDArrayFactory::create_('c', {10, 10}); +// x->assign(1.0f); + +// auto y = NDArrayFactory::create_('c', {10, 10}); +// y->assign(2.0f); + +// auto z = NDArrayFactory::create_('c', {10, 10}); + +// auto exp = NDArrayFactory::create_('c', {10, 10}); +// exp->assign(3.0); + +// std::string opName("add"); + +// auto hash = sd::ops::HashHelper::getInstance().getInstance()->getLongHash(opName); + +// auto inputBuffers = new Nd4jPointer[2]; +// auto inputShapes = new Nd4jPointer[2]; + +// inputBuffers[0] = (Nd4jPointer) x->buffer(); +// inputBuffers[1] = (Nd4jPointer) y->buffer(); + +// inputShapes[0] = (Nd4jPointer) x->shapeInfo(); +// inputShapes[1] = (Nd4jPointer) y->shapeInfo(); + +// auto outputBuffers = new Nd4jPointer[1]; +// auto outputShapes = new Nd4jPointer[1]; + +// execCustomOp(nullptr, hash, inputBuffers, inputShapes, 2, outputBuffers, outputShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, true); + +// ASSERT_NEAR(2.0, y->meanNumber().e(0), 1e-5); +// ASSERT_NEAR(3.0, x->meanNumber().e(0), 1e-5); + + +// delete x; +// delete y; +// delete z; +// delete exp; +// delete[] inputBuffers; +// delete[] inputShapes; +// delete[] outputBuffers; +// delete[] outputShapes; +// } + +#ifndef __CUDABLAS__ +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, TestGemv1) { + /* + auto xBuffer = new float[15]{1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}; + auto xShape = new Nd4jLong[8] {2, 5, 3, 3, 1, 0, 1, 99}; + ArrayOptions::setDataType(xShape, sd::DataType::FLOAT32); + auto x = new NDArray(xBuffer, xShape); + + auto yBuffer = new float[3]{2.f, 4.f, 6.f}; + auto yShape = new Nd4jLong[8] {2, 3, 1, 1, 1, 0, 1, 99}; + ArrayOptions::setDataType(yShape, sd::DataType::FLOAT32); + + auto y = new NDArray(yBuffer, yShape); + + auto z = NDArrayFactory::create_('f', {5, 1}); + + auto expBuffer = new float[5]{28.00f,64.00f,100.00f,136.00f,172.00f}; + auto exp = new NDArray(expBuffer, z->shapeInfo()); + + sd::blas::GEMV::op('f', x->rows(), x->columns(), 1.0f, x->buffer(), y->rows(), y->buffer(), 1, 0.0, z->buffer(), 1); + + ASSERT_TRUE(z->equalsTo(exp)); + + delete []xBuffer; delete []xShape; delete x; delete []yBuffer; delete []yShape; delete y; delete z; delete []expBuffer; delete exp; + */ +} + +#endif + + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, Transpose1) { + auto x = NDArrayFactory::create_('c', { 3,5,2 }); + auto exp = NDArrayFactory::create_('c', { 2,5,3 }); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(1, new Variable()); + + auto block = new Context(1, variableSpace, false); // not-in-place + block->fillInputs({ -1 }); + sd::ops::transpose transpose; + + Nd4jStatus status = transpose.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + + ASSERT_TRUE(exp->isSameShape(result)); + ASSERT_TRUE(exp->dataType() == result->dataType()); + ASSERT_TRUE(exp->ordering() == result->ordering()); + + delete exp; + delete block; + delete variableSpace; +} + +////////////////////////////////////////////////////////////////////// +// not-in-place +TEST_F(DeclarableOpsTests1, Permute1) { + + Nd4jLong shapeX[] = { 3, 5,10,15, 150,15,1, 0,1,99 }; + Nd4jLong shapeExp[] = { 3, 15,5,10, 50,10,1, 0,1,99 }; + const std::vector perm = { 2, 0, 1 }; + + ArrayOptions::setDataType(shapeX, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shapeExp, sd::DataType::FLOAT32); + + auto x = new NDArray(shapeX, true); + auto exp = new NDArray(shapeExp, true); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(1, new Variable()); + + auto block = new Context(1, variableSpace, false); // not-in-place + block->fillInputs({ -1 }); + auto arguments = block->getIArguments(); + *arguments = perm; // set dimensions to be permuted + + sd::ops::permute permute; + Nd4jStatus status = permute.execute(block); + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(result->isSameShapeStrict(*exp)); + + delete block; + delete variableSpace; + delete exp; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, TestArgumentsValidation1) { + Nd4jLong shapeX[] = { 3, 5, 10, 15, 150, 15, 1, 0, 1, 99 }; + Nd4jLong shapeExp[] = { 3, 15, 5, 10, 1, 150, 15, 0, -1, 99 }; + + ArrayOptions::setDataType(shapeX, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shapeExp, sd::DataType::FLOAT32); + + const std::vector perm = { 2, 0, 1 }; + auto x = new NDArray(shapeX); + auto exp = new NDArray(shapeExp); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(1, new Variable()); + + auto block = new Context(1, variableSpace, false); // not-in-place + block->fillInputs({ -1 }); + + sd::ops::im2col permute; + Nd4jStatus status = permute.execute(block); + + ASSERT_TRUE(status != 0); + + delete exp; + delete block; + delete variableSpace; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, TestReductionShape1) { + auto input = NDArrayFactory::create_('c', { 4, 5, 5, 10, 10 }); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, input); + + auto block = new Context(1, variableSpace, false); // not-in-place + block->fillInputs({ -1 }); + + // kernel params + block->getIArguments()->push_back(MAX_INT); + + sd::ops::testreduction testop; + + auto inP = new Nd4jLong[shape::shapeInfoLength(input->shapeInfo())]; + memcpy(inP, input->shapeInfo(), shape::shapeInfoByteLength(input->rankOf())); + + auto inshape = new ShapeList(inP); + + auto shapes = testop.calculateOutputShape(inshape, *block); + + ASSERT_EQ(1, shapes->size()); + ASSERT_EQ(0, shapes->at(0)[0]); // scalar shape has rank 0 + ASSERT_EQ(8192, shapes->at(0)[1]); + ASSERT_EQ(1, shapes->at(0)[2]); + + delete[] inP; + delete variableSpace; + delete block; + delete inshape; + delete shapes; + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, TestReductionShape2) { + auto input = NDArrayFactory::create_('c', { 4, 5, 5, 10, 10 }); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, input); + + auto block = new Context(1, variableSpace, false); // not-in-place + block->fillInputs({ -1 }); + + // kernel params + //block->getIArguments()->push_back(4); + block->getIArguments()->push_back(1); + block->getIArguments()->push_back(2); + block->getIArguments()->push_back(3); + block->getIArguments()->push_back(4); + + sd::ops::testreduction testop; + + auto inshapes = new ShapeList(input->shapeInfo()); + auto shapes = testop.calculateOutputShape(inshapes, *block); + ASSERT_EQ(1, shapes->size()); + ASSERT_EQ(1, shapes->at(0)[0]); + ASSERT_EQ(4, shapes->at(0)[1]); + ASSERT_EQ(1, shapes->at(0)[2]); + + delete variableSpace; + delete block; + delete shapes; + delete inshapes; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, TestCustomShape1) { + auto input = NDArrayFactory::create_('c', { 2, 3, 4 }); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, input); + + auto block = new Context(1, variableSpace, false); // not-in-place + block->fillInputs({ -1 }); + + sd::ops::testcustom test; + + auto inshapes = new ShapeList(input->shapeInfo()); + auto shapes = test.calculateOutputShape(inshapes, *block); + + + ASSERT_EQ(input->shapeInfo()[0], shapes->at(0)[0]); + ASSERT_EQ(input->shapeInfo()[1] * 2, shapes->at(0)[1]); + ASSERT_EQ(input->shapeInfo()[2] * 2, shapes->at(0)[2]); + ASSERT_EQ(input->shapeInfo()[3] * 2, shapes->at(0)[3]); + + delete variableSpace; + delete block; + delete shapes; + delete inshapes; +} + + +////////////////////////////////////////////////////////////////////// +/* +TEST_F(DeclarableOpsTests1, Sum1) { + + float xBuff[] = {1, 2, 3, 4, 5, 6, 7, 8}; + int xShape[] = {2, 4, 2, 2, 1, 0, 1, 99}; + float expBuff[] = {16, 20}; + int expShape[] = {2, 1, 2, 2, 1, 0, 1, 99}; + + const std::vector dimensions = {1,0}; + + auto x = NDArrayFactory::create_(xBuff, xShape); + auto z = NDArrayFactory::create_(1, 2, 'c'); + auto exp(expBuff, expShape); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(1, z); + + auto block = new Context(1, variableSpace, false); // not-in-place + block->fillInputs({-1}); + std::vector* arguments = block->getIArguments(); + *arguments = dimensions; + + sd::ops::sum sum; + Nd4jStatus status = sum.execute(block); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(result->equalsTo(&exp)); + + delete block; + delete variableSpace; +} +*/ + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, Pnormpool2d1) { + + auto x = NDArrayFactory::create_('c', { bS,iD,iH,iW }); + auto exp = NDArrayFactory::create('c', { bS,iD,oH,oW }); + // auto z('c',{bS,iD,oH,oW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({ -1 }); + std::vector* argI = block->getIArguments(); + *argI = { kH,kW, sH,sW, pH,pW, dW,dH, 0, 1, 0 }; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; 9 - extraParam0 for pnorm case; + + sd::ops::pnormpool2d pooling; + Nd4jStatus status = pooling.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + ASSERT_TRUE(exp.isSameShape(result)); + + delete variableSpace; + delete block; +} + +/*///////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, IsMax1) { + + float xBuff[] = {1,2,3,4,5,6,7,8,9}; + Nd4jLong xShape[] = {2,3,3,3,1,0,1,99}; + bool expBuff[] = {0,0,1,0,0,1,0,0,1}; + ArrayOptions::setDataType(xShape, sd::DataType::BOOL); + + auto x = new NDArray(xBuff, xShape); + NDArray exp(expBuff, xShape); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + std::vector* argI = block->getIArguments(); +// *argI = {1}; // dimensions + argI->push_back(1); // = {1}; // dimensions + + sd::ops::ismax ismaxOp; + Nd4jStatus status = ismaxOp.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + result->printIndexedBuffer("IS_MAX"); + ASSERT_TRUE(exp.equalsTo(result)); + + delete variableSpace; + delete block; +} +*/ + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, IsMax1) { + NDArray x('c', { 3, 3 }, sd::DataType::FLOAT32); + // NDArray exp('c', {3, 3}, sd::DataType::BOOL); + NDArray exp('c', { 3, 3 }, sd::DataType::FLOAT32); + x.linspace(1); + exp.p(0, 2, true); + exp.p(1, 2, true); + exp.p(2, 2, true); + + sd::ops::ismax ismaxOp; + auto result = ismaxOp.evaluate({ &x }, {}, { 1 }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto res = result.at(0); + //res->printIndexedBuffer("IS_MAX"); + ASSERT_TRUE(exp.equalsTo(res)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, IsMax2) { + NDArray x('c', { 3, 3 }, sd::DataType::FLOAT32); + // NDArray exp('c', {3, 3}, sd::DataType::BOOL); + NDArray exp('c', { 3, 3 }, sd::DataType::FLOAT32); + x.linspace(1); + //exp.p(0, 2, true); + //exp.p(1, 2, true); + exp.p(2, 2, true); + + sd::ops::ismax ismaxOp; + auto result = ismaxOp.evaluate({ &x }, {}, { 0, 1 }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto res = result.at(0); + //res->printIndexedBuffer("IS_MAX"); + ASSERT_TRUE(exp.equalsTo(res)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, IsMax3) { + NDArray x = NDArrayFactory::create(120.f); //('c', {3, 3}, sd::DataType::FLOAT32); +// NDArray exp('c', {3, 3}, sd::DataType::BOOL); + NDArray exp = NDArrayFactory::create(1.f);//, sd::DataType::FLOAT32); //'c', {3, 3}, sd::DataType::FLOAT32); + x.linspace(1); + //exp.p(0, 2, true); + //exp.p(1, 2, true); + //exp.p(2, 2, true); + + sd::ops::ismax ismaxOp; + auto result = ismaxOp.evaluate({ &x }, {}, { 0 }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto res = result.at(0); + //res->printIndexedBuffer("IS_MAX"); + ASSERT_TRUE(exp.equalsTo(res)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, IsMax4) { + auto x = NDArrayFactory::create('c', { 6 }, { 0, 0, 0, 2, 2, 0 }); + auto z = NDArrayFactory::create('c', { 6 }); + auto e = NDArrayFactory::create('c', { 6 }, { false, false, false, true, false, false }); + + sd::ops::ismax op; + auto result = op.execute({ &x }, { &z }); + ASSERT_EQ(Status::OK(), result); + + ASSERT_EQ(e, z); +} + +//////////////////////////////////////////////////////////////////// +// TEST_F(DeclarableOpsTests1, sru_old_test1) { + +// const int bS = 2; +// const int K = 3; +// const int N = 4; + +// NDArray input('c', {bS,K,N}, sd::DataType::DOUBLE); +// NDArray weights('c', {3*K,K}, sd::DataType::DOUBLE); +// NDArray bias('c', {1,2*K}, sd::DataType::DOUBLE); +// NDArray init('c', {bS,K}, sd::DataType::DOUBLE); +// NDArray mask('c', {bS,K}, sd::DataType::DOUBLE); +// NDArray expState('c', {bS,K,N}, {0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715}, sd::DataType::DOUBLE); +// NDArray expOut('c', {bS,K,N}, {1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656}, sd::DataType::DOUBLE); + +// input.assign(1.5); +// weights.assign(0.5); +// bias.assign(0.3) ; +// init.assign(1.); +// mask.assign(1.); + +// sd::ops::sru_old op; +// auto results = op.execute({&input, &weights, &bias, &init, &mask}, {}, {}); +// ASSERT_TRUE(results.size() == 2); + +// auto state = results.at(0); +// auto output = results.at(1); +// // state->printBuffer(); +// // expState.printIndexedBuffer("EXP STATE"); +// // state->printIndexedBuffer("OUT STATE"); +// ASSERT_TRUE(expState.equalsTo(state)); +// ASSERT_TRUE(expOut.equalsTo(output)); + +// +// } + +////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, sru_test1) { + + const int bS = 2; + const int K = 3; + const int N = 4; + + NDArray input('c', { bS,K,N }, sd::DataType::DOUBLE); + NDArray weights('c', { 3 * K,K }, sd::DataType::DOUBLE); + NDArray bias('c', { 2 * K }, sd::DataType::DOUBLE); + NDArray init('c', { bS,K }, sd::DataType::DOUBLE); + NDArray mask('c', { bS,K }, sd::DataType::DOUBLE); + NDArray expState('c', { bS,K,N }, { 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656, 1.090533, 1.174509, 1.252403, 1.324656 }, sd::DataType::DOUBLE); + NDArray expOut('c', { bS,K,N }, { 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715 }, sd::DataType::DOUBLE); + + input.assign(1.5); + weights.assign(0.5); + bias.assign(0.3); + init.assign(1.); + mask.assign(1.); + + sd::ops::sru op; + auto results = op.evaluate({ &input, &weights, &bias, &init, &mask }); + ASSERT_TRUE(results.size() == 2); + + auto output = results.at(0); + auto state = results.at(1); + + ASSERT_TRUE(expState.equalsTo(state)); + ASSERT_TRUE(expOut.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, sru_bp) { + + const int bS = 2; + const int K = 3; + const int N = 4; + std::vector expGradXBuff = { -0.0259303, -0.03869125, -0.0302272, -0.02299165, -0.0259303, -0.03869125, -0.0302272, -0.02299165, -0.0259303, -0.03869125, -0.0302272, -0.02299165, -0.0259303, -0.03869125, -0.0302272, -0.02299165, -0.0259303, -0.03869125, -0.0302272, -0.02299165, -0.0259303, -0.03869125, -0.0302272, -0.02299165 }; + std::vector expGradWBuff = { 0.42526005,0.42526005,0.42526005, 0.42526005,0.42526005,0.42526005, 0.42526005,0.42526005,0.42526005, -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215, 0.42526005,0.42526005,0.42526005, 0.42526005,0.42526005,0.42526005, 0.42526005,0.42526005,0.42526005, -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.5282811 , -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215, -0.15967215 }; + std::vector expGradBBuff = { -0.7043748, -0.7043748, -0.7043748, -0.2128962, -0.2128962, -0.2128962 }; + std::vector expGradInitBuff = { 1.1421, 1.1421, 1.1421, 1.1421, 1.1421, 1.1421 }; + std::vector stateBuff = { 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715, 0.847983, 0.874549, 0.896109, 0.913715 }; + + auto input = NDArrayFactory::create('c', { bS,K,N }); + auto weights = NDArrayFactory::create('c', { 3 * K,K }); + auto bias = NDArrayFactory::create('c', { 1,2 * K }); + auto init = NDArrayFactory::create('c', { bS,K }); + auto mask = NDArrayFactory::create('c', { bS,K }); + auto state = NDArrayFactory::create('c', { bS,K,N }, stateBuff); + auto inGradCt = NDArrayFactory::create('c', { bS,K }); + auto inGradH = NDArrayFactory::create('c', { bS,K,N }); + + auto expGradX = NDArrayFactory::create('c', { bS,K,N }, expGradXBuff); + auto expGradW = NDArrayFactory::create('c', { bS,3 * K,K }, expGradWBuff); + auto expGradB = NDArrayFactory::create('c', { 1,2 * K }, expGradBBuff); + auto expGradInit = NDArrayFactory::create('c', { bS,K }, expGradInitBuff); + + input.assign(1.5); + weights.assign(0.5); + bias.assign(0.3); + mask.assign(1.); + init.assign(1.); + inGradCt.assign(0.5); + inGradH.assign(0.5); + + sd::ops::sru_bp bp; + auto resultsBP = bp.evaluate({ &input, &weights, &bias, &init, &state, &inGradCt, &inGradH, &mask }, {}, {}); + ASSERT_TRUE(resultsBP.size() == 4); + + auto gradX = resultsBP.at(0); + auto gradW = resultsBP.at(1); + auto gradB = resultsBP.at(2); + auto gradInit = resultsBP.at(3); + // expGradX.printBuffer("Exp GRAD"); + // gradX->printBuffer("Res GRAD"); + ASSERT_TRUE(expGradX.equalsTo(gradX, 1e-4)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); + ASSERT_TRUE(expGradInit.equalsTo(gradInit)); + + +} + +////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, sru_bi_1) { + + const int bS = 2; + const int K = 3; + const int N = 4; + + NDArray input('c', { N,bS,2 * K }, sd::DataType::DOUBLE); + NDArray weights('c', { 2 * K,6 * K }, sd::DataType::DOUBLE); + NDArray bias('c', { 4 * K }, sd::DataType::DOUBLE); + NDArray init('c', { bS,2 * K }, sd::DataType::DOUBLE); + NDArray mask('c', { bS,2 * K }, sd::DataType::DOUBLE); + NDArray expState('c', { N,bS,2 * K }, { 1.02857, 1.02857, 1.02857, 1.11288, 1.11288, 1.11288, 1.02857, 1.02857, 1.02857, 1.11288, 1.11288, 1.11288, 1.0569, 1.0569, 1.0569, 1.08501, 1.08501, 1.08501, 1.0569, 1.0569, 1.0569, 1.08501, 1.08501, 1.08501, 1.08501, 1.08501, 1.08501, 1.0569, 1.0569, 1.0569, 1.08501, 1.08501, 1.08501, 1.0569, 1.0569, 1.0569, 1.11288, 1.11288, 1.11288, 1.02857, 1.02857, 1.02857, 1.11288, 1.11288, 1.11288, 1.02857, 1.02857, 1.02857 }); + NDArray expOut('c', { N,bS,2 * K }, { 0.779265, 0.779265, 0.779265, 0.810752, 0.810752, 0.810752, 0.779265, 0.779265, 0.779265, 0.810752, 0.810752, 0.810752, 0.790317, 0.790317, 0.790317, 0.800804, 0.800804, 0.800804, 0.790317, 0.790317, 0.790317, 0.800804, 0.800804, 0.800804, 0.800804, 0.800804, 0.800804, 0.790317, 0.790317, 0.790317, 0.800804, 0.800804, 0.800804, 0.790317, 0.790317, 0.790317, 0.810752, 0.810752, 0.810752, 0.779265, 0.779265, 0.779265, 0.810752, 0.810752, 0.810752, 0.779265, 0.779265, 0.779265 }); + + input.assign(1.5); + weights.assign(0.5); + bias.assign(0.3); + init.assign(1.); + mask.assign(1.); + + sd::ops::sru_bi op; + auto results = op.evaluate({ &input, &weights, &bias, &init, &mask }, {}, {}); + ASSERT_TRUE(results.size() == 2); + + auto output = results.at(0); + auto state = results.at(1); + // state->printBuffer(); + // output->printBuffer(); + + ASSERT_TRUE(expState.equalsTo(state)); + ASSERT_TRUE(expOut.equalsTo(output)); + + +} + +TEST_F(DeclarableOpsTests1, sru_bi_bp_1) { + + const int bS = 2; + const int K = 3; + const int N = 3; + std::vector expGradXBuff = { 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129, 0.00408129 }; + std::vector expGradInitBuff = { 1.05121, 1.05121, 1.05121, 1.02676, 1.02676, 1.02676, 1.05121, 1.05121, 1.05121, 1.02676, 1.02676, 1.02676 }; + std::vector expGradWBuff = { 0.02595354,-0.090096 ,-0.00882456,0.02595354,-0.090096 ,-0.0088245, 0.02595354,-0.090096 ,-0.00882456,0.01651665,-0.0559437,-0.0084390, 0.01651665,-0.0559437,-0.00843906,0.01651665,-0.0559437,-0.00843906, 0.02595354,-0.090096 ,-0.00882456,0.02595354,-0.090096 ,-0.0088245, 0.02595354,-0.090096 ,-0.00882456,0.01651665,-0.0559437,-0.0084390, 0.01651665,-0.0559437,-0.00843906,0.01651665,-0.0559437,-0.00843906, 0.02595354,-0.090096 ,-0.00882456,0.02595354,-0.090096 ,-0.0088245, 0.02595354,-0.090096 ,-0.00882456,0.01651665,-0.0559437,-0.0084390, 0.01651665,-0.0559437,-0.00843906,0.01651665,-0.0559437,-0.00843906, 0.02595354,-0.090096 ,-0.00882456,0.02595354,-0.090096 ,-0.0088245, 0.02595354,-0.090096 ,-0.00882456,0.01651665,-0.0559437,-0.0084390, 0.01651665,-0.0559437,-0.00843906,0.01651665,-0.0559437,-0.00843906, 0.02595354,-0.090096 ,-0.00882456,0.02595354,-0.090096 ,-0.0088245, 0.02595354,-0.090096 ,-0.00882456,0.01651665,-0.0559437,-0.0084390, 0.01651665,-0.0559437,-0.00843906,0.01651665,-0.0559437,-0.00843906, 0.02595354,-0.090096 ,-0.00882456,0.02595354,-0.090096 ,-0.0088245, 0.02595354,-0.090096 ,-0.00882456,0.01651665,-0.0559437,-0.0084390, 0.01651665,-0.0559437,-0.00843906,0.01651665,-0.0559437,-0.00843906, 0.02124567,-0.0731508,-0.00868926,0.02124567,-0.0731508,-0.0086892, 0.02124567,-0.0731508,-0.00868926,0.02084955,-0.0712011,-0.0085608, 0.02084955,-0.0712011,-0.00856086,0.02084955,-0.0712011,-0.00856086, 0.02124567,-0.0731508,-0.00868926,0.02124567,-0.0731508,-0.0086892, 0.02124567,-0.0731508,-0.00868926,0.02084955,-0.0712011,-0.0085608, 0.02084955,-0.0712011,-0.00856086,0.02084955,-0.0712011,-0.00856086, 0.02124567,-0.0731508,-0.00868926,0.02124567,-0.0731508,-0.0086892, 0.02124567,-0.0731508,-0.00868926,0.02084955,-0.0712011,-0.0085608, 0.02084955,-0.0712011,-0.00856086,0.02084955,-0.0712011,-0.00856086, 0.02124567,-0.0731508,-0.00868926,0.02124567,-0.0731508,-0.0086892, 0.02124567,-0.0731508,-0.00868926,0.02084955,-0.0712011,-0.0085608, 0.02084955,-0.0712011,-0.00856086,0.02084955,-0.0712011,-0.00856086, 0.02124567,-0.0731508,-0.00868926,0.02124567,-0.0731508,-0.0086892, 0.02124567,-0.0731508,-0.00868926,0.02084955,-0.0712011,-0.0085608, 0.02084955,-0.0712011,-0.00856086,0.02084955,-0.0712011,-0.00856086, 0.02124567,-0.0731508,-0.00868926,0.02124567,-0.0731508,-0.0086892, 0.02124567,-0.0731508,-0.00868926,0.02084955,-0.0712011,-0.0085608, 0.02084955,-0.0712011,-0.00856086,0.02084955,-0.0712011,-0.00856086, 0.01671156,-0.0570699,-0.00856086,0.01671156,-0.0570699,-0.0085608, 0.01671156,-0.0570699,-0.00856086,0.02534988,-0.0880002,-0.0086892, 0.02534988,-0.0880002,-0.00868926,0.02534988,-0.0880002,-0.00868926, 0.01671156,-0.0570699,-0.00856086,0.01671156,-0.0570699,-0.0085608, 0.01671156,-0.0570699,-0.00856086,0.02534988,-0.0880002,-0.0086892, 0.02534988,-0.0880002,-0.00868926,0.02534988,-0.0880002,-0.00868926, 0.01671156,-0.0570699,-0.00856086,0.01671156,-0.0570699,-0.0085608, 0.01671156,-0.0570699,-0.00856086,0.02534988,-0.0880002,-0.0086892, 0.02534988,-0.0880002,-0.00868926,0.02534988,-0.0880002,-0.00868926, 0.01671156,-0.0570699,-0.00856086,0.01671156,-0.0570699,-0.0085608, 0.01671156,-0.0570699,-0.00856086,0.02534988,-0.0880002,-0.0086892, 0.02534988,-0.0880002,-0.00868926,0.02534988,-0.0880002,-0.00868926, 0.01671156,-0.0570699,-0.00856086,0.01671156,-0.0570699,-0.0085608, 0.01671156,-0.0570699,-0.00856086,0.02534988,-0.0880002,-0.0086892, 0.02534988,-0.0880002,-0.00868926,0.02534988,-0.0880002,-0.00868926, 0.01671156,-0.0570699,-0.00856086,0.01671156,-0.0570699,-0.0085608, 0.01671156,-0.0570699,-0.00856086,0.02534988,-0.0880002,-0.0086892, 0.02534988,-0.0880002,-0.00868926,0.02534988,-0.0880002,-0.00868926 }; + std::vector expGradBBuff = { -0.0734389, -0.0734389, -0.0734389, -0.0717151, -0.0717151, -0.0717151, -0.0734389, -0.0734389, -0.0734389, -0.0717151, -0.0717151, -0.0717151, -0.00869156, -0.00869156, -0.00869156, -0.00856306, -0.00856306, -0.00856306, -0.00869156, -0.00869156, -0.00869156, -0.00856306, -0.00856306, -0.00856306 }; + std::vector stateBuff = { 1.028569, 1.028569, 1.028569, 1.112884, 1.112884, 1.112884, 1.028569, 1.028569, 1.028569, 1.112884,1.112884, 1.112884, 1.056905, 1.056905, 1.056905, 1.085009, 1.085009, 1.085009, 1.056905, 1.056905,1.056905, 1.085009, 1.085009, 1.085009, 1.085009, 1.085009, 1.085009, 1.056905, 1.056905, 1.056905,1.085009, 1.085009, 1.085009, 1.056905, 1.056905, 1.056905 }; + + auto input = NDArrayFactory::create('c', { N,bS,2 * K }); + auto weights = NDArrayFactory::create('c', { 2 * K,6 * K }); + auto bias = NDArrayFactory::create('c', { 4 * K }); + auto init = NDArrayFactory::create('c', { bS,2 * K }); + auto mask = NDArrayFactory::create('c', { bS,2 * K }); + NDArray state('c', { N,bS,2 * K }, stateBuff); + auto inGradCt = NDArrayFactory::create('c', { bS,2 * K }); + auto inGradH = NDArrayFactory::create('c', { N,bS,2 * K }); + + NDArray gradBias('c', { bS,4 * K }, expGradBBuff); + + NDArray expGradX('c', { N,bS,2 * K }, expGradXBuff); + NDArray expGradW('c', { N,2 * K,6 * K }, expGradWBuff); + auto expGradB = NDArrayFactory::create('c', { 4 * K }); + gradBias.reduceAlongDimension(reduce::Sum, expGradB, { 0 }); // [bS, 4K] -> [4K] + NDArray expGradInit('c', { bS,2 * K }, expGradInitBuff); + + input.assign(1.5); + weights.assign(0.5); + bias.assign(0.3); + mask.assign(1.); + init.assign(1.); + inGradCt.assign(0.5); + inGradH.assign(0.5); + + sd::ops::sru_bi_bp bp; + auto resultsBP = bp.evaluate({ &input, &weights, &bias, &init, &state, &inGradCt, &inGradH, &mask }, {}, {}); + ASSERT_TRUE(resultsBP.size() == 4); + + auto gradX = resultsBP.at(0); + auto gradW = resultsBP.at(1); + auto gradB = resultsBP.at(2); + auto gradInit = resultsBP.at(3); + + ASSERT_TRUE(expGradX.equalsTo(gradX)); + ASSERT_TRUE(expGradW.equalsTo(gradW)); + ASSERT_TRUE(expGradB.equalsTo(gradB)); + ASSERT_TRUE(expGradInit.equalsTo(gradInit)); + + +} + +TEST_F(DeclarableOpsTests1, ArgMax1) { + auto x = NDArrayFactory::create('c', { 3, 5 }); + x.linspace(1); + auto exp = NDArrayFactory::create('c', { 3 }); + exp.assign(4); + + sd::ops::argmax op; + + auto result = op.evaluate({ &x }, {}, { 1 }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(DeclarableOpsTests1, ArgMax2) { + auto x = NDArrayFactory::create('c', { 3, 5 }); + x.linspace(1); + auto exp = NDArrayFactory::create('c', { 5 }); + exp.assign(2); + + sd::ops::argmax op; + + auto result = op.evaluate({ &x }, {}, { 0 }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(DeclarableOpsTests1, ArgMax3) { + auto x = NDArrayFactory::create('c', { 3, 5 }); + auto dim = NDArrayFactory::create('c', { 1, 1 }, { 0. }); + x.linspace(1); + auto exp = NDArrayFactory::create('c', { 5 }); + exp.assign(2); + + sd::ops::argmax op; + + auto result = op.evaluate({ &x, &dim }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests1, ArgMax4) { + auto x = NDArrayFactory::create('c', { 3, 5 }); + auto dim = NDArrayFactory::create('c', { 1, 1 }, { 1 }); + x.linspace(1); + auto exp = NDArrayFactory::create('c', { 3 }); + exp.assign(4); + + sd::ops::argmax op; + + auto result = op.evaluate({ &x, &dim }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(DeclarableOpsTests1, ArgMax5) { + auto x = NDArrayFactory::create('c', { 3, 5 }); + auto dim = NDArrayFactory::create('c', { 1, 2 }, { 0, 1 }); + x.linspace(1); + auto exp = NDArrayFactory::create(14); + + + sd::ops::argmax op; + + auto result = op.evaluate({ &x, &dim }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests1, ArgMax6) { + auto x = NDArrayFactory::create('c', { 3, 4, 5 }); + auto dim = NDArrayFactory::create(-1.f); + x.linspace(1); + + + sd::ops::argmax op; + + auto expected = op.evaluate({ &x }, {}, { 2 }); + ASSERT_EQ(Status::OK(), expected.status()); + auto exp = expected.at(0); + + + auto result = op.evaluate({ &x, &dim }, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_EQ(*exp, *z); +} + + +TEST_F(DeclarableOpsTests1, ArgMin1) { + auto x = NDArrayFactory::create('c', { 3, 5 }); + x.linspace(1); + // auto exp('c', {3, 1}); + auto exp = NDArrayFactory::create('c', { 3 }); + exp.assign(0.0f); + + sd::ops::argmin op; + + auto result = op.evaluate({ &x }, {}, { 1 }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(DeclarableOpsTests1, SquareTests1) { + auto x = NDArrayFactory::create('c', { 3, 5 }); + x.linspace(1); + + auto exp = NDArrayFactory::create('c', { 3, 5 }); + exp.linspace(1); + exp *= exp; + + sd::ops::square op; + + auto result = op.evaluate({ &x }, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests1, OneHotTests_1) { + + auto indices = NDArrayFactory::create('c', { 1, 4 }, { 0.0f, 2.0f, -1.0f, 1.0f }); + + auto exp = NDArrayFactory::create('c', { 1, 4, 3 }, { 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f }); + + sd::ops::onehot op; + + auto result = op.evaluate({ &indices }, { 1.0f, 0.0f }, { -1, 3 }); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + // z->printBuffer(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests1, OneHotTests_2) { + auto indices = NDArrayFactory::create('c', { 2, 2 }, { 0.f, 2.f, 1.f, -1.f }); + + auto exp = NDArrayFactory::create('c', { 2, 2, 3 }, { 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f }); + + sd::ops::onehot op; + auto result = op.evaluate({ &indices }, { 1.0f, 0.0f }, { -1, 3 }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests1, OneHotTests_3) { + auto indices = NDArrayFactory::create('c', { 4 }, { 0.0f, 2.0f, -1.0f, 1.0f }); + + auto exp = NDArrayFactory::create('c', { 4, 3 }, { 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f }); + + sd::ops::onehot op; + + auto result = op.evaluate({ &indices }, { 1.0f, 0.0f }, { -1, 3 }); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + // z->printIndexedBuffer("z"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests1, OneHotTests_4) { + auto indices = NDArrayFactory::create('c', { 4 }, { 0.0f, 2.0f, -1.0f, 1.0f }); + auto depth = NDArrayFactory::create(3.0f); + + auto exp = NDArrayFactory::create('c', { 4, 3 }, { 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f }); + + sd::ops::onehot op; + + auto result = op.evaluate({ &indices, &depth }, { 1.0f, 0.0f }, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests1, OneHotTests_5) { + auto indices = NDArrayFactory::create('c', { 4 }, { 0.0f, 2.0f, -1.0f, 1.0f }); + auto depth = NDArrayFactory::create(3.0f); + auto on = NDArrayFactory::create(1.0f); + auto off = NDArrayFactory::create(0.0f); + + auto exp = NDArrayFactory::create('c', { 4, 3 }, { 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f }); + + sd::ops::onehot op; + + auto result = op.evaluate({ &indices, &depth, &on, &off }, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests1, OneHotTests_6) { + auto indices = NDArrayFactory::create('c', { 3 }, { 0.f, 1.f, 2.f }); + auto e = NDArrayFactory::create('c', { 3, 3 }, { 1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f }); + + sd::ops::onehot op; + auto result = op.evaluate({ &indices }, { 1.0, 0.0 }, { 0, 3 }); + auto z = result.at(0); + + ASSERT_EQ(e, *z); + + +} + +TEST_F(DeclarableOpsTests1, OneHotTests_7) { + auto indices = NDArrayFactory::create('c', { 3 }, { 0, 1, 2 }); + auto e = NDArrayFactory::create('c', { 3, 3 }, { 1., 0., 0., 0., 1., 0., 0., 0., 1. }); + + sd::ops::onehot op; + auto result = op.evaluate({ &indices }, { 1.0, 0.0 }, { 0, 3 }, {}, { sd::DataType::HALF }, false); + auto z = result.at(0); + + ASSERT_EQ(e, *z); + + +} + +TEST_F(DeclarableOpsTests1, FillAs_1) { + auto x = NDArrayFactory::create('c', { 2, 2 }); + x.assign(117); + + float scalar = 119.f; + + sd::ops::fill_as op; + auto result = op.evaluate({ &x }, { scalar }, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(x.isSameShape(result.at(0))); + + ASSERT_NEAR(scalar, result.at(0)->meanNumber().e(0), 1e-5f); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, LRN1) { + sd::ops::lrn lrn; + + lrn.getOpName(); +} + +TEST_F(DeclarableOpsTests1, Test_Range_Integer_1) { + auto exp = NDArrayFactory::create('c', { 4 }); + exp.linspace(1); + + sd::ops::range op; + + auto result = op.evaluate({}, {}, { 1, 5, 1 }); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_EQ(1, result.size()); + + auto array = result.at(0); + // array->printIndexedBuffer("Range integer 1"); + ASSERT_TRUE(exp.isSameShape(array)); + ASSERT_TRUE(exp.equalsTo(array)); + + +} + + +TEST_F(DeclarableOpsTests1, Test_Range_Integer_2) { + auto exp = NDArrayFactory::create('c', { 4 }); + exp.linspace(1); + + auto start = NDArrayFactory::create('c', { 1, 1 }); + auto stop = NDArrayFactory::create('c', { 1, 1 }); + auto step = NDArrayFactory::create('c', { 1, 1 }); + start.p(0, 1.f); + stop.p(0, 5.f); + step.p(0, 1.f); + + sd::ops::range op; + + auto result = op.evaluate({ &start, &stop, &step }, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_EQ(1, result.size()); + + auto array = result.at(0); + + ASSERT_TRUE(exp.isSameShape(array)); + ASSERT_TRUE(exp.equalsTo(array)); + + +} + + +TEST_F(DeclarableOpsTests1, Test_Range_Integer_3) { + auto exp = NDArrayFactory::create('c', { 4 }); + exp.linspace(1); + + sd::ops::range op; + + auto result = op.evaluate({}, { 1.f, 5.f, 1.f }, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_EQ(1, result.size()); + + auto array = result.at(0); + + ASSERT_TRUE(exp.isSameShape(array)); + ASSERT_TRUE(exp.equalsTo(array)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, softmax_test1) { + + NDArray input('c', { 3, 3 }, { -1.f, 1.f, -2.f, 2.f, -3.f, 3.f, -4.f, 4.f, 5.f }, sd::DataType::FLOAT32); + + NDArray expOutput('c', { 3, 3 }, { 1.14195199e-01, 8.43794734e-01, 4.20100661e-02, 2.68454951e-01, 1.80883523e-03, 7.29736214e-01, 9.02116571e-05, 2.68917160e-01, 7.30992629e-01 }, sd::DataType::FLOAT32); + + sd::ops::softmax op; + auto results = op.evaluate({ &input }, {}, {}, {}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, softmax_test2) { + NDArray input('c', { 3, 3, 3 }, { -1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14 }, sd::DataType::FLOAT32); + NDArray expOutput('c', { 3, 3, 3 }, { 4.73142e-02,4.73847e-02,6.69062e-03, 9.50330e-01,8.67881e-04,9.92976e-01, 2.35563e-03,9.51747e-01,3.33106e-04, 4.74259e-02,2.26032e-06,4.74259e-02, 2.91395e-07,9.99998e-01,3.94360e-08, 9.52574e-01,1.12535e-07,9.52574e-01, 7.58256e-10,4.74259e-02,1.22325e-11, 1.00000e+00,1.32293e-11,1.19203e-01, 3.77513e-11,9.52574e-01,8.80797e-01 }, sd::DataType::FLOAT32); + + sd::ops::softmax op; + auto results = op.evaluate({ &input }, {}, { 1 }, {}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, softmax_test3) { + NDArray input('c', { 3, 3, 3 }, { -1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14 }, sd::DataType::FLOAT32); + NDArray expOutput('c', { 3, 3, 3 }, { 2.47262e-03,1.23395e-04,3.35350e-04, 1.23395e-04,4.53979e-05,1.23395e-04, 6.14417e-06,1.23395e-04,5.56530e-09, 9.97527e-01,1.12521e-07,9.99665e-01, 1.52281e-08,9.99955e-01,2.06090e-09, 9.99994e-01,2.78912e-10,6.69285e-03, 3.05146e-07,9.99876e-01,4.13855e-08, 9.99877e-01,5.60254e-09,9.99877e-01, 7.58251e-10,9.99877e-01,9.93307e-01 }, sd::DataType::FLOAT32); + + sd::ops::softmax op; + auto results = op.evaluate({ &input }, {}, { 0 }, {}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, softmax_test4) { + NDArray input('c', { 1, 5 }, { -1, 1, -2, 2, 3 }, sd::DataType::FLOAT32); + NDArray expOutput('c', { 1, 5 }, { 0.01198,0.08855,0.00441,0.24072,0.65434 }, sd::DataType::FLOAT32); + + sd::ops::softmax op; + auto results = op.evaluate({ &input }, {}, { 1 }, {}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, softmax_test5) { + NDArray input('c', { 1, 5 }, { -1, 1, -2, 2, 3 }, sd::DataType::FLOAT32); + NDArray expOutput('c', { 1, 5 }, { 1,1,1,1,1 }, sd::DataType::FLOAT32); + + sd::ops::softmax op; + auto results = op.evaluate({ &input }, {}, { 0 }); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, softmax_test6) { + NDArray input('c', { 5, 1 }, { -1, 1, -2, 2, 3 }, sd::DataType::FLOAT32); + NDArray expOutput('c', { 5, 1 }, { 0.01198,0.08855,0.00441,0.24072,0.65434 }, sd::DataType::FLOAT32); + + sd::ops::softmax op; + auto results = op.evaluate({ &input }, {}, { 0 }, {}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, softmax_test7) { + NDArray input('c', { 5, 1 }, { -1, 1, -2, 2, 3 }, sd::DataType::FLOAT32); + NDArray expOutput('c', { 5, 1 }, { 1,1,1,1,1 }, sd::DataType::FLOAT32); + + sd::ops::softmax op; + auto results = op.evaluate({ &input }, {}, { 1 }, {}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, softmax_test8) { + NDArray input('c', { 5 }, { -1, 1, -2, 2, 3 }, sd::DataType::FLOAT32); + NDArray expOutput('c', { 5 }, { 0.01198,0.08855,0.00441,0.24072,0.65434 }, sd::DataType::FLOAT32); + + sd::ops::softmax op; + auto results = op.evaluate({ &input }, {}, {}, {}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, softmax_test9) { + NDArray input('c', { 2, 2, 2, 2 }, { -1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8 }, sd::DataType::FLOAT32); + NDArray expOutput('c', { 2, 2, 2, 2 }, { 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059 }, sd::DataType::FLOAT32); + + sd::ops::softmax op; + auto results = op.evaluate({ &input }, {}, { 2 }, {}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); + + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, softmax_test10) { + NDArray input('c', { 2, 2, 2, 2, 2 }, { -1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14, -14, 15, -15, 16,-16 }, sd::DataType::FLOAT32); + NDArray expOutput('c', { 2, 2, 2, 2, 2 }, { 0.119203, 0.880797, 0.017986, 0.982014, 0.002473, 0.997527, 0.000335, 0.999665, 0.000045, 0.999955, 0.000006, 0.999994, 0.000001, 0.999999, 0.000000, 1.000000, 0.000000, 1.000000, 0.000000, 1.000000, 0.000000, 1.000000, 0.000000, 1.000000, 0.000000, 1.000000, 1.000000, 0.000000, 1.000000, 0.000000, 1.000000, 0.00000 }, sd::DataType::FLOAT32); + + sd::ops::softmax op; + auto results = op.evaluate({ &input }, {}, { 4 }, {}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); + + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, softmax_test11) { + NDArray input('c', { 2, 2, 2, 2, 2, 2 }, { -1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14, -14, 15, -15, 16,-16, -2.1, 2.1, -2.2, 2.2, -2.3, 2.3, -2.4, 2.4, -2.5,2.5 ,-2.6,2.6, -2.7,2.7, -2.8,2.8, -2.9,2.9, -3.0,3.0, -3.1,3.1, -3.2,3.2, -3.3,3.3, 3.4, -3.4, 3.5, -3.5, 3.6,-3.6 }, sd::DataType::FLOAT32); + NDArray expOutput('c', { 2, 2, 2, 2, 2, 2 }, { 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, 0.000000, 1.000000, 1.000000, 0.000000, 0.268941, 0.731059, 0.731059, 0.268941, 0.524979, 0.475021, 0.475021, 0.524979, 0.524979, 0.475021, 0.475021, 0.524979, 0.524979, 0.475021, 0.475021, 0.524979, 0.524979, 0.475021, 0.475021, 0.524979, 0.524979, 0.475021, 0.475021, 0.524979, 0.524979, 0.475021, 0.475021, 0.524979, 0.001229, 0.998771, 0.998771, 0.001229, 0.475021, 0.524979, 0.524979, 0.475021 }, sd::DataType::FLOAT32); + + sd::ops::softmax op; + auto results = op.evaluate({ &input }, {}, { 4 }, {}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, softmax_test12) { + NDArray input('f', { 2, 2, 2, 2, 2, 2 }, { -1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14, -14, 15, -15, 16,-16, -2.1, 2.1, -2.2, 2.2, -2.3, 2.3, -2.4, 2.4, -2.5,2.5 ,-2.6,2.6, -2.7,2.7, -2.8,2.8, -2.9,2.9, -3.0,3.0, -3.1,3.1, -3.2,3.2, -3.3,3.3, 3.4, -3.4, 3.5, -3.5, 3.6,-3.6 }, sd::DataType::FLOAT32); + NDArray exp('c', { 2, 2, 2, 2, 2, 2 }, { 0.982014, 0.598688, 0.982014, 0.598688, 0.017986, 0.401312, 0.017986, 0.401312, 0.982014, 0.598688, 0.000000, 0.001359, 0.017986, 0.401312, 1.000000, 0.998641, 0.982014, 0.598688, 0.000000, 0.001659, 0.017986, 0.401312, 1.000000, 0.998341, 0.982014, 0.598688, 0.000000, 0.001113, 0.017986, 0.401312, 1.000000, 0.998887, 0.017986, 0.401312, 0.017986, 0.401312, 0.982014, 0.598688, 0.982014, 0.598688, 0.017986, 0.401312, 1.000000, 0.998641, 0.982014, 0.598688, 0.000000, 0.001359, 0.017986, 0.401312, 1.000000, 0.998341, 0.982014, 0.598688, 0.000000, 0.001659, 0.017986, 0.401312, 1.000000, 0.998887, 0.982014, 0.598688, 0.000000, 0.001113 }, sd::DataType::FLOAT32); + + auto expOutput = NDArray('f', { 2, 2, 2, 2, 2, 2 }, sd::DataType::FLOAT32); + expOutput.assign(exp); + + sd::ops::softmax op; + auto results = op.evaluate({ &input }, {}, { 3 }, {}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); + + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, Reverse_1) { + + float inBuff[] = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24 }; + float expBuff[] = { 24., 23., 22., 21., 20., 19., 18., 17., 16., 15., 14., 13., 12., 11., 10., 9., 8., 7., 6., 5., 4., 3., 2., 1. }; + Nd4jLong shapeInfo[] = { 3, 2, 3, 4, 12, 4, 1, 0, 1, 99 }; + ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); + + NDArray input(inBuff, shapeInfo); + NDArray expected(expBuff, shapeInfo); + NDArray output(shapeInfo); + + sd::ops::reverse op; + auto results = op.evaluate({ &input }, {}, { 0,1,2 }); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, Reverse_2) { + + float inBuff[] = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24 }; + float expBuff[] = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24 }; + Nd4jLong shapeInfo[] = { 3, 2, 3, 4, 12, 4, 1, 0, 1, 99 }; + ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); + + NDArray input(inBuff, shapeInfo); + NDArray expected(expBuff, shapeInfo); + NDArray output(shapeInfo); + + sd::ops::reverse op; + auto results = op.evaluate({ &input }, {}, {}, {}, {}, true); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(input)); + ASSERT_TRUE(expected.equalsTo(&input)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, Reverse_3) { + + float inBuff[] = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24 }; + float expBuff[] = { 12., 11., 10., 9., 8., 7., 6., 5., 4., 3., 2., 1., 24., 23., 22., 21., 20., 19., 18., 17., 16., 15., 14., 13. }; + Nd4jLong shapeInfo[] = { 3, 2, 3, 4, 12, 4, 1, 0, 1, 99 }; + ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); + + NDArray input(inBuff, shapeInfo); + NDArray expected(expBuff, shapeInfo); + NDArray output(shapeInfo); + + sd::ops::reverse op; + auto results = op.evaluate({ &input }, {}, { 1,2 }); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, Reverse_4) { + + float inBuff[] = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24 }; + float expBuff[] = { 16,15,14,13,20,19,18,17,24,23,22,21,4,3,2,1,8,7,6,5,12,11,10,9, }; + Nd4jLong shapeInfo[] = { 3, 2, 3, 4, 12, 4, 1, 0, 1, 99 }; + ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); + + NDArray input(inBuff, shapeInfo); + NDArray expected(expBuff, shapeInfo); + NDArray output(shapeInfo); + + sd::ops::reverse op; + auto results = op.evaluate({ &input }, {}, { 0,2 }); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, Reverse_5) { + + float inBuff[] = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24 }; + float expBuff[] = { 21., 22., 23., 24., 17., 18., 19., 20., 13., 14., 15., 16., 9., 10., 11., 12., 5., 6., 7., 8., 1., 2., 3., 4. }; + Nd4jLong shapeInfo[] = { 3, 2, 3, 4, 12, 4, 1, 0, 1, 99 }; + ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); + + NDArray input(inBuff, shapeInfo); + NDArray expected(expBuff, shapeInfo); + NDArray output(shapeInfo); + + sd::ops::reverse op; + auto results = op.evaluate({ &input }, {}, { 0,1 }); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, Reverse_6) { + + float inBuff[] = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24 }; + float expBuff[] = { 4., 3., 2., 1., 8., 7., 6., 5., 12., 11., 10., 9., 16., 15., 14., 13., 20., 19., 18., 17., 24., 23., 22., 21. }; + Nd4jLong shapeInfo[] = { 3, 2, 3, 4, 12, 4, 1, 0, 1, 99 }; + ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); + + NDArray input(inBuff, shapeInfo); + NDArray expected(expBuff, shapeInfo); + NDArray output(shapeInfo); + + sd::ops::reverse op; + auto results = op.evaluate({ &input }, {}, { 2 }, {}, {}, true); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(input)); + ASSERT_TRUE(expected.equalsTo(&input)); + + +} + + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, Reverse_7) { + + float inBuff[] = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24 }; + float expBuff[] = { 9., 10., 11., 12., 5., 6., 7., 8., 1., 2., 3., 4., 21., 22., 23., 24., 17., 18., 19., 20., 13., 14., 15., 16. }; + Nd4jLong shapeInfo[] = { 3, 2, 3, 4, 12, 4, 1, 0, 1, 99 }; + ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); + + NDArray input(inBuff, shapeInfo); + NDArray expected(expBuff, shapeInfo); + NDArray output(shapeInfo); + + sd::ops::reverse op; + auto results = op.evaluate({ &input }, {}, { 1 }); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + //expected.printIndexedBuffer("E"); + //result->printIndexedBuffer("R"); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + + + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, Reverse_8) { + + float inBuff[] = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24 }; + float expBuff[] = { 12., 11., 10., 9., 8., 7., 6., 5., 4., 3., 2., 1., 24., 23., 22., 21., 20., 19., 18., 17., 16., 15., 14., 13. }; + Nd4jLong shapeInfo[] = { 3, 2, 3, 4, 12, 4, 1, 0, 1, 99 }; + ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); + + NDArray input(inBuff, shapeInfo); + NDArray expected(expBuff, shapeInfo); + NDArray output(shapeInfo); + + sd::ops::reverse op; + auto results = op.evaluate({ &input }, {}, { 2,1 }); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, Reverse_9) { + + float inBuff[] = { 1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24 }; + float expBuff[] = { 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }; + Nd4jLong shapeInfo[] = { 3, 2, 3, 4, 12, 4, 1, 0, 1, 99 }; + ArrayOptions::setDataType(shapeInfo, sd::DataType::FLOAT32); + + NDArray input(inBuff, shapeInfo); + NDArray expected(expBuff, shapeInfo); + NDArray output(shapeInfo); + + sd::ops::reverse op; + auto results = op.evaluate({ &input }, {}, { 0 }); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +TEST_F(DeclarableOpsTests1, Reverse_10) { + auto x = NDArrayFactory::create('c', { 4, 3 }, { 1.5375735, 0.1592365, 0.09966054, 0.677872, 1.144433, -1.0355669, 0.48456487, -0.67863184, 0.85020787, 0.13950661, 0.20998026, -1.1660044 }); + auto i = NDArrayFactory::create('c', { 1 }, { -1 }); + auto e = NDArrayFactory::create('c', { 4, 3 }, { 0.09966054, 0.1592365, 1.5375735, -1.0355669, 1.144433, 0.677872,0.85020787, -0.67863184, 0.48456487, -1.1660044, 0.20998026, 0.13950661 }); + + sd::ops::reverse op; + auto result = op.evaluate({ &x, &i }, {}, {}, {}); + + auto z = result.at(0); + + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, Reverse_11) { + + + auto input = NDArrayFactory::create('c', { 2,3,4 }); + auto expected = NDArrayFactory::create('c', { 2,3,4 }, { 24.f, 23.f, 22.f, 21.f, 20.f, 19.f, 18.f, 17.f, 16.f, + 15.f, 14.f, 13.f, 12.f, 11.f, 10.f, 9.f, 8.f, 7.f, + 6.f, 5.f, 4.f, 3.f, 2.f, 1.f }); + + input.linspace(1); + sd::ops::reverse op; + auto results = op.evaluate({ &input }, {}, { 0, 1, 2 }); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, Reverse_12) { + + + auto input = NDArrayFactory::create({ 0.f, 1.f, 2.f, 3.f, 4.f }); + auto expected = NDArrayFactory::create({ 4.f, 3.f, 2.f, 1.f, 0.f }); + + //input.linspace(1); + sd::ops::reverse op; + auto results = op.evaluate({ &input }, {}, { 0 }); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + //result->printIndexedBuffer("Result reverse"); + //expected.printIndexedBuffer("Expected reverse"); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, Reverse_13) { + + + auto input = NDArrayFactory::create({ 0.f, 1.f, 2.f, 3.f, 4.f }); + auto expected = NDArrayFactory::create({ 4.f, 3.f, 2.f, 1.f, 0.f }); + + //input.linspace(1); + sd::ops::reverse op; + auto results = op.evaluate({ &input }, {}, { -1 }); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests1, Reverse_14) { + + + auto input = NDArrayFactory::create({ 0.f, 1.f, 2.f, 3.f, 4.f }); + auto expected = NDArrayFactory::create({ 0.f, 1.f, 2.f, 3.f, 4.f }); + + //input.linspace(1); + sd::ops::reverse op; + auto results = op.evaluate({ &input }, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +TEST_F(DeclarableOpsTests1, Test_Expose_1) { + auto input0 = NDArrayFactory::create('c', { 2, 3 }, { 1, 2, 3, 6, 5, 4 }); + auto input1 = NDArrayFactory::create('c', { 2, 3 }, { 3, 2, 1, 4, 5, 6 }); + + sd::ops::expose op; + + auto result = op.evaluate({ &input0, &input1 }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z0 = result.at(0); + auto z1 = result.at(1); + + ASSERT_TRUE(input0.equalsTo(z0)); + ASSERT_TRUE(input1.equalsTo(z1)); + + +} + +TEST_F(DeclarableOpsTests1, Test_Expose_2) { + auto list = new NDArrayList(0, true); + + auto var = new Variable(nullptr, "arraylist", -1, 0); + var->setNDArrayList(list); + + VariableSpace variableSpace; + variableSpace.putVariable(-1, var); + variableSpace.trackList(list); + + Context block(1, &variableSpace); + block.pickInput(-1); + + sd::ops::expose op; + auto result = op.execute(&block); + + ASSERT_EQ(ND4J_STATUS_OK, result); + ASSERT_TRUE(variableSpace.hasVariable(1)); + + auto var1 = variableSpace.getVariable(1); + + ASSERT_EQ(var->variableType(), var1->variableType()); + + auto list1 = var1->getNDArrayList(); + + ASSERT_TRUE(list == list1); + +} + +TEST_F(DeclarableOpsTests1, Test_Release) { + auto x = NDArrayFactory::create('c', { 8, 8 }); + // x.printShapeInfo("x shape"); +} diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests10.cpp new file mode 100644 index 000000000..01403e968 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -0,0 +1,3238 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + + +// +// Created by raver on 8/4/2018. +// + +#include "testlayers.h" +#include +#include +#include +#include + + +using namespace sd; + + +class DeclarableOpsTests10 : public testing::Test { +public: + + DeclarableOpsTests10() { + printf("\n"); + fflush(stdout); + } +}; + +template +class TypedDeclarableOpsTests10 : public testing::Test { +public: + + TypedDeclarableOpsTests10() { + printf("\n"); + fflush(stdout); + } +}; + +typedef ::testing::Types TestingTypes; +TYPED_TEST_CASE(TypedDeclarableOpsTests10, TestingTypes); + +TEST_F(DeclarableOpsTests10, Test_ArgMax_1) { + auto x = NDArrayFactory::create('c', {3, 3}); + auto e = NDArrayFactory::create(8); + + x.linspace(1.0); + + + sd::ops::argmax op; + auto result = op.evaluate({&x}); + ASSERT_EQ(Status::OK(), result.status()); + + + auto z = *result.at(0); + + ASSERT_EQ(e, z); +} + +TEST_F(DeclarableOpsTests10, Test_ArgMax_2) { + auto x = NDArrayFactory::create('c', {3, 3}); + auto y = NDArrayFactory::create('c', {1}, {1}); + auto e = NDArrayFactory::create('c', {3}, {2, 2, 2}); + + x.linspace(1.0); + + sd::ops::argmax op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = *result.at(0); + + //z.printIndexedBuffer("z"); + //z.printShapeInfo("z shape"); + + ASSERT_EQ(e, z); +} + +TEST_F(DeclarableOpsTests10, Test_And_1) { + auto x = NDArrayFactory::create('c', {4}, {1, 1, 0, 1}); + auto y = NDArrayFactory::create('c', {4}, {0, 0, 0, 1}); + auto e = NDArrayFactory::create('c', {4}, {0, 0, 0, 1}); + + sd::ops::boolean_and op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_EQ(e, *result.at(0)); +} + +TEST_F(DeclarableOpsTests10, Test_Or_1) { + auto x = NDArrayFactory::create('c', {4}, {1, 1, 0, 1}); + auto y = NDArrayFactory::create('c', {4}, {0, 0, 0, 1}); + auto e = NDArrayFactory::create('c', {4}, {1, 1, 0, 1}); + + sd::ops::boolean_or op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_EQ(e, *result.at(0)); +} + +TEST_F(DeclarableOpsTests10, Test_Not_1) { + auto x = NDArrayFactory::create('c', {4}, {true, true, false, true}); + auto y = NDArrayFactory::create('c', {4}, {false, false, false, true}); +// auto e = NDArrayFactory::create('c', {4}, {1, 1, 1, 0}); + auto e = NDArrayFactory::create('c', {4}, {false, false, true, false}); + + sd::ops::boolean_not op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); + auto res = result.at(0); + + ASSERT_TRUE(e.equalsTo(res)); +} + +TEST_F(DeclarableOpsTests10, Test_Size_at_1) { + auto x = NDArrayFactory::create('c', {10, 20, 30}); + auto e = NDArrayFactory::create(20); + + sd::ops::size_at op; + auto result = op.evaluate({&x}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_EQ(e, *result.at(0)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, MirrorPad_SGO_Test_1) { + + auto in = NDArrayFactory::create({1., 2., 3., 4., 5.}); +// auto pad('c', {1, 2}, {1., 1.});// = Nd4j.create(new double[]{1, 1}, new long[]{1, 2}); + auto pad = NDArrayFactory::create('c', {1, 2}, {1, 1}); +// auto value(10.0); + + auto exp = NDArrayFactory::create({2., 1., 2., 3., 4., 5., 4.}); + + sd::ops::mirror_pad op; + + auto res = op.evaluate({&in, &pad}, {10.0}, {0}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + + ASSERT_TRUE(exp.equalsTo(res.at(0))); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, Unique_SGO_Test_1) { + auto input = NDArrayFactory::create({3., 4., 3., 1., 3., 0., 2., 4., 2., 4.}); + auto expIdx = NDArrayFactory::create({0, 1, 0, 2, 0, 3, 4, 1, 4, 1}); + auto exp = NDArrayFactory::create({3., 4., 1., 0., 2.}); + + sd::ops::unique op; + auto res = op.evaluate({&input}, {}, {}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto res1 = res.at(0); + auto res2 = res.at(1); + + ASSERT_TRUE(exp.equalsTo(res1)); + ASSERT_TRUE(expIdx.equalsTo(res2)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, Where_SGO_Test_1) { + auto input = NDArrayFactory::create('c', {3, 3}, {true, false, false, true, true, false, true, true, true}); + //auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); + auto exp = NDArrayFactory::create('c', {6, 2}, {0LL, 0LL, 1LL, 0LL, 1LL, 1LL, 2LL, 0LL, 2LL, 1LL, 2LL, 2LL}); + + sd::ops::Where op; + auto res = op.evaluate({&input}, {}, {}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + auto resA = res.at(0); + + ASSERT_TRUE(exp.isSameShape(resA)); + ASSERT_TRUE(exp.equalsTo(resA)); +// ASSERT_TRUE(expIdx.equalsTo(res.at(1))); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, Where_SGO_Test_02) { + auto input = NDArrayFactory::create('c', {2, 2, 2}, {true, false, false, true, true, true, true, false}); + //auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); + auto exp = NDArrayFactory::create('c', {5, 3}, {0LL, 0LL, 0LL, 0LL, 1LL, 1LL, 1LL, 0LL, 0LL, 1LL, 0LL, 1LL, 1LL, 1LL, 0LL}); + + sd::ops::Where op; + auto res = op.evaluate({&input}, {}, {}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + auto resA = res.at(0); + + ASSERT_TRUE(exp.equalsTo(resA)); + ASSERT_TRUE(exp.isSameShape(resA)); +// ASSERT_TRUE(expIdx.equalsTo(res.at(1))); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_1) { + auto cond3d = NDArrayFactory::create('c', {2, 2, 2}, {true, false, false, true, true, true, true, false}); +// auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); + auto exp1 = NDArrayFactory::create({0, 0, 1, 1, 1}); + auto exp2 = NDArrayFactory::create({0, 1, 0, 0, 1}); + auto exp3 = NDArrayFactory::create({0, 1, 0, 1, 0}); + sd::ops::where_np op; + auto res = op.evaluate({&cond3d}, {}, {}); + ASSERT_TRUE(res.size() == 3); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto res1 = res.at(0); + auto res2 = res.at(1); + auto res3 = res.at(2); +// res1->printShapeInfo("Res1 shape"); res1->printBuffer("Res1"); +// res2->printShapeInfo("Res2 shape"); res2->printBuffer("Res2"); +// res3->printShapeInfo("Res3 shape"); res3->printBuffer("Res3"); + ASSERT_TRUE(exp1.equalsTo(res1)); + ASSERT_TRUE(exp2.equalsTo(res2)); + ASSERT_TRUE(exp3.equalsTo(res3)); + //ASSERT_TRUE(expIdx.equalsTo(res.at(1))); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_2) { + auto cond2d = NDArrayFactory::create('c', {3, 5}, {true, true, false, false, true, true, true, + true, true, true, false, true, true, true, true}); +// auto expIdx({0, 1, 0, 2, 0, 3, 4, 1, 4, 1}); + auto exp1 = NDArrayFactory::create({0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2}); + auto exp2 = NDArrayFactory::create({0, 1, 4, 0, 1, 2, 3, 4, 1, 2, 3, 4}); + sd::ops::where_np op; + auto res = op.evaluate({&cond2d}, {}, {}); + ASSERT_TRUE(res.size() == 2); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + ASSERT_TRUE(exp1.equalsTo(res.at(0))); + ASSERT_TRUE(exp2.equalsTo(res.at(1))); + //ASSERT_TRUE(expIdx.equalsTo(res.at(1))); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, Where_SGO_Test_2) { + auto input = NDArrayFactory::create({true, false, true, true, true}); + //auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); + auto exp = NDArrayFactory::create('c', {4,1}, {0, 2, 3, 4}); + + sd::ops::Where op; + auto res = op.evaluate({&input}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + auto resA = res.at(0); +// resA->printIndexedBuffer("Result A"); +// resA->printShapeInfo("ShapeA"); + ASSERT_TRUE(exp.equalsTo(resA)); + ASSERT_TRUE(exp.isSameShape(resA)); +// ASSERT_TRUE(expIdx.equalsTo(res.at(1))); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, Where_SGO_Test_3) { + auto input = NDArrayFactory::create('c', {5, 1}, {true, false, true, true, true}); + //auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); + auto exp = NDArrayFactory::create('c', {4, 2}, {0, 0, 2, 0, 3, 0, 4, 0}); + + sd::ops::Where op; + auto res = op.evaluate({&input}, {}, {}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + auto resA = res.at(0); + //resA->printIndexedBuffer("Result A"); + //resA->printShapeInfo("ShapeA"); + ASSERT_TRUE(exp.equalsTo(resA)); + ASSERT_TRUE(exp.isSameShape(resA)); +// ASSERT_TRUE(expIdx.equalsTo(res.at(1))); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, Where_SGO_Test_4) { + auto input = NDArrayFactory::create('c', {5, 1}, {false, false, false, false, false}); + //auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); + auto exp = NDArrayFactory::create('c', {4, 2}, {0, 0, 2, 0, 3, 0, 4, 0}); + + sd::ops::Where op; + auto res = op.evaluate({&input}, {}, {}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + auto resA = res.at(0); + ASSERT_TRUE(resA->isEmpty()); + //resA->printIndexedBuffer("Result A"); + //resA->printShapeInfo("ShapeA"); + //ASSERT_TRUE(exp.equalsTo(resA)); + //ASSERT_TRUE(exp.isSameShape(resA)); +// ASSERT_TRUE(expIdx.equalsTo(res.at(1))); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, Where_SGO_Test_5) { + auto input = NDArrayFactory::create('c', {5}, {1, 0, 0, 2, 3}); + //auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); + auto exp = NDArrayFactory::create('c', {3, 1}, {0, 3, 4}); + + sd::ops::Where op; + auto res = op.evaluate({&input}, {}, {}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + auto resA = res.at(0); + //ASSERT_TRUE(resA->isEmpty()); + + ASSERT_TRUE(exp.equalsTo(resA)); + ASSERT_TRUE(exp.isSameShape(resA)); +// ASSERT_TRUE(expIdx.equalsTo(res.at(1))); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, WhereNP_SGO_Test_4) { + auto input = NDArrayFactory::create('c', {5, 1}, {false, false, false, false, false}); + //auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); + auto exp = NDArrayFactory::create('c', {4, 2}, {0, 0, 2, 0, 3, 0, 4, 0}); + + sd::ops::where_np op; + auto res = op.evaluate({&input}, {}, {}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + auto resA = res.at(0); + ASSERT_TRUE(resA->isEmpty()); + //resA->printIndexedBuffer("Result A"); + //resA->printShapeInfo("ShapeA"); + //ASSERT_TRUE(exp.equalsTo(resA)); + //ASSERT_TRUE(exp.isSameShape(resA)); +// ASSERT_TRUE(expIdx.equalsTo(res.at(1))); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, CosineDistance_SGO_Test_1) { + auto labels = NDArrayFactory::create('c', {2, 3}, {1.0, 2.0, 3.0, -1.0, 2.0, 1.0}); + //auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); + auto predictions = NDArrayFactory::create('c', {2, 3}, {-0.3, -0.2, -0.1, 0, 0.1, 0.2}); + auto weights = NDArrayFactory::create('c', {2, 1}, {0., 1.}); + auto exp = NDArrayFactory::create(0.6); + + sd::ops::cosine_distance_loss op; + auto res = op.evaluate({&predictions, &weights, &labels}, {}, {3, 1}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + auto resA = res.at(0); + + ASSERT_TRUE(exp.equalsTo(resA)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, CosineDistance_SGO_Test_2) { + auto labels = NDArrayFactory::create('c', {2, 3}, {1.0, 2.0, 3.0, -1.0, 2.0, 1.0}); + //auto expIdx({0., 1., 0., 2., 0., 3., 4., 1., 4., 1.}); + auto predictions = NDArrayFactory::create('c', {2, 3}, {-0.3, -0.2, -0.1, 0, 0.1, 0.2}); + auto weights = NDArrayFactory::create('c', {2, 1}, {0., 1.}); + auto exp = NDArrayFactory::create(0.6); + + sd::ops::cosine_distance_loss op; + auto res = op.evaluate({&predictions, &weights, &labels}, {}, {2, 1}); + ASSERT_TRUE(res.status() == ND4J_STATUS_OK); + auto resA = res.at(0); + + ASSERT_TRUE(exp.equalsTo(resA)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, TestMarixBandPart_Test_1) { + + auto x = NDArrayFactory::create('c', {2, 3, 3}); + + auto exp = NDArrayFactory::create('c', {2, 3, 3}); + x.linspace(1); + exp.linspace(1); + exp.p(0, 0, 2, 0.); + exp.p(1, 0, 2, 0.); + exp.p(0, 2, 0, 0.); + exp.p(1, 2, 0, 0.); + + sd::ops::matrix_band_part op; + auto results = op.evaluate({&x}, {}, {1, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + //results.at(0)->printIndexedBuffer("MBP Test1"); + //exp.printIndexedBuffer("MBP Expec"); + ASSERT_TRUE(exp.equalsTo(results.at(0))); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, TestMarixBandPart_Test_2) { + + auto x = NDArrayFactory::create('c', {2, 3, 3}); + auto minD = NDArrayFactory::create(1); + auto maxD = NDArrayFactory::create(1); + auto exp = NDArrayFactory::create('c', {2, 3, 3}); + x.linspace(1); + exp.linspace(1); + exp.p(0, 0, 2, 0.); + exp.p(1, 0, 2, 0.); + exp.p(0, 2, 0, 0.); + exp.p(1, 2, 0, 0.); + + sd::ops::matrix_band_part op; + auto results = op.evaluate({&x, &minD, &maxD}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + //results.at(0)->printIndexedBuffer("MBP Test1"); + //exp.printIndexedBuffer("MBP Expec"); + ASSERT_TRUE(exp.equalsTo(results.at(0))); +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, atan2_test1) { + + auto y = NDArrayFactory::create('c', {2, 3, 4}, {-1.001 ,-0.915 ,-0.829 ,-0.743 ,-0.657 ,-0.571 ,-0.485 ,-0.399 ,-0.313 ,-0.227 ,-0.141 ,-0.055 ,0.031 ,0.117 ,0.203 ,0.289 ,0.375 ,0.461 ,0.547 ,0.633 ,0.719 ,0.805 ,0.891 ,0.977}); + auto x = NDArrayFactory::create('c', {2, 3, 4}, {-0.51, -0.46, -0.41, -0.36, -0.31, -0.26, -0.21, -0.16, -0.11, -0.06, -0.01, 0.04, 0.09, 0.14, 0.19, 0.24, 0.29, 0.34, 0.39, 0.44, 0.49, 0.54, 0.59, 0.61}); + + auto exp = NDArrayFactory::create('c', {2,3,4}, {-2.04201, -2.03663, -2.03009, -2.02199,-2.01166, -1.99808, -1.97941, -1.95217,-1.90875, -1.8292 , -1.6416 , -0.942 , + 0.33172, 0.69614, 0.81846, 0.87776, 0.91253, 0.93533, 0.95141, 0.96336, 0.97259, 0.97993, 0.98591, 1.01266,}); + + sd::ops::tf_atan2 op; + auto result = op.evaluate({&y, &x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, atan2_test2) { + + auto y = NDArrayFactory::create('c', {2, 3, 4}, {-1.001 ,-0.915 ,-0.829 ,-0.743 ,-0.657 ,-0.571 ,-0.485 ,-0.399 ,-0.313 ,-0.227 ,-0.141 ,-0.055 ,0.031 ,0.117 ,0.203 ,0.289 ,0.375 ,0.461 ,0.547 ,0.633 ,0.719 ,0.805 ,0.891 ,0.977}); + auto x = NDArrayFactory::create('c', { 3, 4}, {-1.05, -0.82, -0.639, -0.458, -0.277, -0.096, 0.085, 0.266, 0.447, 0.628, 0.809, 0.99}); + + auto exp = NDArrayFactory::create('c', {2,3,4}, {-2.38008, -2.30149, -2.22748, -2.1232 ,-1.96979, -1.73736, -1.3973 , -0.98279,-0.61088, -0.34685, -0.17256, -0.0555 , + 3.11208, 2.99987, 2.83399, 2.57869, 2.207 , 1.77611, 1.41664, 1.17298, 1.01458, 0.90829, 0.8336 , 0.77879}); + + sd::ops::tf_atan2 op; + auto result = op.evaluate({&y, &x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + // z->printIndexedBuffer(); + + // x.applyTrueBroadcast(sd::BroadcastOpsTuple::custom(scalar::Atan2, pairwise::Atan2, broadcast::Atan2), &y, &z, true); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, atan2_test3) { + + auto y = NDArrayFactory::create('c', {2, 3, 4}, {-1.001 ,-0.915 ,-0.829 ,-0.743 ,-0.657 ,-0.571 ,-0.485 ,-0.399 ,-0.313 ,-0.227 ,-0.141 ,-0.055 ,0.031 ,0.117 ,0.203 ,0.289 ,0.375 ,0.461 ,0.547 ,0.633 ,0.719 ,0.805 ,0.891 ,0.977}); + auto x = NDArrayFactory::create('c', { 3, 4}, {-1.05, -0.82, -0.639, -0.458, -0.277, -0.096, 0.085, 0.266, 0.447, 0.628, 0.809, 0.99}); + + auto exp = NDArrayFactory::create('c', {2,3,4}, {-2.33231, -2.41089, -2.48491, -2.58919,-2.74259, -2.97502, 2.9681 , 2.55359, 2.18167, 1.91765, 1.74335, 1.62629, + -1.54128, -1.42907, -1.2632 , -1.00789,-0.63621, -0.20531, 0.15416, 0.39782, 0.55622, 0.6625 , 0.7372 , 0.79201}); + + sd::ops::tf_atan2 op; + auto result = op.evaluate({&x, &y}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, atan2_test4) { + + auto y = NDArrayFactory::create('c', {1, 3, 4}, {-1.001 ,-0.829 ,-0.657 ,-0.485 ,-0.313 ,-0.141 ,0.031 ,0.203 ,0.375 ,0.547 ,0.719 ,0.891}); + auto x = NDArrayFactory::create('c', {2, 3, 1}, {-0.82, -0.458, -0.096, 0.085, 0.447, 0.809}); + + auto exp = NDArrayFactory::create('c', {2,3,4}, {-2.45527, -2.36165, -2.24628, -2.10492,-2.1703 , -1.86945, -1.50321, -1.15359,-0.25062, -0.17373, -0.13273, -0.10733, + 3.05688, 3.03942, 3.01293, 2.9681 , 2.18167, 1.87635, 1.50156, 1.14451, 1.13674, 0.97626, 0.84423, 0.7372 }); + + sd::ops::tf_atan2 op; + auto result = op.evaluate({&x, &y}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, atan2_test5) { + + auto y = NDArrayFactory::create('c', {1, 3, 4}, {-1.001 ,-0.829 ,-0.657 ,-0.485 ,-0.313 ,-0.141 ,0.031 ,0.203 ,0.375 ,0.547 ,0.719 ,0.891}); + auto x = NDArrayFactory::create('c', {2, 3, 1}, {-0.82, -0.458, -0.096, 0.085, 0.447, 0.809}); + + auto exp = NDArrayFactory::create('c', {2,3,4}, {-2.25712, -2.35074, -2.46611, -2.60747,-2.54209, -2.84294, 3.07401, 2.72438, 1.82141, 1.74453, 1.70353, 1.67813, + -1.48608, -1.46862, -1.44214, -1.3973 ,-0.61088, -0.30556, 0.06924, 0.42629, 0.43405, 0.59453, 0.72657, 0.8336 }); + + sd::ops::tf_atan2 op; + auto result = op.evaluate({&y, &x}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, atan2_test6) { + + auto y = NDArrayFactory::create('c', {1, 3, 4}, {-1.001 ,-0.829 ,-0.657 ,-0.485 ,-0.313 ,-0.141 ,0.031 ,0.203 ,0.375 ,0.547 ,0.719 ,0.891}); + auto x = NDArrayFactory::create('c', { 4}, {-0.82, -0.096, 0.085, 0.809}); + + auto exp = NDArrayFactory::create('c', {1,3,4}, {-2.25712, -1.68608, -1.44214, -0.54006,-2.77695, -2.16855, 0.34972, 0.24585, 2.71267, 1.74453, 1.45312, 0.8336 }); + + sd::ops::tf_atan2 op; + auto result = op.evaluate({&y, &x}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, IGamma_Test1) { + + auto y = NDArrayFactory::create('c', {1, 3, 4}, {1.1 , 2.1 , 3.1 ,4.1 , 5.1 , 6.1 ,7.1 ,8.1 ,9.1 ,10.1,11.1 ,12.1}); + auto x = NDArrayFactory::create('c', { 4}, {1.2, 2.2, 3.2, 4.2}); + + auto exp = NDArrayFactory::create('c', {1,3,4}, { + 0.659917, 0.61757898, 0.59726304, 0.58478117, + 0.0066205109, 0.022211598, 0.040677428, 0.059117373, + 0.0000039433403, 0.000086064574, 0.000436067, 0.0012273735}); + + sd::ops::igamma op; + auto result = op.evaluate({&y, &x}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); +// z->printBuffer("OUtput"); +// exp.printBuffer("EXpect"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, IGamma_Test2) { + + auto y = NDArrayFactory::create('c', {1, 3, 4}, {1.1 , 2.1 , 3.1 ,4.1 , 5.1 , 6.1 , + 7.1 ,8.1 ,9.1 ,10.1,11.1 ,12.1}); + auto x = NDArrayFactory::create('c', { 4}, {1.2, 2.2, 3.2, 4.2}); + auto exp = NDArrayFactory::create('c', {1,3,4}, {0.340083, 0.382421, 0.402737, 0.415221, + 0.993379, 0.977788, 0.959323, 0.940883, + 0.999996, 0.999914, 0.999564, 0.998773}); + + sd::ops::igammac op; + auto result = op.evaluate({&y, &x}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); +// z->printBuffer("OUtput"); +// exp.printBuffer("EXpect"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, LGamma_Test1) { + + auto x = NDArrayFactory::create('c', {3, 3}, {0.1, 0.5, 0.7, 1.5, 1.7, 2.0, 2.5, 2.7, 3.}); + + auto exp = NDArrayFactory::create('c', {3,3}, { + 2.2527127 , 0.5723649 , 0.26086727, + -0.12078223, -0.09580769, 0., + 0.28468287, 0.4348206 , 0.6931472 + }); + + sd::ops::lgamma op; + auto result = op.evaluate({&x}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); +// z->printBuffer("OUtput"); +// exp.printBuffer("EXpect"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, range_test10) { + + auto limit = NDArrayFactory::create('c', {1, 3, 4}); + limit = 5.; + auto exp = NDArrayFactory::create('c', {5}, {0.,1.,2.,3.,4.}); + + sd::ops::range op; + auto result = op.evaluate({&limit}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, range_test11) { + + auto limit = NDArrayFactory::create('c', {1, 3, 4}); + auto start = NDArrayFactory::create('c', {2, 4}); + limit = 5.; + start = 0.5; + auto exp = NDArrayFactory::create('c', {5}, {0.5,1.5,2.5,3.5,4.5}); + + sd::ops::range op; + auto result = op.evaluate({&start, &limit}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, range_test12) { + + auto exp = NDArrayFactory::create('c', {9}, {0.5f, 1.f , 1.5f, 2.f , 2.5f, 3.f , 3.5f, 4.f , 4.5f}); + + sd::ops::range op; + auto result = op.evaluate({}, {0.5, 5, 0.5}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, top_k_permuted_test1) { + + auto x = NDArrayFactory::create({7., 3., 1., 2., 5., 0., 4., 6., 9., 8.}); + auto expUnsorted = NDArrayFactory::create({7., 6., 9., 8.}); // Sorted = False + auto expSorted = NDArrayFactory::create({9., 8., 7., 6., 5.}); // Sorted = False + + + sd::ops::top_k op; + auto result = op.evaluate({&x}, {}, {4}, {false}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + auto zI = result.at(1); + + ASSERT_TRUE(expUnsorted.isSameShape(z)); + ASSERT_TRUE(expUnsorted.equalsTo(z)); + + auto result2 = op.evaluate({&x}, {}, {5}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, result2.status()); + + z = result2.at(0); + zI = result2.at(1); + + ASSERT_TRUE(expSorted.isSameShape(z)); + ASSERT_TRUE(expSorted.equalsTo(z)); +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, top_k_permuted_test2) { + + auto x = NDArrayFactory::create({7., 3., 1., 2., 5., 0., 4., 6., 9., 8.}); + auto expUnsorted = NDArrayFactory::create({7., 5., 6., 9., 8.}); // Sorted = False + auto expSorted = NDArrayFactory::create({9., 8., 7., 6., 5.}); // Sorted = False + + + sd::ops::top_k op; + auto result = op.evaluate({&x}, {}, {5}, {false}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + auto zI = result.at(1); + + ASSERT_TRUE(expUnsorted.isSameShape(z)); + ASSERT_TRUE(expUnsorted.equalsTo(z)); + + auto result2 = op.evaluate({&x}, {}, {5}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, result2.status()); + + z = result2.at(0); + zI = result2.at(1); + + ASSERT_TRUE(expSorted.isSameShape(z)); + ASSERT_TRUE(expSorted.equalsTo(z)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test1) { + + auto labels = NDArrayFactory::create('c', {2,3},{3, 2, 1, 0, 1, 2}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto expected = NDArrayFactory::create('c', {2,3}, {1.24254, 1.34254, 1.44254, 1.54254, 1.44254, 1.34254}); + + logits.linspace(0.1, 0.1); + + sd::ops::sparse_softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&labels, &logits}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test2) { + + auto labels = NDArrayFactory::create('c', {2},{1, 0}); + auto logits = NDArrayFactory::create('c', {2,3}); + auto expected = NDArrayFactory::create('c', {2}, {1.10194, 1.20194}); + + logits.linspace(0.1, 0.1); + + sd::ops::sparse_softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&labels, &logits}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test3) { + + NDArray labels('c', {1}, std::vector{0}, sd::DataType::INT32); + auto logits = NDArrayFactory::create('c', {1,3}); + auto expected = NDArrayFactory::create('c', {1}, {1.20194}); + + logits.linspace(0.1, 0.1); + + sd::ops::sparse_softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&labels, &logits}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, sparse_softmax_cross_entropy_loss_with_logits_test4) { + + auto labels = NDArrayFactory::create('c', {2},{0, 0}); + auto logits = NDArrayFactory::create('c', {2,1}); + auto expected = NDArrayFactory::create('c', {2}, {0., 0.}); + + logits.linspace(0.1, 0.1); + + sd::ops::sparse_softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&labels, &logits}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, histogram_fixed_width_test1) { + + auto input = NDArrayFactory::create('c', {2,3},{-1.f, 0.f, 1.5f, 2.f, 5.f, 15.f}); + auto range = NDArrayFactory::create('c', {2}, {0, 5}); + auto exp = NDArrayFactory::create('c', {5}, {2, 1, 1, 0, 2}); + + sd::ops::histogram_fixed_width op; + auto results = op.evaluate({&input, &range}, {}, {5}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto out = results.at(0); + + ASSERT_TRUE(exp.isSameShape(out)); + ASSERT_TRUE(exp.equalsTo(out)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, histogram_fixed_width_test2) { + + auto input = NDArrayFactory::create('c', {2,3,4},{0.f, 5.f, 2.f, 1.f, -1.f, 2.f, 5.f, 3.f, 2.f, 3.f, -1.f, 5.f, 3.f, 2.f, 1.f, 4.f, 2.f, 5.f, 5.f, 5.f, 6.f, 6.f, -1.f, 0.f}); + auto range = NDArrayFactory::create('c', {2}, {0, 5}); + auto exp = NDArrayFactory::create('c', {5}, {5, 2, 5, 3, 9}); + + sd::ops::histogram_fixed_width op; + auto results = op.evaluate({&input, &range}, {}, {5}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto out = results.at(0); + + ASSERT_TRUE(exp.isSameShape(out)); + ASSERT_TRUE(exp.equalsTo(out)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, histogram_fixed_width_test3) { + + auto input = NDArrayFactory::create('c', {2,3,1,4,1},{0.f, 5.f, 2.001f, 1.f, -1.f, 2.f, 5.f, 3.f, 2.999f, 3.00001f, -1.f, 3.99999f, 3.f, 2.f, 1.f, 4.f, 2.f, 5.f, 5.f, 5.f, 6.f, 6.f, -1.f, 0.00001f}); + auto range = NDArrayFactory::create('c', {1,2,1}, {0, 5}); + auto exp = NDArrayFactory::create('c', {5}, {5, 2, 5, 4, 8}); + + sd::ops::histogram_fixed_width op; + auto results = op.evaluate({&input, &range}, {}, {5}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto out = results.at(0); + + ASSERT_TRUE(exp.isSameShape(out)); + ASSERT_TRUE(exp.equalsTo(out)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, histogram_fixed_width_test4) { + + auto input = NDArrayFactory::create('c', {20,5},{13.8387f,0.1509f,50.39f,30.403f,13.5174f,9.7351f,37.6652f,28.9215f,22.7011f,45.2834f,40.7628f,50.4995f,26.8003f,27.479f,44.633f,6.9109f,48.5004f, + 46.5971f,1.6203f,23.6381f,38.9661f,50.8146f,17.2482f,8.0429f,7.5666f,7.9709f,21.8403f,20.1694f,23.3004f,50.9151f,46.239f,38.7323f,29.6946f,32.9876f, + 23.0013f,39.7318f,19.4486f,37.6147f,-0.1506f,5.3246f,3.6173f,24.2573f,4.3941f,9.7105f,24.0364f,35.3681f,17.7805f,35.7681f,16.4144f,17.4362f,8.4987f, + 26.8108f,36.2937f,31.6442f,29.7221f,8.7445f,33.3301f,4.0939f,13.078f,45.1481f,29.0172f,21.6548f,35.408f,27.1861f,2.2576f,40.6804f,36.2201f,29.7352f, + 29.1244f,38.7444f,5.8721f,33.5983f,48.2694f,34.4161f,19.7148f,13.8085f,13.6075f,22.5042f,37.8002f,50.0543f,48.5314f,20.3694f,28.5042f,-0.4679f,4.4245f, + 18.9837f,40.7724f,2.7611f,44.0431f,37.186f,27.7361f,14.6001f,9.1721f,14.6087f,21.4072f,49.3344f,11.4668f,14.6171f,15.2502f,5.244f}); + auto range = NDArrayFactory::create('c', {1,2}, {0, 50}); + auto exp = NDArrayFactory::create('c', {5}, {22, 17, 24, 19, 18}); + + sd::ops::histogram_fixed_width op; + auto results = op.evaluate({&input, &range}, {}, {5}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto out = results.at(0); + + ASSERT_TRUE(exp.isSameShape(out)); + ASSERT_TRUE(exp.equalsTo(out)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, histogram_fixed_width_test5) { + + auto input = NDArrayFactory::create('c', {5,20},{20.f, 0.f, 60.f, 40.f, 20.f, 0.f, 40.f, 0.f, 40.f, 40.f,40.f,60.f, 20.f, 20.f, 60.f, 0.f, 40.f, + 46.5971f,1.6203f,23.6381f,38.9661f,50.8146f,17.2482f,8.0429f,7.5666f,7.9709f,21.8403f,20.1694f,23.3004f,50.9151f,46.239f,38.7323f,29.6946f,32.9876f, + 23.0013f,39.7318f,19.4486f,37.6147f,-0.1506f,5.3246f,3.6173f,24.2573f,4.3941f,9.7105f,24.0364f,35.3681f,17.7805f,35.7681f,16.4144f,17.4362f,8.4987f, + 26.8108f,36.2937f,31.6442f,29.7221f,8.7445f,33.3301f,4.0939f,13.078f,45.1481f,29.0172f,21.6548f,35.408f,27.1861f,2.2576f,40.6804f,36.2201f,29.7352f, + 29.1244f,38.7444f,5.8721f,33.5983f,48.2694f,34.4161f,19.7148f,13.8085f,13.6075f,22.5042f,37.8002f,50.0543f,48.5314f,20.3694f,28.5042f,-0.4679f,4.4245f, + 18.9837f,40.7724f,2.7611f,44.0431f,37.186f,27.7361f,14.6001f,9.1721f,14.6087f,21.4072f,49.3344f,11.4668f,14.6171f,15.2502f,5.244f}); + auto range = NDArrayFactory::create('c', {1,2}, {0, 50}); +// auto exp = NDArrayFactory::create('c', {5}, {23, 19, 20, 23, 15}); // 23, 15, 24, 17, 21 + auto exp = NDArrayFactory::create('c', {5}, {23, 15, 24, 17, 21}); + + sd::ops::histogram_fixed_width op; + auto results = op.evaluate({&input, &range}, {}, {5}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *out = results.at(0); + + ASSERT_TRUE(exp.isSameShape(out)); + // out->printBuffer("5HIST"); + ASSERT_TRUE(exp.equalsTo(out)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, histogram_fixed_width_test6) { + + auto input = NDArrayFactory::create('c', {7},{0.0, 0.1, 0.1, 0.3, 0.5, 0.5, 0.9}); + auto range = NDArrayFactory::create('c', {2}, {0, 1}); + auto bins = NDArrayFactory::create(5); + + auto exp = NDArrayFactory::create('c', {5}, {3, 1, 2, 0, 1}); + + sd::ops::histogram_fixed_width op; + auto results = op.evaluate({&input, &range, &bins}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto out = results.at(0); + // out->printShapeInfo(); + // out->printIndexedBuffer(); + + ASSERT_TRUE(exp.isSameShape(out)); + ASSERT_TRUE(exp.equalsTo(out)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, NTH_Element_Test_1) { + + NDArray input = NDArrayFactory::create('c', {12}, {10, 1, 9, 8, 11, 7, 6, 5, 12, 3, 2, 4}); + NDArray n = NDArrayFactory::create(4.f); + NDArray exp = NDArrayFactory::create(5.f); + + //input.linspace(1.f); + + sd::ops::nth_element op; + auto results = op.evaluate({&input, &n}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, NTH_Element_Test_2) { + + NDArray input = NDArrayFactory::create('c', {3, 4}, {10, 11, 9, 12, 8, 7, 6, 5, 1, 3, 2, 4}); + NDArray n = NDArrayFactory::create(3); + NDArray exp = NDArrayFactory::create({12.f, 8.f, 4.f}); + +// input.linspace(1.f); + + sd::ops::nth_element op; + auto results = op.evaluate({&input, &n}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, NTH_Element_Test_3) { + + NDArray input = NDArrayFactory::create('c', {3,4}, {10, 1, 9, 8, 11, 7, 6, 5, 12, 3, 2, 4}); + NDArray n = NDArrayFactory::create(3); + NDArray exp = NDArrayFactory::create({1.f, 5.f, 2.f}); + + //input.linspace(1.f); + + sd::ops::nth_element op; + auto results = op.evaluate({&input, &n}, {}, {1}); // with reverse = true + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, NTH_Element_Test_4) { + + NDArray input = NDArrayFactory::create('c', {2, 2, 3}, {10, 1, 9, 8, 11, 7, 6, 5, 12, 3, 2, 4}); + NDArray n = NDArrayFactory::create(2); + NDArray exp = NDArrayFactory::create('c', {2,2}, {10.f, 11.f, 12.f, 4.f}); + + //input.linspace(1.f); + + sd::ops::nth_element op; + auto results = op.evaluate({&input, &n}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, NTH_Element_Test_04) { + + NDArray input = NDArrayFactory::create('c', {6, 15}); + NDArray n = NDArrayFactory::create(4); + NDArray exp = NDArrayFactory::create('c', {6}, {5.f, 20.f, 35.f, 50.f, 65.f, 80.f}); + + input.linspace(1.f); + + sd::ops::nth_element op; + auto results = op.evaluate({&input, &n}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, NTH_Element_Test_5) { + + NDArray input = NDArrayFactory::create('c', {2, 2, 3}, {10, 1, 9, 8, 11, 7, 6, 5, 12, 3, 2, 4}); + NDArray n = NDArrayFactory::create(2); + NDArray exp = NDArrayFactory::create('c', {2,2}, {1.f, 7.f, 5.f, 2.f}); + +// input.linspace(1.f); + + sd::ops::nth_element op; + auto results = op.evaluate({&input, &n}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, NTH_Element_Test_6) { + + NDArray input = NDArrayFactory::create('c', {12}, {10, 1, 9, 8, 11, 7, 6, 5, 12, 3, 2, 4}); + NDArray n = NDArrayFactory::create(0); + NDArray exp = NDArrayFactory::create(1.f);//NDArrayFactory::create('c', {2,2}, {1.f, 4.f, 7.f, 10.f}); + +// input.linspace(1.f); + + sd::ops::nth_element op; + auto results = op.evaluate({&input, &n}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* output = results.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, NTH_Element_Test_06) { + + NDArray input = NDArrayFactory::create('c', {12}, {10, 1, 9, 8, 11, 7, 6, 5, 12, 3, 2, 4}); + NDArray n = NDArrayFactory::create(4); + NDArray exp = NDArrayFactory::create(8.f);//NDArrayFactory::create('c', {2,2}, {1.f, 4.f, 7.f, 10.f}); + +// input.linspace(1.f); + + sd::ops::nth_element op; + auto results = op.evaluate({&input, &n}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* output = results.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, NTH_Element_Test_7) { + + NDArray input = NDArrayFactory::create('c', {2, 3, 4}, {0.7788f, 0.8012f, 0.7244f, 0.2309f, + 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f, 0.9234f, 0.0856f, 0.7938f, + + 0.6591f, 0.5555f, 0.1596f, 0.3087f, + 0.1548f, 0.4695f, 0.9939f, 0.6113f, + 0.6765f, 0.1800f, 0.6750f, 0.2246f}); + NDArray n = NDArrayFactory::create(2); + NDArray exp = NDArrayFactory::create('c', {2,3}, {0.7788f, 0.7271f, 0.7938f, 0.5555f, 0.6113f, 0.675f}); + + //input.linspace(1.f); + + sd::ops::nth_element op; + auto results = op.evaluate({&input, &n}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, NTH_Element_Test_8) { + + NDArray input = NDArrayFactory::create('c', {2, 3, 4}, {0.7788f, 0.8012f, 0.7244f, 0.2309f, + 0.7271f, 0.1804f, 0.5056f, 0.8925f, + 0.5461f, 0.9234f, 0.0856f, 0.7938f, + + 0.6591f, 0.5555f, 0.1596f, 0.3087f, + 0.1548f, 0.4695f, 0.9939f, 0.6113f, + 0.6765f, 0.1800f, 0.6750f, 0.2246f}); + NDArray n = NDArrayFactory::create(2); + NDArray exp = NDArrayFactory::create('c', {2,3}, {0.7244f, 0.5056f, 0.5461f, 0.3087f, 0.4695f, 0.2246f}); + + //input.linspace(1.f); + + sd::ops::nth_element op; + auto results = op.evaluate({&input, &n}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, broadcast_to_test1) { + + auto input = NDArrayFactory::create('c', {3}); + auto shape = NDArrayFactory::create('c', {2}, {3, 3}); + auto exp = NDArrayFactory::create('c', {3,3}, {1, 2, 3,1, 2, 3, 1, 2, 3}); + + input.linspace(1.f); + + sd::ops::broadcast_to op; + auto results = op.evaluate({&input, &shape}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, broadcast_to_test2) { + + auto input = NDArrayFactory::create('c', {1,3}); + auto shape = NDArrayFactory::create('c', {2}, {3.f, 3.f}); + auto exp = NDArrayFactory::create('c', {3,3}, {1.f, 2.f, 3.f,1.f, 2.f, 3.f,1.f, 2.f, 3.f}); + + input.linspace(1.f); + + sd::ops::broadcast_to op; + auto results = op.evaluate({&input, &shape}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, broadcast_to_test3) { + + auto input = NDArrayFactory::create('c', {3,1}); + auto shape = NDArrayFactory::create('c', {2}, {3.f, 3.f}); + auto exp = NDArrayFactory::create('c', {3,3}, {1.f, 1.f, 1.f,2.f, 2.f, 2.f,3.f, 3.f, 3.f}); + + input.linspace(1.f); + + sd::ops::broadcast_to op; + auto results = op.evaluate({&input, &shape}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, broadcast_to_test4) { + + auto input = NDArrayFactory::create(10.); + auto shape = NDArrayFactory::create('c', {2}, {3.f, 3.f}); + auto exp = NDArrayFactory::create('c', {3,3}, {10.f, 10.f, 10.f,10.f, 10.f, 10.f, 10.f, 10.f, 10.f}); + + sd::ops::broadcast_to op; + auto results = op.evaluate({&input, &shape}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, broadcast_to_test5) { + + auto input = NDArrayFactory::create(10.f); + auto shape = NDArrayFactory::create('c', {1}, {3.f}); + auto exp = NDArrayFactory::create('c', {3}, {10.f, 10.f, 10.f}); + + sd::ops::broadcast_to op; + auto results = op.evaluate({&input, &shape}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, broadcast_to_test6) { + + auto input = NDArrayFactory::create(10.f); + auto shape = NDArrayFactory::create(1.f); + auto exp = NDArrayFactory::create('c', {1}, {10.f}); + + sd::ops::broadcast_to op; + auto results = op.evaluate({&input, &shape}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, broadcast_to_test7) { + + auto input = NDArrayFactory::create(10.f); + auto shape = NDArrayFactory::create(1); + auto exp = NDArrayFactory::create('c', {1}, {10.}); + + sd::ops::broadcast_to op; + auto results = op.evaluate({&input, &shape}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, broadcast_to_test8) { + + auto input = NDArrayFactory::create('c', {3}); + auto shape = NDArrayFactory::create('c', {3}, {1.f, 3.f, 3.f}); + auto exp = NDArrayFactory::create('c', {1,3,3}, {1.f, 2.f, 3.f,1.f, 2.f, 3.f,1.f, 2.f, 3.f}); + + input.linspace(1.f); + + sd::ops::broadcast_to op; + auto results = op.evaluate({&input, &shape}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, broadcast_to_test9) { + + auto input = NDArrayFactory::create('c', {5,1,1}); + auto shape = NDArrayFactory::create('c', {5}, {2.f,1.f,5.f,1.f,3.f}); + auto exp = NDArrayFactory::create('c', {2,1,5,1,3}, {1.f, 1.f, 1.f,2.f, 2.f, 2.f,3.f, 3.f, 3.f,4.f, 4.f, 4.f,5.f, 5.f, 5.f, + 1.f, 1.f, 1.f,2.f, 2.f, 2.f,3.f, 3.f, 3.f,4.f, 4.f, 4.f,5.f, 5.f, 5.f}); + input.linspace(1.f); + + sd::ops::broadcast_to op; + auto results = op.evaluate({&input, &shape}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, broadcast_to_test10) { + + auto input = NDArrayFactory::create('c', {5,1,3}); + auto shape = NDArrayFactory::create('c', {5}, {2.f,1.f,5.f,1.f,3.f}); + auto exp = NDArrayFactory::create('c', {2,1,5,1,3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f,10.f, 11.f, 12.f,13.f, 14.f, 15.f, + 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f,10.f, 11.f, 12.f,13.f, 14.f, 15.f}); + input.linspace(1.f); + + sd::ops::broadcast_to op; + auto results = op.evaluate({&input, &shape}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1) { + + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); + //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create('c', {1, 10, 10, 4}, {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4, + 4.6, 5.6, 6.6, 7.6, 5.8, 6.8, 7.8, 8.8, 7., 8., 9., 10., + 8.2, 9.2, 10.2, 11.2, 9., 10., 11., 12., 9., 10., 11., 12., + 9., 10., 11., 12., 3.4, 4.4, 5.4, 6.4, 4.6, 5.6, 6.6, 7.6, + 5.8, 6.8, 7.8, 8.8, 7.0, 8., 9., 10., 8.2, 9.2, 10.2, 11.2, + 9.4,10.4, 11.4, 12.4,10.6, 11.6,12.6, 13.6,11.4, 12.4, 13.4, 14.4, + 11.4,12.4, 13.4, 14.4,11.4, 12.4,13.4, 14.4, 5.8, 6.8, 7.8, 8.8, + 7., 8., 9., 10., 8.2, 9.2,10.2, 11.2, 9.4, 10.4, 11.4, 12.4, + 10.6,11.6, 12.6, 13.6,11.8, 12.8,13.8, 14.8,13.0, 14.0, 15.0, 16., + 13.8,14.8, 15.8, 16.8,13.8, 14.8,15.8, 16.8,13.8, 14.8, 15.8, 16.8, + 8.2, 9.2, 10.2, 11.2, 9.4, 10.4,11.4, 12.4,10.6, 11.6, 12.6, 13.6, + 11.8,12.8, 13.8, 14.8,13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, + 15.4,16.4, 17.4, 18.4,16.2, 17.2,18.2, 19.2,16.2, 17.2, 18.2, 19.2, + 16.2,17.2, 18.2, 19.2,10.6, 11.6,12.6, 13.6,11.8, 12.8, 13.8, 14.8, + 13., 14., 15., 16., 14.2, 15.2,16.2, 17.2,15.4, 16.4, 17.4, 18.4, + 16.6,17.6, 18.6, 19.6,17.8, 18.8,19.8, 20.8,18.6, 19.6, 20.6, 21.6, + 18.6,19.6, 20.6, 21.6,18.6, 19.6,20.6, 21.6,13., 14., 15., 16., + 14.2,15.2, 16.2, 17.2,15.4, 16.4,17.4, 18.4,16.6, 17.6, 18.6, 19.6, + 17.8,18.8, 19.8, 20.8,19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, + 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., + 13., 14., 15., 16., 14.2, 15.2,16.2, 17.2,15.4, 16.4, 17.4, 18.4, + 16.6,17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,19., 20., 21., 22., + 20.2,21.2, 22.2, 23.2,21., 22., 23., 24., 21., 22., 23., 24., + 21., 22., 23., 24., 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, + 15.4,16.4, 17.4, 18.4,16.6, 17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8, + 19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2,21., 22., 23., 24., + 21., 22., 23., 24., 21., 22., 23., 24., 13., 14., 15., 16., + 14.2,15.2, 16.2, 17.2,15.4, 16.4, 17.4, 18.4,16.6, 17.6, 18.6, 19.6, + 17.8,18.8, 19.8, 20.8,19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, + 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., + 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2,15.4, 16.4, 17.4, 18.4, + 16.6,17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,19., 20., 21., 22., + 20.2,21.2, 22.2, 23.2, + 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24.}); + //input = 1.f; + input.linspace(1); + + sd::ops::resize_bilinear op; + auto results = op.evaluate({&input}, {}, {10, 10}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + + //result.printIndexedBuffer("Resized to 10x10"); + //expected.printIndexedBuffer("Expect for 10x10"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test_11) { + + NDArray input = NDArrayFactory::create('c', {1, 1, 1, 256}); + + input.assign(0.8f); //linspace(1); + auto size = NDArrayFactory::create({65,65}); + auto ex = NDArrayFactory::create('c', {1,65,65,256}); + sd::ops::resize_bilinear op; + auto results = op.evaluate({&input, &size}, {}, {}, {false}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + ASSERT_NE(*result, ex); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test_12) { + + NDArray input = NDArrayFactory::create('c', {1, 1, 1, 256}); + + input.assign(0.8f); //linspace(1); + auto size = NDArrayFactory::create({65,65}); + auto ex = NDArrayFactory::create('c', {1,65,65,256}); + sd::ops::resize_bilinear op; + auto results = op.evaluate({&input, &size}, {}, {}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + ASSERT_NE(*result, ex); +} + +TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_1) { + + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); + //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create('c', {1, 4, 5, 4}, { + 1., 2., 3., 4., + 2.6, 3.6, 4.6, 5.6, + 5., 6., 7., 8., + 7.4, 8.4, 9.4, 10.4, + 9., 10., 11., 12., + + 4., 5., 6., 7., + 5.6, 6.6, 7.6, 8.6, + 8., 9., 10., 11., + 10.4, 11.4, 12.4, 13.4, + 12., 13., 14., 15., + + 10., 11., 12., 13., + 11.6, 12.6, 13.6, 14.6, + 14., 15., 16., 17., + 16.4, 17.4, 18.4, 19.4, + 18., 19., 20., 21., + + 13., 14., 15., 16., + 14.6, 15.6, 16.6, 17.6, + 17., 18., 19., 20., + 19.4, 20.4, 21.4, 22.4, + 21., 22., 23., 24. + }); + //input = 1.f; + input.linspace(1); + + sd::ops::resize_bilinear op; + auto results = op.evaluate({&input}, {}, {4, 5}, {false, true}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + +// result.printIndexedBuffer("Resized to 4x5 bilinear with half pixels"); + //expected.printIndexedBuffer("Expect for 10x10"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + +} + +TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test1_2) { + + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); + //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create('c', {1, 4, 5, 4}, { + 1.f, 2.f, 3.f, 4.f, + 2.6f, 3.6f, 4.6f, 5.6f, + 5.f, 6.f, 7.f, 8.f, + 7.4f, 8.4f, 9.4f, 10.4f, + 9.f, 10.f, 11.f, 12.f, + + 4.f, 5.f, 6.f, 7.f, + 5.6f, 6.6f, 7.6f, 8.6f, + 8.f, 9.f, 10.f, 11.f, + 10.4f, 11.4f, 12.4f, 13.4f, + 12.f, 13.f, 14.f, 15.f, + + 10.f, 11.f, 12.f, 13.f, + 11.6f, 12.6f, 13.6f, 14.6f, + 14.f, 15.f, 16.f, 17.f, + 16.4f, 17.4f, 18.4f, 19.4f, + 18.f, 19.f, 20.f, 21.f, + + 13.f, 14.f, 15.f, 16.f, + 14.6f, 15.6f, 16.6f, 17.6f, + 17.f, 18.f, 19.f, 20.f, + 19.4f, 20.4f, 21.4f, 22.4f, + 21.f, 22.f, 23.f, 24.f + }); + //input = 1.f; + input.linspace(1); + + sd::ops::resize_bilinear op; + auto results = op.evaluate({&input}, {}, {4, 5}, {false, true}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + +// result.printBuffer("Resized to 4x5"); +// expected.printBuffer("Expect for 4x5"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + +} + +TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test01) { + + NDArray input = NDArrayFactory::create('c', {2,3,4}); + //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create('c', {10, 10, 4}, {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4, + 4.6, 5.6, 6.6, 7.6, 5.8, 6.8, 7.8, 8.8, 7., 8., 9., 10., + 8.2, 9.2, 10.2, 11.2, 9., 10., 11., 12., 9., 10., 11., 12., + 9., 10., 11., 12., 3.4, 4.4, 5.4, 6.4, 4.6, 5.6, 6.6, 7.6, + 5.8, 6.8, 7.8, 8.8, 7.0, 8., 9., 10., 8.2, 9.2, 10.2, 11.2, + 9.4,10.4, 11.4, 12.4,10.6, 11.6,12.6, 13.6,11.4, 12.4, 13.4, 14.4, + 11.4,12.4, 13.4, 14.4,11.4, 12.4,13.4, 14.4, 5.8, 6.8, 7.8, 8.8, + 7., 8., 9., 10., 8.2, 9.2,10.2, 11.2, 9.4, 10.4, 11.4, 12.4, + 10.6,11.6, 12.6, 13.6,11.8, 12.8,13.8, 14.8,13.0, 14.0, 15.0, 16., + 13.8,14.8, 15.8, 16.8,13.8, 14.8,15.8, 16.8,13.8, 14.8, 15.8, 16.8, + 8.2, 9.2, 10.2, 11.2, 9.4, 10.4,11.4, 12.4,10.6, 11.6, 12.6, 13.6, + 11.8,12.8, 13.8, 14.8,13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, + 15.4,16.4, 17.4, 18.4,16.2, 17.2,18.2, 19.2,16.2, 17.2, 18.2, 19.2, + 16.2,17.2, 18.2, 19.2,10.6, 11.6,12.6, 13.6,11.8, 12.8, 13.8, 14.8, + 13., 14., 15., 16., 14.2, 15.2,16.2, 17.2,15.4, 16.4, 17.4, 18.4, + 16.6,17.6, 18.6, 19.6,17.8, 18.8,19.8, 20.8,18.6, 19.6, 20.6, 21.6, + 18.6,19.6, 20.6, 21.6,18.6, 19.6,20.6, 21.6,13., 14., 15., 16., + 14.2,15.2, 16.2, 17.2,15.4, 16.4,17.4, 18.4,16.6, 17.6, 18.6, 19.6, + 17.8,18.8, 19.8, 20.8,19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, + 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., + 13., 14., 15., 16., 14.2, 15.2,16.2, 17.2,15.4, 16.4, 17.4, 18.4, + 16.6,17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,19., 20., 21., 22., + 20.2,21.2, 22.2, 23.2,21., 22., 23., 24., 21., 22., 23., 24., + 21., 22., 23., 24., 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, + 15.4,16.4, 17.4, 18.4,16.6, 17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8, + 19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2,21., 22., 23., 24., + 21., 22., 23., 24., 21., 22., 23., 24., 13., 14., 15., 16., + 14.2,15.2, 16.2, 17.2,15.4, 16.4, 17.4, 18.4,16.6, 17.6, 18.6, 19.6, + 17.8,18.8, 19.8, 20.8,19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, + 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., + 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2,15.4, 16.4, 17.4, 18.4, + 16.6,17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,19., 20., 21., 22., + 20.2,21.2, 22.2, 23.2, + 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24.}); + //input = 1.f; + input.linspace(1); + + sd::ops::resize_bilinear op; + auto results = op.evaluate({&input}, {}, {10, 10}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + + //result.printIndexedBuffer("Resized to 10x10"); + //expected.printIndexedBuffer("Expect for 10x10"); +// result.printShapeInfo("Output shape"); +// expected.printShapeInfo("Expect shape"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + +} + +TEST_F(DeclarableOpsTests10, ResizeImages_Test1) { + + NDArray input = NDArrayFactory::create('c', {2, 4, 5, 3}); + input.linspace(1.); + + auto expected = NDArrayFactory::create('c', {2, 7, 9, 3}, { + 1.f, 2.f, 3.f, 2.6666667f, 3.6666667f, 4.666667f, 4.3333335f, 5.3333335f, 6.3333335f, 6.f, + 7.f, 8.f, 7.666667f, 8.666667f, 9.666667f, 9.333334f, 10.333334f, 11.333334f, 11.f, 12.f, + 13.f, 12.666667f, 13.666667f, 14.666667f, 13.f, 14.f, 15.f, 9.571429f, 10.571429f, 11.571429f, + 11.238095f, 12.238095f, 13.238095f, 12.904762f, 13.904762f, 14.904762f, 14.571429f, 15.571429f, 16.57143f, + 16.238096f, 17.238096f, 18.238096f, 17.904762f, 18.904762f, 19.904762f, 19.57143f, 20.57143f, 21.57143f, + 21.238096f, 22.238096f, 23.238096f, 21.57143f, 22.57143f, 23.57143f, 18.142859f, 19.142859f, 20.142859f, + 19.809525f, 20.809525f, 21.809525f, 21.476192f, 22.476192f, 23.476192f, 23.142859f, 24.142859f, 25.142859f, + 24.809526f, 25.809526f, 26.809526f, 26.476192f, 27.476192f, 28.476192f, 28.142859f, 29.142859f, 30.142859f, + 29.809526f, 30.809526f, 31.809526f, 30.142859f, 31.142859f, 32.142857f, 26.714287f, 27.714287f, 28.714287f, + 28.380955f, 29.380955f, 30.380955f, 30.04762f, 31.04762f, 32.047623f, 31.714287f, 32.714287f, 33.714287f, + 33.380955f, 34.380955f, 35.380955f, 35.047623f, 36.047623f, 37.047623f, 36.714287f, 37.714287f, 38.714287f, + 38.380955f, 39.380955f, 40.380955f, 38.714287f, 39.714287f, 40.714287f, 35.285717f, 36.285717f, 37.285717f, + 36.952385f, 37.952385f, 38.952385f, 38.61905f, 39.61905f, 40.61905f, 40.285717f, 41.285717f, 42.285717f, + 41.952385f, 42.952385f, 43.952385f, 43.61905f, 44.61905f, 45.61905f, 45.285717f, 46.285717f, 47.285717f, + 46.952385f, 47.952385f, 48.952385f, 47.285717f, 48.285717f, 49.285717f, 43.857143f, 44.857143f, 45.857143f, + 45.52381f, 46.52381f, 47.52381f, 47.190475f, 48.190475f, 49.190475f, 48.857143f, 49.857143f, 50.857143f, + 50.52381f, 51.52381f, 52.52381f, 52.190475f, 53.190475f, 54.190475f, 53.857143f, 54.857143f, 55.857143f, + 55.52381f, 56.52381f, 57.52381f, 55.857143f, 56.857143f, 57.857143f, 46.f, 47.f, 48.f, + 47.666668f, 48.666668f, 49.666668f, 49.333332f, 50.333332f, 51.333332f, 51.f, 52.f, 53.f, + 52.666668f, 53.666668f, 54.666668f, 54.333332f, 55.333332f, 56.333332f, 56.f, 57.f, 58.f, + 57.666668f, 58.666668f, 59.666668f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, + 62.666668f, 63.666668f, 64.666664f, 64.333336f, 65.333336f, 66.333336f, 66.f, 67.f, 68.f, + 67.666664f, 68.666664f, 69.666664f, 69.333336f, 70.333336f, 71.333336f, 71.f, 72.f, 73.f, + 72.666664f, 73.666664f, 74.666664f, 73.f, 74.f, 75.f, 69.57143f, 70.57143f, 71.57143f, + 71.2381f, 72.2381f, 73.23809f, 72.90476f, 73.90476f, 74.90476f, 74.57143f, 75.57143f, 76.57143f, + 76.23809f, 77.23809f, 78.23809f, 77.90476f, 78.90476f, 79.90476f, 79.57143f, 80.57143f, 81.57143f, + 81.23809f, 82.23809f, 83.23809f, 81.57143f, 82.57143f, 83.57143f, 78.14286f, 79.14286f, 80.14286f, + 79.809525f, 80.809525f, 81.809525f, 81.4762f, 82.4762f, 83.4762f, 83.14286f, 84.14286f, 85.14286f, + 84.809525f, 85.809525f, 86.809525f, 86.4762f, 87.4762f, 88.4762f, 88.14286f, 89.14286f, 90.14286f, + 89.809525f, 90.809525f, 91.809525f, 90.14286f, 91.14286f, 92.14286f, 86.71429f, 87.71429f, 88.71429f, + 88.38095f, 89.38095f, 90.38095f, 90.04762f, 91.04762f, 92.04762f, 91.71429f, 92.71429f, 93.71429f, + 93.38095f, 94.38095f, 95.38095f, 95.04762f, 96.04762f, 97.04762f, 96.71429f, 97.71429f, 98.71429f, + 98.38095f, 99.38095f, 100.38095f, 98.71429f, 99.71429f, 100.71429f, 95.28571f, 96.28571f, 97.28571f, + 96.95238f, 97.95238f, 98.95238f, 98.61905f, 99.61905f, 100.61905f, 100.28571f, 101.28571f, 102.28571f, + 101.95238f, 102.95238f, 103.95238f, 103.61905f, 104.61905f, 105.61905f, 105.28571f, 106.28571f, 107.28571f, + 106.95238f, 107.95238f, 108.95238f, 107.28571f, 108.28571f, 109.28571f, 103.85715f, 104.85715f, 105.85715f, + 105.5238f, 106.5238f, 107.5238f,107.190475f,108.190475f,109.190475f, 108.85715f, 109.85715f, 110.85715f, + 110.5238f, 111.5238f, 112.5238f,112.190475f,113.190475f,114.190475f, 113.85715f, 114.85715f, 115.85715f, + 115.5238f, 116.5238f, 117.5238f, 115.85715f, 116.85715f, 117.85715f, 106.f, 107.f, 108.f, + 107.666664f,108.666664f,109.666664f,109.333336f,110.333336f,111.333336f, 111.f, 112.f, 113.f, + 112.666664f,113.666664f,114.666664f,114.333336f,115.333336f,116.333336f, 116.f, 117.f, 118.f, + 117.666664f,118.666664f,119.666664f, 118.f, 119.f, 120.f + }); + + auto size = NDArrayFactory::create({7, 11}); + sd::ops::resize_images op; + auto results = op.evaluate({&input, &size}, {}, {0}, {false, true}); // resize with bilinear method + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray *result = results.at(0); + +// result->printBuffer("Resized to 7x9"); +// expected.printBuffer("Expect for 7x9"); +// result.printShapeInfo("Output shape"); +// expected.printShapeInfo("Expect shape"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} +TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test02) { + + NDArray input = NDArrayFactory::create('c', {2, 5,5,3}, { + 0.7788f, 0.8012f, 0.7244f, + 0.2309f, 0.7271f, 0.1804f, + 0.5056f, 0.8925f, 0.5461f, + 0.9234f, 0.0856f, 0.7938f, + 0.6591f, 0.5555f, 0.1596f, + 0.3087f, 0.1548f, 0.4695f, + 0.9939f, 0.6113f, 0.6765f, + 0.1800f, 0.6750f, 0.2246f, + 0.0509f, 0.4601f, 0.8284f, + 0.2354f, 0.9752f, 0.8361f, + 0.2585f, 0.4189f, 0.7028f, + 0.7679f, 0.5373f, 0.7234f, + 0.2690f, 0.0062f, 0.0327f, + 0.0644f, 0.8428f, 0.7494f, + 0.0755f, 0.6245f, 0.3491f, + 0.5793f, 0.5730f, 0.1822f, + 0.6420f, 0.9143f, 0.3019f, + 0.3574f, 0.1704f, 0.8395f, + 0.5468f, 0.0744f, 0.9011f, + 0.6574f, 0.4124f, 0.2445f, + 0.4248f, 0.5219f, 0.6952f, + 0.4900f, 0.2158f, 0.9549f, + 0.1386f, 0.1544f, 0.5365f, + 0.0134f, 0.4163f, 0.1456f, + 0.4109f, 0.2484f, 0.3330f, + 0.2974f, 0.6636f, 0.3808f, + 0.8664f, 0.1896f, 0.7530f, + 0.7215f, 0.6612f, 0.7270f, + 0.5704f, 0.2666f, 0.7453f, + 0.0444f, 0.3024f, 0.4850f, + 0.7982f, 0.0965f, 0.7843f, + 0.5075f, 0.0844f, 0.8370f, + 0.6103f, 0.4604f, 0.6087f, + 0.8594f, 0.4599f, 0.6714f, + 0.2744f, 0.1981f, 0.4143f, + 0.7821f, 0.3505f, 0.5040f, + 0.1180f, 0.8307f, 0.1817f, + 0.8442f, 0.5074f, 0.4471f, + 0.5105f, 0.6666f, 0.2576f, + 0.2341f, 0.6801f, 0.2652f, + 0.5394f, 0.4690f, 0.6146f, + 0.1210f, 0.2576f, 0.0769f, + 0.4643f, 0.1628f, 0.2026f, + 0.3774f, 0.0506f, 0.3462f, + 0.5720f, 0.0838f, 0.4228f, + 0.0588f, 0.5362f, 0.4756f, + 0.2530f, 0.1778f, 0.0751f, + 0.8977f, 0.3648f, 0.3065f, + 0.4739f, 0.7014f, 0.4473f, + 0.5171f, 0.1744f, 0.3487f}); + + NDArray expected = NDArrayFactory::create('c', {2, 9, 9, 3}, { + 0.7788f, 0.8012f, 0.7244f, 0.4744111f, 0.7600333f, 0.42217776f, + 0.26142225f, 0.7454778f, 0.22103335f, 0.41403335f, 0.8373667f, 0.42420003f, + 0.59844446f, 0.71318877f, 0.6011445f, 0.83055556f, 0.264911f, 0.7387556f, + 0.83529997f, 0.2422334f, 0.5823999f, 0.6884666f, 0.5032889f, 0.23006654f, + 0.6591f, 0.5555f, 0.1596f, 0.5176333f, 0.44208887f , 0.5827889f, + 0.5938309f, 0.5646876f, 0.5123568f, 0.61811364f, 0.6748667f, 0.44617534f, + 0.43473703f, 0.7353667f, 0.3969963f, 0.35003704f, 0.6654419f, 0.46649635f, + 0.41335183f, 0.39988017f, 0.7140149f, 0.43368888f, 0.45865932f, 0.72049254f, + 0.42537406f, 0.73366547f, 0.5662765f, 0.42371112f, 0.78866667f, 0.53543335f, + 0.30312222f, 0.18414445f, 0.49542224f, 0.67293704f, 0.4168852f, 0.59891605f, + 0.8822444f, 0.60281235f, 0.62855184f, 0.4495222f, 0.6014852f, 0.36275554f, + 0.15933579f, 0.5788963f, 0.34024328f, 0.08295307f, 0.52441484f, 0.6826569f, + 0.10747781f, 0.64715934f, 0.80707777f, 0.19927411f, 0.8880544f, 0.7861703f, + 0.21763334f, 0.9362333f, 0.78198886f, 0.27523333f, 0.3308667f, 0.6250333f, + 0.5907889f, 0.45925558f, 0.6709963f, 0.7761333f, 0.5249852f, 0.63986665f, + 0.4406333f, 0.34007773f, 0.3003666f, 0.19945924f, 0.33715558f, 0.24757043f, + 0.09977405f, 0.60721123f, 0.6248297f, 0.08286668f, 0.7239556f, 0.6876333f, + 0.12114445f, 0.73849255f ,0.54079986f, 0.12879999f, 0.74139994f, 0.51143324f, + 0.32978892f, 0.45314446f, 0.58711106f, 0.5576408f, 0.5464408f, 0.6107901f, + 0.68978024f, 0.55681235f, 0.5833172f, 0.43907034f, 0.23548517f, 0.35123706f, + 0.26263458f, 0.18254575f, 0.33890504f, 0.1976099f, 0.5321877f, 0.65619516f, + 0.18267044f, 0.6404851f, 0.63069254f, 0.20112106f, 0.58788633f, 0.37666163f, + 0.20481117f, 0.57736665f, 0.32585555f, 0.50801116f, 0.5387556f, 0.29788882f, + 0.59799266f, 0.7008482f, 0.35215425f, 0.6330642f, 0.753121f, 0.42497158f, + 0.44849625f, 0.36611477f, 0.5719964f, 0.36038768f, 0.1586321f, 0.70625067f, + 0.416968f, 0.22043455f, 0.82134944f, 0.4690964f, 0.31661478f, 0.6675073f, + 0.5182569f, 0.4357136f, 0.33437145f, 0.528089f, 0.4595333f, 0.26774442f, + 0.52779996f, 0.5559667f, 0.35320008f, 0.5630963f, 0.62568885f, 0.44562602f, + 0.557237f, 0.62408876f, 0.5438927f, 0.3867555f, 0.3371999f, 0.6655223f, + 0.30325183f, 0.17024446f, 0.71867025f, 0.35021478f, 0.18318895f, 0.6690962f, + 0.4377444f, 0.24482228f, 0.5241777f, 0.5523185f, 0.33891484f, 0.3156962f, + 0.5752333f, 0.3577333f, 0.27400002f, 0.44196665f, 0.52757776f, 0.6382001f, + 0.47803456f, 0.3974851f, 0.7738359f, 0.4686691f, 0.27816284f, 0.8476581f, + 0.2775703f, 0.20192216f, 0.6742259f, 0.14285672f, 0.20554078f, 0.4944727f, + 0.0927209f, 0.32894826f, 0.30523813f, 0.19454071f, 0.3410815f, 0.26075178f, + 0.3976642f, 0.27903205f, 0.31276423f, 0.43828884f, 0.2666222f, 0.32316667f, + 0.4248f, 0.5219f, 0.6952f, 0.46102223f, 0.35184443f, 0.8394778f, + 0.45095554f, 0.20897777f, 0.9084111f, 0.2557333f, 0.17486666f, 0.6759666f, + 0.11077777f, 0.21260004f, 0.44963327f, 0.04122221f, 0.35810006f, 0.23246664f, + 0.14590007f, 0.36033332f, 0.2080667f, 0.3667334f, 0.2670555f, 0.31217784f, + 0.4109f, 0.2484f, 0.333f, 0.2974f, 0.6636f, 0.3808f, + 0.6135111f, 0.40026665f, 0.5875778f, 0.8503f, 0.24200003f, 0.7501111f, + 0.76979995f, 0.50400007f, 0.7356667f, 0.6879222f, 0.57351106f, 0.73106664f, + 0.60397774f, 0.35428885f, 0.74123335f, 0.39506656f, 0.27853334f, 0.6585333f, + 0.10284433f, 0.29842222f, 0.5139222f, 0.0444f, 0.3024f, 0.485f, + 0.5756222f, 0.34854442f, 0.6049667f, 0.6263938f, 0.22777282f, 0.71313334f, + 0.66620123f, 0.17765433f, 0.78429013f, 0.6621518f, 0.41014817f, 0.7074074f, + 0.67555183f, 0.51060987f, 0.6708259f, 0.7151259f, 0.41302344f, 0.6946963f, + 0.5446962f, 0.33081108f, 0.6180703f, 0.23426408f, 0.25884813f, 0.4744469f, + 0.17217779f, 0.24445555f, 0.44572222f, 0.7964111f, 0.12472223f, 0.7531556f, + 0.6118617f, 0.1483889f, 0.75928515f, 0.4833407f, 0.2004667f, 0.7449173f, + 0.57893336f, 0.3661889f, 0.6485592f, 0.6772543f, 0.46945432f, 0.5984506f, + 0.7796679f, 0.47903457f, 0.617716f, 0.63706285f, 0.40579626f, 0.54952586f, + 0.33111224f, 0.27734566f, 0.42303205f, 0.26992223f, 0.25165558f, 0.39773333f, + 0.7874667f, 0.26583335f, 0.5974333f, 0.4876703f, 0.44144446f, 0.48782218f, + 0.30543333f, 0.57191116f, 0.41133702f, 0.5934334f, 0.5218f, 0.46735552f, + 0.73524815f, 0.5152815f, 0.47753704f, 0.6577852f, 0.5741519f, 0.41896293f, + 0.50037766f, 0.57161117f, 0.3686555f, 0.28967398f, 0.5281297f, 0.3238592f, + 0.24753332f, 0.5194334f, 0.31489998f, 0.72816664f, 0.37683335f, 0.5285778f, + 0.3895555f, 0.5582283f, 0.32292962f, 0.18990126f, 0.6730641f, 0.18445063f, + 0.5460741f, 0.5216629f, 0.31464812f, 0.6978098f, 0.45279747f, 0.36710492f, + 0.5428901f, 0.5077358f, 0.30295062f, 0.42367774f, 0.53567034f, 0.28493333f, + 0.32827038f, 0.54560244f, 0.2976741f, 0.30918893f, 0.5475888f, 0.30022222f, + 0.5933333f, 0.44266668f, 0.59002227f, 0.3305555f, 0.4106049f, 0.31789258f, + 0.16793211f, 0.36878017f, 0.11760493f, 0.40592593f, 0.28790364f, 0.20468517f, + 0.5172234f, 0.22784683f, 0.27239504f, 0.4384765f, 0.19901967f, 0.3110494f, + 0.43695557f, 0.19709623f, 0.34693336f, 0.4869186f, 0.21310854f, 0.38097042f, + 0.49691117f, 0.21631104f, 0.3877778f, 0.37919992f, 0.4914f, 0.56826663f, + 0.26019996f, 0.34673333f, 0.29495183f, 0.21430746f, 0.23090371f, 0.09418149f, + 0.46084452f, 0.23042224f, 0.1835889f, 0.56450003f, 0.23844449f, 0.26893705f, + 0.45383334f, 0.2592223f, 0.34819633f, 0.45761114f, 0.21635559f, 0.38596666f, + 0.5376852f, 0.13105926f, 0.39607778f, 0.55370003f, 0.11400001f, 0.3981f, + 0.11219993f, 0.5287333f, 0.49104443f, 0.18227404f, 0.3386963f, 0.26007527f, + 0.30624574f, 0.20396544f, 0.09970618f, 0.6458075f, 0.2904593f, 0.22173704f, + 0.7636852f, 0.40607417f, 0.32631359f, 0.549037f, 0.5653705f, 0.40470868f, + 0.4831852f, 0.47417036f, 0.40968886f, 0.5165309f, 0.21597281f, 0.3657259f, + 0.5232f, 0.16433334f, 0.3569333f, 0.0588f, 0.5362f, 0.4756f, + 0.16668889f, 0.33708888f, 0.25309998f, 0.32463336f, 0.19857779f, 0.10081112f, + 0.68280005f, 0.3024667f, 0.22936666f, 0.80352217f, 0.43960005f, 0.33778888f, + 0.5680777f, 0.6266f, 0.41601112f, 0.4883f, 0.52573323f, 0.4144333f, + 0.5123f, 0.23295549f, 0.35965553f, 0.5171f, 0.1744f, 0.3487f + }); + //input.linspace(1); + + sd::ops::resize_bilinear op; + auto results = op.evaluate({&input}, {}, {9, 9}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + +// result.printBuffer("Resized to 9x9"); +// expected.printBuffer("Expect for 9x9"); +// result.printShapeInfo("Output shape"); +// expected.printShapeInfo("Expect shape"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test2) { + + NDArray input = NDArrayFactory::create('c', {1, 2,3,4}); + NDArray size = NDArrayFactory::create({10, 10}); + //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create('c', {1, 10, 10, 4}, {1., 2., 3., 4., 2.2, 3.2, 4.2, 5.2, 3.4, 4.4, 5.4, 6.4, + 4.6, 5.6, 6.6, 7.6, 5.8, 6.8, 7.8, 8.8, 7., 8., 9., 10., + 8.2, 9.2, 10.2, 11.2, 9., 10., 11., 12., 9., 10., 11., 12., + 9., 10., 11., 12., 3.4, 4.4, 5.4, 6.4, 4.6, 5.6, 6.6, 7.6, + 5.8, 6.8, 7.8, 8.8, 7.0, 8., 9., 10., 8.2, 9.2, 10.2, 11.2, + 9.4,10.4, 11.4, 12.4,10.6, 11.6,12.6, 13.6,11.4, 12.4, 13.4, 14.4, + 11.4,12.4, 13.4, 14.4,11.4, 12.4,13.4, 14.4, 5.8, 6.8, 7.8, 8.8, + 7., 8., 9., 10., 8.2, 9.2,10.2, 11.2, 9.4, 10.4, 11.4, 12.4, + 10.6,11.6, 12.6, 13.6,11.8, 12.8,13.8, 14.8,13.0, 14.0, 15.0, 16., + 13.8,14.8, 15.8, 16.8,13.8, 14.8,15.8, 16.8,13.8, 14.8, 15.8, 16.8, + 8.2, 9.2, 10.2, 11.2, 9.4, 10.4,11.4, 12.4,10.6, 11.6, 12.6, 13.6, + 11.8,12.8, 13.8, 14.8,13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, + 15.4,16.4, 17.4, 18.4,16.2, 17.2,18.2, 19.2,16.2, 17.2, 18.2, 19.2, + 16.2,17.2, 18.2, 19.2,10.6, 11.6,12.6, 13.6,11.8, 12.8, 13.8, 14.8, + 13., 14., 15., 16., 14.2, 15.2,16.2, 17.2,15.4, 16.4, 17.4, 18.4, + 16.6,17.6, 18.6, 19.6,17.8, 18.8,19.8, 20.8,18.6, 19.6, 20.6, 21.6, + 18.6,19.6, 20.6, 21.6,18.6, 19.6,20.6, 21.6,13., 14., 15., 16., + 14.2,15.2, 16.2, 17.2,15.4, 16.4,17.4, 18.4,16.6, 17.6, 18.6, 19.6, + 17.8,18.8, 19.8, 20.8,19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, + 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., + 13., 14., 15., 16., 14.2, 15.2,16.2, 17.2,15.4, 16.4, 17.4, 18.4, + 16.6,17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,19., 20., 21., 22., + 20.2,21.2, 22.2, 23.2,21., 22., 23., 24., 21., 22., 23., 24., + 21., 22., 23., 24., 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2, + 15.4,16.4, 17.4, 18.4,16.6, 17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8, + 19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2,21., 22., 23., 24., + 21., 22., 23., 24., 21., 22., 23., 24., 13., 14., 15., 16., + 14.2,15.2, 16.2, 17.2,15.4, 16.4, 17.4, 18.4,16.6, 17.6, 18.6, 19.6, + 17.8,18.8, 19.8, 20.8,19., 20., 21., 22., 20.2, 21.2, 22.2, 23.2, + 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24., + 13., 14., 15., 16., 14.2, 15.2, 16.2, 17.2,15.4, 16.4, 17.4, 18.4, + 16.6,17.6, 18.6, 19.6,17.8, 18.8, 19.8, 20.8,19., 20., 21., 22., + 20.2,21.2, 22.2, 23.2, + 21., 22., 23., 24., 21., 22., 23., 24., 21., 22., 23., 24.}); + //input = 1.f; + input.linspace(1); + + sd::ops::resize_bilinear op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test3) { + + NDArray input = NDArrayFactory::create('c', {1, 2,3,4}); + //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create('c', {1, 10, 10, 4}, + { 1., 2., 3., 4. , + 1.8888888, 2.8888888, 3.8888888, 4.888889, + 2.7777777, 3.7777777, 4.7777777, 5.7777777, + 3.6666667, 4.666667 , 5.666667, 6.666667 , + 4.5555553, 5.5555553, 6.5555553, 7.5555553, + 5.4444447, 6.4444447, 7.4444447, 8.444445, + 6.3333335, 7.3333335, 8.333334, 9.333334, + 7.2222223, 8.222222, 9.222222, 10.222222, + 8.111111, 9.111111, 10.111111, 11.111111, + 9., 10., 11., 12., + + 2.3333335, 3.3333335, 4.3333335, 5.3333335, + 3.2222223, 4.2222223, 5.2222223, 6.2222223, + 4.111111, 5.111111, 6.111111, 7.111111, + 5., 6., 7., 8., + 5.888889, 6.888889, 7.888889, 8.888888, + 6.777778, 7.777778, 8.777778, 9.777778, + 7.666667, 8.666667, 9.666667, 10.666667, + 8.555555, 9.555555, 10.555555, 11.555555, + 9.444444, 10.444444, 11.444444, 12.444444, + 10.333333, 11.333333, 12.333333, 13.333333, + + 3.6666667, 4.666667, 5.666667, 6.666667, + 4.5555553, 5.5555553, 6.5555553, 7.5555553, + 5.4444447, 6.4444447, 7.4444447, 8.444445 , + 6.3333335, 7.3333335, 8.333334, 9.333334 , + 7.2222223, 8.222222, 9.222222, 10.222222 , + 8.111112, 9.111112, 10.111112, 11.111112 , + 9., 10., 11.000001, 12.000001 , + 9.888889, 10.888889, 11.888889, 12.888889 , + 10.777778, 11.777778, 12.777778, 13.777778 , + 11.666667, 12.666667, 13.666667, 14.666667, + + 5., 6., 7., 8., + 5.888889, 6.888889, 7.888889, 8.888889, + 6.7777777, 7.7777777, 8.777779, 9.777779, + 7.666667, 8.666667, 9.666667, 10.666667, + 8.555555, 9.555555, 10.555555, 11.555555, + 9.444445, 10.444445, 11.444445, 12.444445, + 10.333334, 11.333334, 12.333334, 13.333334, + 11.222222, 12.222222, 13.222222, 14.222222, + 12.111111, 13.111111, 14.111111, 15.111111, + 13., 14., 15., 16., + + 6.3333335, 7.3333335, 8.333334, 9.333334, + 7.2222223, 8.222222, 9.222222, 10.222222, + 8.111111, 9.111111, 10.111112, 11.111112, + 9., 10., 11., 12., + 9.888889, 10.888889, 11.888889, 12.888889, + 10.777779, 11.777779, 12.777779, 13.777779, + 11.666667, 12.666667, 13.666668, 14.666668, + 12.555555, 13.555555, 14.555555, 15.555555, + 13.444445, 14.444445, 15.444445, 16.444445, + 14.333334, 15.333334, 16.333334, 17.333334, + 7.666667, 8.666667, 9.666667, 10.666667, + 8.555555, 9.555555, 10.555555, 11.555555, + 9.444445, 10.444445, 11.444445, 12.444445, + 10.333334, 11.333334, 12.333334, 13.333334, + 11.222222, 12.222222, 13.222222, 14.222222, + 12.111112, 13.111112, 14.111112, 15.111112, + 13., 14., 15.0, 16., + 13.888889, 14.888889, 15.888889, 16.88889, + 14.777778, 15.777778, 16.777779, 17.777779, + 15.666667, 16.666668, 17.666668, 18.666668, + + 9., 10., 11., 12., + 9.888889, 10.888889, 11.888889, 12.888889, + 10.777778, 11.777778, 12.777779, 13.777779, + 11.666667, 12.666666, 13.666666, 14.666666, + 12.555555, 13.555555, 14.555555, 15.555555, + 13.444445, 14.444445, 15.444445, 16.444445, + 14.333334, 15.333334, 16.333334, 17.333334, + 15.222221, 16.222221, 17.222221, 18.222221, + 16.11111, 17.11111, 18.11111, 19.11111, + 17., 18., 19., 20., + + 10.333334, 11.333334, 12.333334, 13.333334, + 11.222223, 12.222223, 13.222223, 14.222223, + 12.111112, 13.111112, 14.111112, 15.111112, + 13.000001, 14., 15., 16., + 13.888889, 14.888889, 15.888889, 16.88889, + 14.777779, 15.777779, 16.777779, 17.777779, + 15.666668, 16.666668, 17.666668, 18.666668, + 16.555555, 17.555555, 18.555555, 19.555555, + 17.444445, 18.444445, 19.444445, 20.444445, + 18.333334, 19.333334, 20.333334, 21.333334, + 11.666667, 12.666667, 13.666667, 14.666667, + 12.555555, 13.555555, 14.555555, 15.555555, + 13.444445, 14.444445, 15.444446, 16.444447, + 14.333334, 15.333333, 16.333332, 17.333332, + 15.222222, 16.222221, 17.222221, 18.222221, + 16.11111, 17.11111, 18.11111, 19.11111, + 17., 18., 19., 20., + 17.88889, 18.88889, 19.88889, 20.88889, + 18.777779, 19.777779, 20.777779, 21.777779, + 19.666668, 20.666668, 21.666668, 22.666668, + + 13., 14., 15., 16., + 13.888889, 14.888889, 15.888889, 16.88889, + 14.777778, 15.777778, 16.777779, 17.777779, + 15.666667, 16.666666, 17.666666, 18.666666, + 16.555555, 17.555555, 18.555555, 19.555555, + 17.444445, 18.444445, 19.444445, 20.444445, + 18.333334, 19.333334, 20.333334, 21.333334, + 19.222221, 20.222221, 21.222221, 22.222221, + 20.11111, 21.11111, 22.11111, 23.11111, + 21., 22., 23., 24.}); + //input = 1.f; + input.linspace(1); + + sd::ops::resize_bilinear op; + auto results = op.evaluate({&input}, {}, {10, 10}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, ImageResizeBilinear_Test4) { + + NDArray input = NDArrayFactory::create('c', {1, 2,3,4}); + NDArray size = NDArrayFactory::create({10, 10}); + NDArray expected = NDArrayFactory::create('c', {1, 10, 10, 4}, + { 1., 2., 3., 4. , + 1.8888888, 2.8888888, 3.8888888, 4.888889, + 2.7777777, 3.7777777, 4.7777777, 5.7777777, + 3.6666667, 4.666667 , 5.666667, 6.666667 , + 4.5555553, 5.5555553, 6.5555553, 7.5555553, + 5.4444447, 6.4444447, 7.4444447, 8.444445, + 6.3333335, 7.3333335, 8.333334, 9.333334, + 7.2222223, 8.222222, 9.222222, 10.222222, + 8.111111, 9.111111, 10.111111, 11.111111, + 9., 10., 11., 12., + + 2.3333335, 3.3333335, 4.3333335, 5.3333335, + 3.2222223, 4.2222223, 5.2222223, 6.2222223, + 4.111111, 5.111111, 6.111111, 7.111111, + 5., 6., 7., 8., + 5.888889, 6.888889, 7.888889, 8.888888, + 6.777778, 7.777778, 8.777778, 9.777778, + 7.666667, 8.666667, 9.666667, 10.666667, + 8.555555, 9.555555, 10.555555, 11.555555, + 9.444444, 10.444444, 11.444444, 12.444444, + 10.333333, 11.333333, 12.333333, 13.333333, + + 3.6666667, 4.666667, 5.666667, 6.666667, + 4.5555553, 5.5555553, 6.5555553, 7.5555553, + 5.4444447, 6.4444447, 7.4444447, 8.444445 , + 6.3333335, 7.3333335, 8.333334, 9.333334 , + 7.2222223, 8.222222, 9.222222, 10.222222 , + 8.111112, 9.111112, 10.111112, 11.111112 , + 9., 10., 11.000001, 12.000001 , + 9.888889, 10.888889, 11.888889, 12.888889 , + 10.777778, 11.777778, 12.777778, 13.777778 , + 11.666667, 12.666667, 13.666667, 14.666667, + + 5., 6., 7., 8., + 5.888889, 6.888889, 7.888889, 8.888889, + 6.7777777, 7.7777777, 8.777779, 9.777779, + 7.666667, 8.666667, 9.666667, 10.666667, + 8.555555, 9.555555, 10.555555, 11.555555, + 9.444445, 10.444445, 11.444445, 12.444445, + 10.333334, 11.333334, 12.333334, 13.333334, + 11.222222, 12.222222, 13.222222, 14.222222, + 12.111111, 13.111111, 14.111111, 15.111111, + 13., 14., 15., 16., + + 6.3333335, 7.3333335, 8.333334, 9.333334, + 7.2222223, 8.222222, 9.222222, 10.222222, + 8.111111, 9.111111, 10.111112, 11.111112, + 9., 10., 11., 12., + 9.888889, 10.888889, 11.888889, 12.888889, + 10.777779, 11.777779, 12.777779, 13.777779, + 11.666667, 12.666667, 13.666668, 14.666668, + 12.555555, 13.555555, 14.555555, 15.555555, + 13.444445, 14.444445, 15.444445, 16.444445, + 14.333334, 15.333334, 16.333334, 17.333334, + 7.666667, 8.666667, 9.666667, 10.666667, + 8.555555, 9.555555, 10.555555, 11.555555, + 9.444445, 10.444445, 11.444445, 12.444445, + 10.333334, 11.333334, 12.333334, 13.333334, + 11.222222, 12.222222, 13.222222, 14.222222, + 12.111112, 13.111112, 14.111112, 15.111112, + 13., 14., 15.0, 16., + 13.888889, 14.888889, 15.888889, 16.88889, + 14.777778, 15.777778, 16.777779, 17.777779, + 15.666667, 16.666668, 17.666668, 18.666668, + + 9., 10., 11., 12., + 9.888889, 10.888889, 11.888889, 12.888889, + 10.777778, 11.777778, 12.777779, 13.777779, + 11.666667, 12.666666, 13.666666, 14.666666, + 12.555555, 13.555555, 14.555555, 15.555555, + 13.444445, 14.444445, 15.444445, 16.444445, + 14.333334, 15.333334, 16.333334, 17.333334, + 15.222221, 16.222221, 17.222221, 18.222221, + 16.11111, 17.11111, 18.11111, 19.11111, + 17., 18., 19., 20., + + 10.333334, 11.333334, 12.333334, 13.333334, + 11.222223, 12.222223, 13.222223, 14.222223, + 12.111112, 13.111112, 14.111112, 15.111112, + 13.000001, 14., 15., 16., + 13.888889, 14.888889, 15.888889, 16.88889, + 14.777779, 15.777779, 16.777779, 17.777779, + 15.666668, 16.666668, 17.666668, 18.666668, + 16.555555, 17.555555, 18.555555, 19.555555, + 17.444445, 18.444445, 19.444445, 20.444445, + 18.333334, 19.333334, 20.333334, 21.333334, + 11.666667, 12.666667, 13.666667, 14.666667, + 12.555555, 13.555555, 14.555555, 15.555555, + 13.444445, 14.444445, 15.444446, 16.444447, + 14.333334, 15.333333, 16.333332, 17.333332, + 15.222222, 16.222221, 17.222221, 18.222221, + 16.11111, 17.11111, 18.11111, 19.11111, + 17., 18., 19., 20., + 17.88889, 18.88889, 19.88889, 20.88889, + 18.777779, 19.777779, 20.777779, 21.777779, + 19.666668, 20.666668, 21.666668, 22.666668, + + 13., 14., 15., 16., + 13.888889, 14.888889, 15.888889, 16.88889, + 14.777778, 15.777778, 16.777779, 17.777779, + 15.666667, 16.666666, 17.666666, 18.666666, + 16.555555, 17.555555, 18.555555, 19.555555, + 17.444445, 18.444445, 19.444445, 20.444445, + 18.333334, 19.333334, 20.333334, 21.333334, + 19.222221, 20.222221, 21.222221, 22.222221, + 20.11111, 21.11111, 22.11111, 23.11111, + 21., 22., 23., 24.}); + //input = 1.f; + input.linspace(1); + + sd::ops::resize_bilinear op; + auto results = op.evaluate({&input, &size}, {}, {}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); +// result.printIndexedBuffer("Resized to 10x10"); +// expected.printIndexedBuffer("Expected of 10x10"); +// result.printShapeInfo("Resized to 10x10 shape"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, LinSpace_Test1) { + + NDArray start = NDArrayFactory::create(1.); + NDArray finish = NDArrayFactory::create(12.); + NDArray num = NDArrayFactory::create(23); + NDArray expect = NDArrayFactory::create({1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6., 6.5, 7., 7.5, + 8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.}); + + sd::ops::lin_space op; + auto result = op.evaluate({&start, &finish, &num}, {}, {}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto res = result.at(0); + + ASSERT_TRUE(expect.equalsTo(res)); + +} +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, LinSpace_Test2) { + + NDArray expect = NDArrayFactory::create({1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6., 6.5, 7., 7.5, + 8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.}); + + sd::ops::lin_space op; + auto result = op.evaluate({}, {1, 12}, {23}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto res = result.at(0); + ASSERT_EQ( res->dataType(), sd::DataType::FLOAT32 ); + ASSERT_TRUE(expect.equalsTo(res)); + +} +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, LinSpace_Test3) { + + NDArray expect('c', { 23 }, {1., 1.5, 2., 2.5, 3., 3.5, 4., 4.5, 5., 5.5, 6., 6.5, 7., 7.5, 8., 8.5, 9., 9.5, 10., 10.5, 11., 11.5, 12.}, sd::DataType::DOUBLE ); + + sd::ops::lin_space op; + auto result = op.evaluate({}, {1, 12}, {23}, {}, { sd::DOUBLE }); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto res = result.at(0); + + ASSERT_EQ( res->dataType(), expect.dataType()); + ASSERT_TRUE(expect.equalsTo(res)); + +} +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1) { + + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); + //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create('c', {1, 4, 5, 4}, { + 1, 2, 3, 4, + 1, 2, 3, 4, + 5, 6, 7, 8, + 5, 6, 7, 8, + 9, 10, 11, 12, + + 1, 2, 3, 4, + 1, 2, 3, 4, + 5, 6, 7, 8, + 5, 6, 7, 8, + 9, 10, 11, 12, + + 13, 14, 15, 16, + 13, 14, 15, 16, + 17, 18, 19, 20, + 17, 18, 19, 20, + 21, 22, 23, 24, + + 13, 14, 15, 16, + 13, 14, 15, 16, + 17, 18, 19, 20, + 17, 18, 19, 20, + 21, 22, 23, 24 + }); + //input = 1.f; + input.linspace(1); + + sd::ops::resize_nearest_neighbor op; + auto results = op.evaluate({&input}, {}, {4, 5}, {false, false}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + +// result.printIndexedBuffer("Resized to 4x5"); +// expected.printIndexedBuffer("Expect for 4x5"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1) { + + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); + //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create('c', {1, 4, 5, 4}, { + 1, 2, 3, 4, + 1, 2, 3, 4, + 5, 6, 7, 8, + 5, 6, 7, 8, + 9, 10, 11, 12, + + 1, 2, 3, 4, + 1, 2, 3, 4, + 5, 6, 7, 8, + 5, 6, 7, 8, + 9, 10, 11, 12, + + 13, 14, 15, 16, + 13, 14, 15, 16, + 17, 18, 19, 20, + 17, 18, 19, 20, + 21, 22, 23, 24, + + 13, 14, 15, 16, + 13, 14, 15, 16, + 17, 18, 19, 20, + 17, 18, 19, 20, + 21, 22, 23, 24 + }); + //input = 1.f; + input.linspace(1); + + sd::ops::resize_nearest_neighbor op; + auto results = op.evaluate({&input}, {}, {4, 5}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + +// result.printIndexedBuffer("Resized to 4x5"); +// expected.printIndexedBuffer("Expect for 4x5"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test1_1_1) { + + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}); + //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create('c', {1, 4, 5, 4}, { + 1.f, 2.f, 3.f, 4.f, + 1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, + 9.f, 10.f, 11.f, 12.f, + + 1.f, 2.f, 3.f, 4.f, + 1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, + 9.f, 10.f, 11.f, 12.f, + + 13.f, 14.f, 15.f, 16.f, + 13.f, 14.f, 15.f, 16.f, + 17.f, 18.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, + 21.f, 22.f, 23.f, 24.f, + + 13.f, 14.f, 15.f, 16.f, + 13.f, 14.f, 15.f, 16.f, + 17.f, 18.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, + 21.f, 22.f, 23.f, 24.f + }); + //input = 1.f; + input.linspace(1); + + sd::ops::resize_nearest_neighbor op; + auto results = op.evaluate({&input}, {}, {4,5}, {false, true}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + +// result.printIndexedBuffer("Resized to 4x5"); +// expected.printBuffer("Expect for 4x5"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests10, ImageResizeNeighbor_Test01) { + + NDArray input = NDArrayFactory::create('c', {2, 3, 4}); + //NDArray paddings('c', {3,2}, {0,0, 0,1, 0,0}); + //NDArray expected('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + NDArray expected = NDArrayFactory::create('c', {4, 5, 4}, { 1, 2, 3, 4, + 1, 2, 3, 4, + 5, 6, 7, 8, + 5, 6, 7, 8, + 9, 10, 11, 12, + + 1, 2, 3, 4, + 1, 2, 3, 4, + 5, 6, 7, 8, + 5, 6, 7, 8, + 9, 10, 11, 12, + + 13, 14, 15, 16, + 13, 14, 15, 16, + 17, 18, 19, 20, + 17, 18, 19, 20, + 21, 22, 23, 24, + + 13, 14, 15, 16, + 13, 14, 15, 16, + 17, 18, 19, 20, + 17, 18, 19, 20, + 21, 22, 23, 24 + }); + //input = 1.f; + input.linspace(1); + + sd::ops::resize_nearest_neighbor op; + auto results = op.evaluate({&input}, {}, {4, 5}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + + //result.printIndexedBuffer("Resized to 4x5"); + //expected.printIndexedBuffer("Expect for 4x5"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_1) { + + NDArray input = NDArrayFactory::create ('c', {3,3}, {0, 1, 0, 0, 1, 0, 0, 0, 0}); + + NDArray expected = NDArrayFactory::create(2.5206409f); + + sd::ops::reduce_logsumexp op; + auto results = op.evaluate({&input}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_2) { + + NDArray input = NDArrayFactory::create('c', {3,3}, {0, 1, 0, 0, 1, 0, 0, 0, 0}); + + NDArray expected = NDArrayFactory::create({1.0986123f, 1.8619947f, 1.0986123f}); + + sd::ops::reduce_logsumexp op; + auto results = op.evaluate({&input}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); +// result.printIndexedBuffer("REDUCE_LOGSUMEXP"); +// expected.printIndexedBuffer("LSE EXPECTED"); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, ReduceLogSumExpTest_3) { + + NDArray input = NDArrayFactory::create('c', {3,3}, {0, 1, 0, 0, 1, 0, 0, 0, 0}); + + NDArray expected = NDArrayFactory::create('c', {1,3}, {1.0986123f, 1.8619947f, 1.0986123f}); + + sd::ops::reduce_logsumexp op; + auto results = op.evaluate({&input}, {1.f}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); +// result.printIndexedBuffer("REDUCE_LOGSUMEXP"); +// expected.printIndexedBuffer("LSE EXPECTED"); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); +} +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_1) { + + NDArray boxes = NDArrayFactory::create('c', {3,4}); + NDArray scores = NDArrayFactory::create('c', {3}, {1, 2, 3}); + NDArray expected = NDArrayFactory::create('c', {3}, {2, 1, 0}); + boxes.linspace(1.f); + + sd::ops::non_max_suppression op; + auto results = op.evaluate({&boxes, &scores}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + //result.printIndexedBuffer("OOOOUUUUTTT"); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_2) { + + NDArray boxes = NDArrayFactory::create('c', {6,4}, {0, 0, 1, 1, 0, 0.1f, 1, 1.1f, 0, -0.1f, 1.f, 0.9f, + 0, 10, 1, 11, 0, 10.1f, 1.f, 11.1f, 0, 100, 1, 101}); + NDArray scales = NDArrayFactory::create('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); //3, 0, 1, 2, 4, 5 + NDArray expected = NDArrayFactory::create('c', {3}, {3,0,5}); + + sd::ops::non_max_suppression op; + auto results = op.evaluate({&boxes, &scales}, {0.5}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); +// result.printBuffer("NonMaxSuppression OUtput2"); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_3) { + + NDArray boxes = NDArrayFactory::create('c', {3, 4}, {0.8115f, 0.4121f, 0.0771f, 0.4863f, + 0.7412f, 0.7607f, 0.1543f, 0.5479f, + 0.8223f, 0.2246f, 0.0049f, 0.6465f}); + NDArray scales = NDArrayFactory::create('c', {3}, {0.0029f, 0.8135f, 0.4873f}); //3, 0, 1, 2, 4, 5 + NDArray expected = NDArrayFactory::create('c', {1}, {1}); + + sd::ops::non_max_suppression op; + auto results = op.evaluate({&boxes, &scales}, {0.5, 0.5}, {2}); + + ASSERT_EQ(Status::OK(), results.status()); + + NDArray* result = results.at(0); +// result.printBuffer("NonMaxSuppression OUtput3"); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_4) { + + NDArray boxes = NDArrayFactory::create('c', {3, 4}, {0.8115f, 0.4121f, 0.0771f, 0.4863f, + 0.7412f, 0.7607f, 0.1543f, 0.5479f, + 0.8223f, 0.2246f, 0.0049f, 0.6465f}); + NDArray scales = NDArrayFactory::create('c', {3}, {0.0029f, 0.8135f, 0.4873f}); //3, 0, 1, 2, 4, 5 + NDArray expected = NDArrayFactory::create('c', {1}, {1}); + NDArray maxSize = NDArrayFactory::create(2); + NDArray threshold = NDArrayFactory::create(0.5f); + NDArray scoreThreshold = NDArrayFactory::create(0.5); + sd::ops::non_max_suppression op; + auto results = op.evaluate({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); + + ASSERT_EQ(Status::OK(), results.status()); + + NDArray* result = results.at(0); +// result.printBuffer("NonMaxSuppression OUtput4"); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); +} +TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_5) { + + NDArray boxes = NDArrayFactory::create('c', {3, 4}, {0.8115f, 0.4121f, 0.0771f, 0.4863f, + 0.7412f, 0.7607f, 0.1543f, 0.5479f, + 0.8223f, 0.2246f, 0.0049f, 0.6465f}); + NDArray scales = NDArrayFactory::create('c', {3}, {0.0029f, 0.8135f, 0.4873f}); //3, 0, 1, 2, 4, 5 + NDArray expected = NDArrayFactory::create('c', {2}, {1, 2}); + NDArray maxSize = NDArrayFactory::create(2); + NDArray threshold = NDArrayFactory::create(0.5f); + NDArray scoreThreshold = NDArrayFactory::create(-DataTypeUtils::infOrMax()); + sd::ops::non_max_suppression op; + auto results = op.evaluate({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); + + ASSERT_EQ(Status::OK(), results.status()); + + NDArray* result = results.at(0); +// result.printBuffer("NonMaxSuppression OUtput4"); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_6) { + + NDArray boxes = NDArrayFactory::create('c', {3, 4}, {0.8115f, 0.4121f, 0.0771f, 0.4863f, + 0.7412f, 0.7607f, 0.1543f, 0.5479f, + 0.8223f, 0.2246f, 0.0049f, 0.6465f}); + NDArray scales = NDArrayFactory::create('c', {3}, {0.0029f, 0.8135f, 0.4873f}); //3, 0, 1, 2, 4, 5 + NDArray expected = NDArrayFactory::create('c', {2}, {1,2}); + NDArray maxSize = NDArrayFactory::create(2); + NDArray threshold = NDArrayFactory::create(0.5f); + NDArray scoreThreshold = NDArrayFactory::create(-DataTypeUtils::infOrMax()); + sd::ops::non_max_suppression_v3 op; + auto results = op.evaluate({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); + + ASSERT_EQ(Status::OK(), results.status()); + + NDArray* result = results.at(0); +// result.printBuffer("NonMaxSuppression OUtput6"); +// result.printShapeInfo("Ouput6 shape is"); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_06) { + + NDArray boxes = NDArrayFactory::create('c', {3, 4}, {0.8115f, 0.4121f, 0.0771f, 0.4863f, + 0.7412f, 0.7607f, 0.1543f, 0.5479f, + 0.8223f, 0.2246f, 0.0049f, 0.6465f}); + NDArray scales = NDArrayFactory::create('c', {3}, {0.0029f, 0.8135f, 0.4873f}); //3, 0, 1, 2, 4, 5 + NDArray expected = NDArrayFactory::create('c', {2}, {1,2}); + NDArray maxSize = NDArrayFactory::create(2); + NDArray threshold = NDArrayFactory::create(0.5f); + NDArray scoreThreshold = NDArrayFactory::create(-DataTypeUtils::infOrMax()); + sd::ops::non_max_suppression_v3 op; + auto results = op.evaluate({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); + + ASSERT_EQ(Status::OK(), results.status()); + + NDArray* result = results.at(0); +// result.printBuffer("NonMaxSuppression OUtput06"); +// result.printShapeInfo("Ouput06 shape is"); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_7) { + + NDArray boxes = NDArrayFactory::create('c', {3, 4}, {0.7788f, 0.8012f, 0.7244f, 0.2329f, + 0.7271f, 0.1804f, 0.5056f, 0.8929f, + 0.5461f, 0.9234f, 0.0856f, 0.7938f}); + NDArray scales = NDArrayFactory::create('c', {3}, {0.7717f, 0.9281f, 0.9846f}); //3, 0, 1, 2, 4, 5 + NDArray maxSize = NDArrayFactory::create(0); + NDArray threshold = NDArrayFactory::create(0.5f); + NDArray scoreThreshold = NDArrayFactory::create(0.5f); + sd::ops::non_max_suppression_v3 op; + auto results = op.evaluate({&boxes, &scales, &maxSize, &threshold, &scoreThreshold}, {}, {}); + + ASSERT_EQ(Status::OK(), results.status()); + + NDArray* result = results.at(0); +// result.printBuffer("NonMaxSuppression OUtput7"); +// result.printShapeInfo("Ouput6 shape is"); + ASSERT_TRUE(result->isEmpty()); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_1) { + + NDArray boxes = NDArrayFactory::create('c', {4,4}, { + 0, 0, 1, 1, + 0, 0.1, 1, 1.1, + 0, -0.1, 1, 0.9, + 0, 10, 1, 11}); + NDArray scores = NDArrayFactory::create('c', {4}, {0.9, .75, .6, .95}); //3 + NDArray max_num = NDArrayFactory::create(3); + NDArray expected = NDArrayFactory::create('c', {1,}, {3}); + + sd::ops::non_max_suppression_overlaps op; + auto results = op.evaluate({&boxes, &scores, &max_num}, {0.5, 0.}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); +// result.printBuffer("NonMaxSuppressionOverlap1 Output"); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_2) { + + NDArray boxes = NDArrayFactory::create('c', {4,4}, { + 0, 0, 1, 1, + 0, 0.1, 1, 1.1, + 0, -0.1, 1, 0.9, + 0, 10, 1, 11}); + NDArray scores = NDArrayFactory::create('c', {4}, {0.9, .95, .6, .75}); //3 + NDArray max_num = NDArrayFactory::create(3); + NDArray expected = NDArrayFactory::create('c', {3,}, {1,1,1}); + + sd::ops::non_max_suppression_overlaps op; + auto results = op.evaluate({&boxes, &scores, &max_num}, {0.5, 0.}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); +// result.printBuffer("NonMaxSuppressionOverlap Output"); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressingOverlap_3) { + + NDArray boxes = NDArrayFactory::create('c', {4,4}, { + 0, 0, 1, 1, + 0, 0.1, 1, 1.1, + 0, -0.1, 1, 0.9, + 0, 10, 1, 11}); + NDArray scores = NDArrayFactory::create('c', {4}, {0.5, .95, -.6, .75}); //3 + NDArray max_num = NDArrayFactory::create(5); + NDArray expected = NDArrayFactory::create('c', {5,}, {1,1,1,1,1}); + + sd::ops::non_max_suppression_overlaps op; + auto results = op.evaluate({&boxes, &scores, &max_num}, {0.5, 0.}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); +// result.printBuffer("NonMaxSuppressionOverlap Output"); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, Image_CropAndResize_1) { + int axis = 0; + NDArray images = NDArrayFactory::create('c', {1,2,2,1}, {1,2,3,4}); + NDArray boxes = NDArrayFactory::create('c', {1,4}, {0,0,1,1}); + NDArray boxI = NDArrayFactory::create('c', {1}, {axis}); + NDArray cropSize = NDArrayFactory::create({1, 1}); + + //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); + NDArray expected = NDArrayFactory::create('c', {1,1,1,1}, {2.5f}); + + sd::ops::crop_and_resize op; + auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); +// result.printIndexedBuffer("Cropped and Resized"); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, Image_CropAndResize_2) { + int axis = 0; + NDArray images = NDArrayFactory::create('c', {1,2,2,1}, {1.f, 2.f, 3.f, 4.f}); + NDArray boxes = NDArrayFactory::create('c', {1,4}, {0.f, 0.f, 1.f, 1.f}); + NDArray boxI = NDArrayFactory::create('c', {1}, {axis}); + NDArray cropSize = NDArrayFactory::create({1, 1}); + + //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); + NDArray expected = NDArrayFactory::create('c', {1,1,1,1}, {4.f}); + + sd::ops::crop_and_resize op; + auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, Image_CropAndResize_3) { + + NDArray images ('c', {1,2,2,1}, {1,2,3,4}, sd::DataType::FLOAT32); + NDArray boxes('c', {1,4}, {0,0,1,1}, sd::DataType::FLOAT32); + NDArray boxI('c', {1}, std::vector{0}, sd::DataType::INT64); + NDArray cropSize = NDArrayFactory::create({3, 3}); + + //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); + NDArray expected('c', {1,3,3,1}, {1.f, 1.5f, 2., 2.f, 2.5f, 3.f, 3.f, 3.5f, 4.f}, sd::DataType::FLOAT32); + + sd::ops::crop_and_resize op; + auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, Image_CropAndResize_4) { + + NDArray images('c', {1,2,2,1}, {1, 2, 3, 4}, sd::DataType::FLOAT32); + NDArray boxes('c', {1,4}, {0,0,1,1}, sd::DataType::FLOAT32); + NDArray boxI('c', {1}, std::vector({0.}), sd::DataType::INT32); + NDArray cropSize = NDArrayFactory::create({3, 3}); + + //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); + NDArray expected('c', {1,3,3,1}, {1.f, 2.f, 2.f, 3.f, 4, 4.f, 3.f, 4.f, 4.f}, sd::DataType::FLOAT32); + + sd::ops::crop_and_resize op; + auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result.printIndexedBuffer("Cropped and Resized"); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, Image_CropAndResize_5) { + + NDArray images('c', {1, 100, 100, 3}, sd::DataType::FLOAT32); + NDArray boxes('c', {1,4}, {0,0,1,1}, sd::DataType::FLOAT32); + NDArray boxI('c', {2}, {1,1}, sd::DataType::INT32); + NDArray cropSize = NDArrayFactory::create({10, 10}); + + //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); + NDArray expected('c', {1, 10, 10,3}, sd::DataType::FLOAT32); + + sd::ops::crop_and_resize op; + auto results = op.evaluate({&images, &boxes, &boxI, &cropSize}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + //ASSERT_TRUE(expected.equalsTo(result)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_1) { + NDArray images = NDArrayFactory::create('c', {2,4,5,3}); + NDArray boxes = NDArrayFactory::create('c', {2, 2, 4}, { + 0.f , 0.f , 1.f , 1.f , 0.1f, 0.2f, 0.9f, 0.8f, + 0.3f, 0.3f, 0.7f, 0.7f, 0.4f, 0.4f, 0.6f, 0.6f + }); + + NDArray colors = NDArrayFactory::create('c', {2, 3}, {201.f, 202.f, 203.f, 127.f, 128.f, 129.f}); + + //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); + NDArray expected = NDArrayFactory::create('c', {2,4,5,3}, { + 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 201.f, 202.f, 203.f, + 127.f, 128.f, 129.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 127.f, 128.f, 129.f, 201.f, 202.f, 203.f, + 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 201.f, 202.f, 203.f, + 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, + + 61.f, 62.f, 63.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 70.f, 71.f, 72.f, 73.f, 74.f, 75.f, + 76.f, 77.f, 78.f, 127.f, 128.f, 129.f, 127.f, 128.f, 129.f, 85.f, 86.f, 87.f, 88.f, 89.f, 90.f, + 91.f, 92.f, 93.f, 201.f, 202.f, 203.f, 201.f, 202.f, 203.f, 100.f, 101.f, 102.f, 103.f, 104.f, 105.f, + 106.f, 107.f, 108.f, 109.f, 110.f, 111.f, 112.f, 113.f, 114.f, 115.f, 116.f, 117.f, 118.f, 119.f, 120.f + }); + images.linspace(1.); + sd::ops::draw_bounding_boxes op; + auto results = op.evaluate({&images, &boxes, &colors}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + result->syncToHost(); +// result.printBuffer("Bounded boxes"); +// expected.printBuffer("Bounded expec"); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_2) { + NDArray images = NDArrayFactory::create('c', {1,9,9,1}); + NDArray boxes = NDArrayFactory::create('c', {1, 1, 4}, {0.2f, 0.2f, 0.7f, 0.7f}); + NDArray colors = NDArrayFactory::create('c', {1, 1}, {0.95f}); + + //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); + NDArray expected = NDArrayFactory::create('c', {1,9,9,1}, { + 1.1f , 2.1f, 3.1f, 4.1f, 5.1f, 6.1f, 7.1f , 8.1f , 9.1f , + 10.1f , 0.95f, 0.95f, 0.95f, 0.95f, 0.95f, 16.1f , 17.1f , 18.1f , + 19.1f , 0.95f, 21.1f, 22.1f, 23.1f, 0.95f, 25.1f , 26.1f , 27.1f , + 28.1f , 0.95f, 30.1f, 31.1f, 32.1f, 0.95f, 34.1f , 35.1f , 36.1f , + 37.1f , 0.95f, 39.1f, 40.1f, 41.1f, 0.95f, 43.1f , 44.1f , 45.1f , + 46.1f , 0.95f, 0.95f, 0.95f, 0.95f, 0.95f, 52.1f , 53.1f , 54.1f , + 55.1f , 56.1f, 57.1f, 58.1f, 59.1f , 60.1f, 61.1f , 62.1f , 63.1f , + 64.1f , 65.1f, 66.1f, 67.1f, 68.1f , 69.1f, 70.1f , 71.1f , 72.1f , + 73.1f , 74.1f, 75.1f, 76.1f, 77.1f , 78.1f, 79.1f , 80.1f , 81.1f }); + images.linspace(1.1); + sd::ops::draw_bounding_boxes op; + auto results = op.evaluate({&images, &boxes, &colors}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); +// result.syncToHost(); +// result.printBuffer("Bounded boxes 2"); +// expected.printBuffer("Bounded expec 2"); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, Image_DrawBoundingBoxes_3) { + NDArray images = NDArrayFactory::create('c', {2,5,5,1}, {0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, 0.1804f, + 0.5056f, 0.8925f, 0.5461f, 0.9234f, 0.0856f, 0.7938f, + 0.6591f, 0.5555f, 0.1596f, 0.3087f, 0.1548f, 0.4695f, + 0.9939f, 0.6113f, 0.6765f, 0.1800f, 0.6750f, 0.2246f, + 0.0509f, 0.4601f, 0.8284f, 0.2354f, 0.9752f, 0.8361f, + 0.2585f, 0.4189f, 0.7028f, 0.7679f, 0.5373f, 0.7234f, + 0.2690f, 0.0062f, 0.0327f, 0.0644f, 0.8428f, 0.7494f, + 0.0755f, 0.6245f, 0.3491f, 0.5793f, 0.5730f, 0.1822f, + 0.6420f, 0.9143f}); + + NDArray boxes = NDArrayFactory::create('c', {2, 2, 4}, {0.7717f, 0.9281f, 0.9846f, 0.4838f, + 0.6433f, 0.6041f, 0.6501f, 0.7612f, + 0.7605f, 0.3948f, 0.9493f, 0.8600f, + 0.7876f, 0.8945f, 0.4638f, 0.7157f}); + NDArray colors = NDArrayFactory::create('c', {1, 2}, {0.9441f, 0.5957f}); + + //NDArray ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); +// NDArray expected = NDArrayFactory::create('c', {2,5,5,1}, { +// 0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, +// 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, 0.0856f, 0.7938f, 0.9441f, +// 0.9441f, 0.1596f, 0.3087f, 0.1548f, 0.4695f, 0.9939f, 0.6113f, 0.6765f, +// 0.1800f, 0.6750f, 0.2246f, 0.0509f, 0.4601f, 0.8284f, 0.2354f, 0.9752f, 0.8361f, +// 0.2585f, 0.4189f,0.7028f,0.7679f,0.5373f,0.7234f,0.2690f,0.0062f,0.0327f,0.0644f, +// 0.8428f, 0.9441f,0.9441f,0.9441f,0.3491f,0.5793f,0.5730f,0.1822f,0.6420f,0.9143f }); + NDArray expected = NDArrayFactory::create('c', {2,5,5,1}, { + 0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, + 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, + 0.0856f, 0.7938f, 0.9441f, 0.9441f, 0.1596f, + 0.3087f, 0.1548f, 0.4695f, 0.9939f, 0.6113f, + 0.6765f, 0.18f , 0.675f , 0.2246f, 0.0509f, + + 0.4601f, 0.8284f, 0.2354f, 0.9752f, 0.8361f, + 0.2585f, 0.4189f, 0.7028f, 0.7679f, 0.5373f, + 0.7234f, 0.269f , 0.0062f, 0.0327f, 0.0644f, + 0.8428f, 0.9441f, 0.9441f, 0.9441f, 0.3491f, + 0.5793f, 0.573f , 0.1822f, 0.642f , 0.9143f}); + sd::ops::draw_bounding_boxes op; + auto results = op.evaluate({&images, &boxes, &colors}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); +// result.printBuffer("Boxes3 output"); +// expected.printBuffer("Boxes3 expect"); + +// result.syncToHost(); +// result.printBuffer("Bounded boxes 2"); +// expected.printBuffer("Bounded expec 2"); + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_1) { + + NDArray x('c', {2,3}, {-63.80f, -63.75f, -63.70f, -63.5f, 0.0f, 0.1f}, sd::DataType::FLOAT32); + NDArray exp('c', {2,3}, {-63.75f, -63.75f, -63.75f, -63.5f, 0.f, 0.f}, sd::DataType::FLOAT32); + NDArray min('c', {}, std::vector{-63.65f}, sd::DataType::FLOAT32); + NDArray max('c', {}, std::vector{0.1f}, sd::DataType::FLOAT32); + + sd::ops::fake_quant_with_min_max_vars op; + auto results = op.evaluate({&x, &min, &max}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); +// result.printBuffer("Quantized"); +// exp.printBuffer("Expected"); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); + ASSERT_TRUE(exp.equalsTo(result)); +} +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_2) { + + NDArray x = NDArrayFactory::create('c', {2,3}, {-63.80, -63.75, -63.4, -63.5, 0.0, 0.1}); + NDArray exp = NDArrayFactory::create('c', {2,3}, {-63.75, -63.75, -63.5 , -63.5 , 0. , 0. }); + NDArray min = NDArrayFactory::create(-63.65); + NDArray max = NDArrayFactory::create(0.1); + + sd::ops::fake_quant_with_min_max_vars op; + auto results = op.evaluate({&x, &min, &max}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result.printIndexedBuffer("Quantized2"); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); + ASSERT_TRUE(exp.equalsTo(result)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_3) { + + NDArray x = NDArrayFactory::create('c', {1,2,3,1}, {-63.80, -63.75, -63.4, -63.5, 0.0, 0.1}); + NDArray exp = NDArrayFactory::create('c', {1,2,3,1}, {-63.75, -63.75, -63.5 , -63.5 , 0. , 0. }); + NDArray min = NDArrayFactory::create('c', {1},{-63.65}); + NDArray max = NDArrayFactory::create('c', {1}, {0.1}); + + sd::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.evaluate({&x, &min, &max}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result.printIndexedBuffer("Quantized2"); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); + ASSERT_TRUE(exp.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03) { + NDArray x = NDArrayFactory::create('c', {3,5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f, + 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, + 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); + NDArray exp = NDArrayFactory::create('c', {3,5}, { + 0.777002f, 0.596913f, 0.72314f, 0.231040f, 0.509824f, + 0.179308f, 0.505282f, 0.86846f, 0.349958f, 0.509824f, + 0.087355f, 0.596913f, 0.65740f, 0.349958f, 0.159745f}); + NDArray min = NDArrayFactory::create({-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); + NDArray max = NDArrayFactory::create({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); + + sd::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.evaluate({&x, &min, &max}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); +// result.printIndexedBuffer("Quantized03"); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); + ASSERT_TRUE(exp.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_1) { + NDArray x = NDArrayFactory::create('c', {3,5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f, + 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, + 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); + NDArray exp = NDArrayFactory::create('c', {3,5}, { + 0.780061f, 0.596635f, 0.725987f, 0.231950f, 0.508419f, + 0.180014f, 0.504643f, 0.868406f, 0.351335f, 0.508419f, + 0.087699f, 0.596635f, 0.659988f, 0.351335f, 0.160374f}); + NDArray min = NDArrayFactory::create({-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); + NDArray max = NDArrayFactory::create({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); + + sd::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.evaluate({&x, &min, &max}, {}, {8}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); +// result.printIndexedBuffer("Quantized03_1"); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); + ASSERT_TRUE(exp.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_2) { + NDArray x = NDArrayFactory::create('c', {3,5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f, + 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, + 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); + NDArray exp = NDArrayFactory::create('c', {3,5}, { + 0.775297f, 0.592226f, 0.725763f, 0.237561f, 0.503245f, + 0.189097f, 0.506084f, 0.868069f, 0.349355f, 0.503245f, + 0.094548f, 0.592226f, 0.654610f, 0.349355f, 0.153769f}); + NDArray min = NDArrayFactory::create({-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); + NDArray max = NDArrayFactory::create({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); + + sd::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.evaluate({&x, &min, &max}, {}, {6}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + result->printIndexedBuffer("Quantized03_2"); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); + ASSERT_TRUE(exp.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_03_3) { + NDArray x = NDArrayFactory::create('c', {3,5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f, + 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, + 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); + NDArray exp = NDArrayFactory::create('c', {3,5}, { + 0.781600f, 0.593422f, 0.728248f, 0.233790f, 0.509014f, 0.186095f, 0.508648f, 0.868295f, 0.343809f, + 0.509014f, 0.093048f, 0.593422f, 0.658224f, 0.343809f, 0.165086f}); + NDArray min = NDArrayFactory::create({-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); + NDArray max = NDArrayFactory::create({0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); + + sd::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.evaluate({&x, &min, &max}, {}, {6}, {false}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + result->printIndexedBuffer("Quantized03_3"); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); + ASSERT_TRUE(exp.equalsTo(result)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_4) { +#ifdef FFAST_MATH + if (1 > 0) + return; +#endif + + NDArray x = NDArrayFactory::create('c', {2,4,5,3}); + NDArray exp = NDArrayFactory::create('c', {2,4,5,3},{ + 1.0588236f, 1.9607843f, 3.019608f, 4.0588236f, 5.098039f, 6.039216f, 7.0588236f, 8.039216f, 9.058824f, + 10.058824f, 10.980392f, 12.078432f, 13.058824f, 13.921569f, 15.09804f, 16.058825f, 17.058825f, 18.117647f, + 19.058825f, 20.f, 21.137257f, 22.058825f, 22.941177f, 23.882355f, 25.058825f, 26.078432f, 26.901962f, + 28.058825f, 29.019608f, 29.92157f, 31.058825f, 31.960785f, 32.941177f, 34.058823f, 35.09804f, 35.960785f, + 37.058823f, 38.039215f, 38.980392f, 40.058823f, 40.980392f, 42.000004f, 43.058826f, 43.92157f, 45.01961f, + 45.f, 47.058823f, 48.03922f, 45.f, 50.f, 51.058826f, 45.f, 50.f, 54.078434f, + 45.f, 50.f, 57.09804f, 45.f, 50.f, 60.11765f, 45.f, 50.f, 62.862747f, + 45.f, 50.f, 65.882355f, 45.f, 50.f, 68.90196f, 45.f, 50.f, 70.f, + 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, + 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, + 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, + 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, + 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, 45.f, 50.f, 70.f, + 45.f, 50.f, 70.f}); + NDArray min = NDArrayFactory::create({20.f, 20.f, 20.f}); + NDArray max = NDArrayFactory::create({65.f, 70.f, 90.f}); + x.linspace(1.); + sd::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.evaluate({&x, &min, &max}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); +// result.printBuffer("Quantized per channels 4"); +// exp.printBuffer("Quantized per channest E"); +// auto diff = *result - exp; +// diff.printIndexedBuffer("Difference"); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); + ASSERT_TRUE(exp.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_5) { + NDArray x = NDArrayFactory::create('c', {2, 3, 5, 4}); + NDArray exp = NDArrayFactory::create('c', {2, 3, 5, 4},{ + -19.92157f, -18.980392f, -18.039217f, -16.941177f, + -19.92157f, -18.980392f, -18.039217f, -16.941177f, + -19.92157f, -18.980392f, -18.039217f, -16.941177f, + -19.92157f, -18.980392f, -18.039217f, -16.941177f, + -19.92157f, -18.980392f, -18.039217f, -16.941177f, + -19.92157f, -18.980392f, -18.039217f, -16.941177f, + -19.92157f, -18.980392f, -18.039217f, -16.941177f, + -19.92157f, -18.980392f, -18.039217f, -16.941177f, + -19.92157f, -18.980392f, -18.039217f, -16.941177f, + -19.92157f, -18.980392f, -18.039217f, -16.941177f, + -19.92157f, -18.980392f, -18.039217f, -16.941177f, + -16.f, -15.058824f, -13.960785f, -13.0196085f, + -11.92157f, -10.980392f, -10.039217f, -8.941177f, + -8.000001f, -7.0588236f, -5.960785f, -5.0196085f, + -3.9215698f, -2.9803925f, -2.039217f, -0.94117737f, + 0.f, 0.94117737f, 2.039215f, 2.9803925f, + 4.07843f, 5.0196075f, 5.960783f, 7.0588226f, + 8.f, 8.941177f, 10.039215f, 10.980392f, + 12.07843f, 13.019608f, 13.960783f, 15.058823f, + 16.f, 16.941177f, 18.039217f, 18.980392f, + 20.07843f, 21.019608f, 21.960783f, 23.058823f, + 20.07843f, 21.019608f, 21.960783f, 23.058823f, + 20.07843f, 21.019608f, 21.960783f, 23.058823f, + 20.07843f, 21.019608f, 21.960783f, 23.058823f, + 20.07843f, 21.019608f, 21.960783f, 23.058823f, + 20.07843f, 21.019608f, 21.960783f, 23.058823f, + 20.07843f, 21.019608f, 21.960783f, 23.058823f, + 20.07843f, 21.019608f, 21.960783f, 23.058823f, + 20.07843f, 21.019608f, 21.960783f, 23.058823f, + 20.07843f, 21.019608f, 21.960783f, 23.058823f + }); + NDArray min = NDArrayFactory::create({-20.f, -19.f, -18.f, -17.f}); + NDArray max = NDArrayFactory::create({20.f, 21.f, 22.f, 23.f}); + x.linspace(-60.); + sd::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.evaluate({&x, &min, &max}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); +// result.printBuffer("Quantized per channels 5"); +// exp.printBuffer("Quantized per channest E"); +// auto diff = *result - exp; +// diff.printIndexedBuffer("Difference"); + + ASSERT_TRUE(exp.isSameShapeStrict(*result)); + ASSERT_TRUE(exp.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_6) { + NDArray x = NDArrayFactory::create('c', {3, 5}, {0.7788f,0.8012f, 0.7244f, 0.2309f,0.7271f, + 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, + 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); +// NDArray exp = NDArrayFactory::create('c', {3, 5},{ +// 0.7801f, 0.5966f, 0.7260f, 0.2320f, 0.5084f, +// 0.1800f, 0.5046f, 0.8684f, 0.3513f, 0.5084f, +// 0.0877f, 0.5966f, 0.6600f, 0.3513f, 0.1604f +// }); + + NDArray exp = NDArrayFactory::create('c', {3,5}, { + 0.77700233f, 0.596913f, 0.72314f, 0.23104f, 0.50982356f, + 0.17930824f, 0.50528157f, 0.86846f, 0.34995764f, 0.50982356f, + 0.08735529f, 0.596913f, 0.6574f, 0.34995764f, 0.15974471f}); + NDArray min = NDArrayFactory::create('c', {5}, {-0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); + NDArray max = NDArrayFactory::create('c', {5}, {0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); + // x.linspace(-60.); + sd::ops::fake_quant_with_min_max_vars_per_channel op; + auto results = op.evaluate({&x, &min, &max}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); +// result.printBuffer("Quantized per channels 5"); +// exp.printBuffer("Quantized per channest E"); +// auto diff = *result - exp; +// diff.printIndexedBuffer("Difference"); + + ASSERT_TRUE(exp.isSameShapeStrict(*result)); + ASSERT_TRUE(exp.equalsTo(result)); +} + +////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_7) { + + NDArray x = NDArrayFactory::create('c', {100}); + NDArray exp = NDArrayFactory::create('c', {100}, { + 0.f, 0.01176471f, 0.01960784f, 0.03137255f, 0.03921569f, + 0.0509804f, 0.05882353f, 0.07058824f, 0.07843138f, 0.09019608f, + 0.09803922f, 0.10980393f, 0.12156864f, 0.12941177f, 0.14117648f, + 0.14901961f, 0.16078432f, 0.16862746f, 0.18039216f, 0.18823531f, + 0.20000002f, 0.21176472f, 0.21960786f, 0.23137257f, 0.2392157f, + 0.2509804f, 0.25882354f, 0.27058825f, 0.2784314f, 0.2901961f, + 0.3019608f, 0.30980393f, 0.32156864f, 0.32941177f, 0.34117648f, + 0.34901962f, 0.36078432f, 0.36862746f, 0.3803922f, 0.38823533f, + 0.40000004f, 0.41176474f, 0.41960788f, 0.43137258f, 0.43921572f, + 0.45098042f, 0.45882356f, 0.47058827f, 0.4784314f, 0.4901961f, + 0.49803925f, 0.50980395f, 0.52156866f, 0.5294118f, 0.5411765f, + 0.54901963f, 0.56078434f, 0.5686275f, 0.5803922f, 0.5882353f, + 0.6f, 0.6117647f, 0.61960787f, 0.6313726f, 0.6392157f, + 0.6509804f, 0.65882355f, 0.67058825f, 0.6784314f, 0.6901961f, + 0.69803923f, 0.70980394f, 0.72156864f, 0.7294118f, 0.7411765f, + 0.7490196f, 0.7607844f, 0.7686275f, 0.7803922f, 0.78823537f, + 0.8000001f, 0.8117648f, 0.8196079f, 0.8313726f, 0.83921576f, + 0.85098046f, 0.8588236f, 0.8705883f, 0.87843144f, 0.89019614f, + 0.8980393f, 0.909804f, 0.9215687f, 0.9294118f, 0.94117653f, + 0.9490197f, 0.9607844f, 0.9686275f, 0.9803922f, 0.98823535f + }); + NDArray min = NDArrayFactory::create('c', {1},{0.0f}); + NDArray max = NDArrayFactory::create('c', {1}, {1.f}); + x.linspace(0., 0.01); + sd::ops::fake_quant_with_min_max_vars op; + auto results = op.evaluate({&x, &min, &max}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); +// result.printBuffer("Quantized7"); +// exp.printBuffer("Expected 7"); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); + ASSERT_TRUE(exp.equalsTo(result)); +} + +////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, FakeQuantWithMinMaxVars_Test_8) { + + NDArray x = NDArrayFactory::create('c', {10}); + NDArray exp = NDArrayFactory::create('c', {10}, { + 0.f, 0.09803922f, 0.20000002f, 0.3019608f, 0.40000004f, 0.49803925f, + 0.6f, 0.69803923f, 0.8000001f, 0.8980393f + }); + NDArray min = NDArrayFactory::create('c', {1},{0.0f}); + NDArray max = NDArrayFactory::create('c', {1}, {1.f}); + x.linspace(0., 0.1); + sd::ops::fake_quant_with_min_max_vars op; + auto results = op.evaluate({&x, &min, &max}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); +// x.printBuffer("SourInput8"); +// result.printBuffer("Quantized8"); +// exp.printBuffer("Expected 8"); + ASSERT_TRUE(exp.isSameShapeStrict(*result)); + ASSERT_TRUE(exp.equalsTo(result)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, bool_broadcast_test_1) { + + NDArray arr1('c', {2,2,1}, {1, 2, 3, 4}, sd::DataType::INT32); + NDArray arr2('c', { 2,2}, {0, 1, 0, 4}, sd::DataType::INT32); + + NDArray expd('c', {2,2,2}, {false, true, false, false, false, false, false, true}, sd::DataType::BOOL); + + NDArray result('c', {2,2,2}, sd::DataType::BOOL); + + arr1.applyTrueBroadcast(sd::BroadcastBoolOpsTuple::custom(scalar::EqualTo, pairwise::EqualTo, broadcast::EqualTo), arr2, result, true); + // result.printIndexedBuffer(); + // expd.printIndexedBuffer(); + + ASSERT_TRUE(expd.isSameShape(result)); + ASSERT_TRUE(expd.equalsTo(result)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, printIndexedTest_1) { + + NDArray arr('c', {2,2,2,2}, {1, 2, 3, 4, 5, 6, 7, 8,9, 10, 11, 12, 13, 14, 15, 16}, sd::DataType::INT32); +// NDArray arr2('c', { 2,2}, {0, 1, 0, 4}, sd::DataType::INT32); + +// NDArray expd('c', {2,2,2}, {0,1,0,0, 0,0,0,1}, sd::DataType::BOOL); + +// NDArray result('c', {2,2,2}, sd::DataType::BOOL); + +// arr1.applyTrueBroadcast(sd::BroadcastBoolOpsTuple::custom(scalar::EqualTo, pairwise::EqualTo, broadcast::EqualTo), &arr2, &result, true, nullptr); + // result.printIndexedBuffer(); + // expd.printIndexedBuffer(); + +// ASSERT_TRUE(expd.isSameShape(result)); +// ASSERT_TRUE(expd.equalsTo(result)); + // arr.printIndexedBuffer("Test Print"); // output as [1, 2, 3, 4, 5, 6, 7, 8] +// +// we want output as +// [[[1 2] +// [3 4]] +// +// [[5 6] +// [7 8]]] +// + ResultSet lastDims = arr.allTensorsAlongDimension({3}); // last dim + size_t k = 0; // k from 0 to lastDims->size() + Nd4jLong rank = 4; // in this case + printf("["); + for (Nd4jLong i = 0; i < rank - 1; i++) { + + for (Nd4jLong l = 0; l < i; ++l) + printf("\n"); + printf("["); + for (Nd4jLong j = 0; j < arr.sizeAt(i); j++) { + // if (!i) + // printf("["); + // else + // printf(" "); + lastDims.at(k++)->printBuffer(); + //if (k == arr.sizeAt(i)) + // printf("]\n"); + } + printf("]\n"); + } + printf("]\n"); +} + diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests11.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests11.cpp new file mode 100644 index 000000000..fca346a15 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests11.cpp @@ -0,0 +1,4030 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +// +// Created by raver on 8/4/2018. +// + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include + +using namespace sd; + + +class DeclarableOpsTests11 : public testing::Test { +public: + + DeclarableOpsTests11() { + printf("\n"); + fflush(stdout); + } +}; + + +TEST_F(DeclarableOpsTests11, test_listdiff_1) { + auto x = NDArrayFactory::create('c', {4}, {0, 1, 2, 3}); + auto y = NDArrayFactory::create('c',{2}, {3, 1}); + + sd::ops::listdiff op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, log_loss_grad_test1) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {-12.49997,-13.04346, -13.63635, -14.28571,-14.99999,-15.78947, -16.66666, -17.64705,-18.75 ,-20. , -21.42857, -23.07692, + -24.99999,-27.27272, -29.99999, -33.33332,-37.49999,-42.85713, -49.99998, -59.99998,-74.99995,-99.99992,-149.99986,-299.99911}); + NDArray dLdwExp('c', {2,3,4}, {3.21887, 4.96807, 6.10512, 6.80726, 7.15461, 7.19051, 6.93973, 6.41584, 5.62456, 4.56548, 3.2326 , 1.61444, + -0.30659, -2.55529, -5.16569, -8.18417,-11.67468,-15.72734,-20.47379,-26.11644,-32.9902 ,-41.71318,-53.64824,-73.05434}); + NDArray dLdlExp('c', {2,3,4}, {1.58903, 1.22117, 0.99621, 0.82911, 0.69315, 0.57634, 0.47223, 0.37689, 0.28768, 0.20273, 0.12058, 0.04002, + -0.04002,-0.12058,-0.20273,-0.28768,-0.37689,-0.47223,-0.57634,-0.69315,-0.82911,-0.99621,-1.22117,-1.58903}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, log_loss_grad_test2) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2,1,4}, sd::DataType::DOUBLE); + + NDArray dLdwExp('c', {2,1,4}, {15.99805, 16.72406, 16.27746, 14.83754,-44.97147,-59.99582,-79.28771,-107.35497}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdw = results.at(1); + + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, log_loss_grad_test3) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights(sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {-12.49997,-13.04346, -13.63635, -14.28571,-14.99999,-15.78947, -16.66666, -17.64705,-18.75 ,-20. , -21.42857, -23.07692, + -24.99999,-27.27272, -29.99999, -33.33332,-37.49999,-42.85713, -49.99998, -59.99998,-74.99995,-99.99992,-149.99986,-299.99911}); + NDArray dLdwExp('c', {}, std::vector{-227.77286}); + NDArray dLdlExp('c', {2,3,4}, {1.58903, 1.22117, 0.99621, 0.82911, 0.69315, 0.57634, 0.47223, 0.37689, 0.28768, 0.20273, 0.12058, 0.04002, + -0.04002,-0.12058,-0.20273,-0.28768,-0.37689,-0.47223,-0.57634,-0.69315,-0.82911,-0.99621,-1.22117,-1.58903}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, log_loss_grad_test4) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); + + NDArray dLdwExp('c', {1,3,1}, {4.8876 , -46.29156, -186.36887}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdw = results.at(1); + // dLdw->printIndexedBuffer(); + // dLdw->printShapeInfo(); + + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, log_loss_grad_test5) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {-1.04166,-1.08696, -1.13636, -1.19048,-1.25 ,-1.31579, -1.38889, -1.47059,-1.5625 ,-1.66667, -1.78571, -1.92308, + -2.08333,-2.27273, -2.5 , -2.77778,-3.125 ,-3.57143, -4.16667, -5. ,-6.25 ,-8.33333,-12.49999,-24.99993}); + NDArray dLdwExp('c', {2,3,4}, {1.05912, 1.20488, 1.29964, 1.35815, 1.3871 , 1.39009, 1.36919, 1.32553, 1.25959, 1.17133, 1.06026, 0.92541, + 0.76533, 0.57794, 0.3604 , 0.10886,-0.18201,-0.51973,-0.91527,-1.38549,-1.95831,-2.68522,-3.67981,-5.29698}); + NDArray dLdlExp('c', {2,3,4}, {0.13242, 0.10176, 0.08302, 0.06909, 0.05776, 0.04803, 0.03935, 0.03141, 0.02397, 0.01689, 0.01005, 0.00334, + -0.00334,-0.01005,-0.01689,-0.02397,-0.03141,-0.03935,-0.04803,-0.05776,-0.06909,-0.08302,-0.10176,-0.13242}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, log_loss_grad_test6) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); + + NDArray dLdwExp('c', {1,3,1}, {6.73432, 2.46939,-9.20372}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdw = results.at(1); + + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, log_loss_grad_test7) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights(sd::DataType::DOUBLE); + + NDArray dLdwExp('c', {}, std::vector{0.}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdw = results.at(1); + + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, log_loss_grad_test8) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {0. , 0. , 0. , 0. ,-1.5 ,-1.57895, -1.66667, -1.76471,-1.875 ,-2. , -2.14286, -2.30769, + -2.5 ,-2.72727, -3. , -3.33333,-3.75 ,-4.28571, -5. , -6. ,-7.49999,-9.99999,-14.99999,-29.99991}); + NDArray dLdwExp('c', {2,3,4}, {1.56625, 1.74117, 1.85487, 1.92509, 1.95982, 1.96341, 1.93833, 1.88594, 1.80682, 1.70091, 1.56762, 1.4058 , + 1.2137 , 0.98883, 0.72779, 0.42594, 0.07689,-0.32837,-0.80302,-1.36728,-2.05466,-2.92696,-4.12046,-6.06107}); + NDArray dLdlExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0.06931, 0.05763, 0.04722, 0.03769, 0.02877, 0.02027, 0.01206, 0.004, + -0.004 ,-0.01206,-0.02027,-0.02877,-0.03769,-0.04722,-0.05763,-0.06931,-0.08291,-0.09962,-0.12212,-0.1589}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); + + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, log_loss_grad_test9) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {-0.52083,-0.54348,-0.56818, -0.59524,-0.625 ,-0.65789,-0.69444, -0.73529,-0.78125,-0.83333,-0.89286, -0.96154, + -1.04167,-1.13636,-1.25 , -1.38889,-1.5625 ,-1.78571,-2.08333, -2.5 ,-3.125 ,-4.16666,-6.24999,-12.49996}); + NDArray dLdwExp('c', {2,3,4}, {0.13412, 0.207 , 0.25438, 0.28364, 0.29811, 0.2996 , 0.28916, 0.26733, 0.23436, 0.19023, 0.13469, 0.06727, + -0.01277,-0.10647,-0.21524,-0.34101,-0.48645,-0.65531,-0.85307,-1.08819,-1.37459,-1.73805,-2.23534,-3.04393}); + NDArray dLdlExp('c', {2,3,4}, {0.06621, 0.05088, 0.04151, 0.03455, 0.02888, 0.02401, 0.01968, 0.0157 , 0.01199, 0.00845, 0.00502, 0.00167, + -0.00167,-0.00502,-0.00845,-0.01199,-0.0157 ,-0.01968,-0.02401,-0.02888,-0.03455,-0.04151,-0.05088,-0.06621}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, log_loss_grad_test10) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {1,1}, sd::DataType::DOUBLE); + + NDArray dLdwExp('c', {1,1}, std::vector{-9.49054}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdw = results.at(1); + + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, log_loss_grad_test11) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); + + NDArray dLdwExp('c', {1,3,1}, {0.20365,-1.92882,-7.76537}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdw = results.at(1); + + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, log_loss_grad_test12) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, { 0. , 0. , 0. , 0. ,-0.75 ,-0.789473,-0.833333, -0.882353,-0.9375 ,-1. ,-1.071428, -1.153846, + -1.25 ,-1.363636,-1.5 , -1.666666,-1.875 ,-2.142857,-2.499999, -2.999999,-3.749997,-4.999997,-7.499993,-14.999956}); + NDArray dLdwExp('c', {2,3,4}, {0.16094, 0.2484 , 0.30526, 0.34036, 0.35773, 0.35953, 0.34699, 0.32079, 0.28123, 0.22827, 0.16163, 0.08072, + -0.01533,-0.12776,-0.25828,-0.40921,-0.58373,-0.78637,-1.02369,-1.30582,-1.64951,-2.08566,-2.68241,-3.65272}); + NDArray dLdlExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0.03466, 0.02882, 0.02361, 0.01884, 0.01438, 0.01014, 0.00603, 0.002 , + -0.002 ,-0.00603,-0.01014,-0.01438,-0.01884,-0.02361,-0.02882,-0.03466,-0.04146,-0.04981,-0.06106,-0.07945}); + + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.r(0) = 0.; + weights.r(1) = 0.; + weights.r(2) = 0.; + weights.r(3) = 0.; + + + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, log_loss_grad_test13) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2,3,1}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , + -2.08333,-2.27273, -2.5 , -2.77778,-3.125 ,-3.57143, -4.16667, -5. ,-6.25 ,-8.33333,-12.49999,-24.99993}); + NDArray dLdwExp('c', {2,3,1}, {1.75828, 2.30839, 1.25309, -1.35098, -6.16602,-16.78383}); + NDArray dLdlExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , + -0.00334,-0.01005,-0.01689,-0.02397,-0.03141,-0.03935,-0.04803,-0.05776,-0.06909,-0.08302,-0.10176,-0.13242}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.r(0) = 0.; + weights.r(1) = 0.; + weights.r(2) = 0.; + + sd::ops::log_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); +} + +TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test1) { + + NDArray input = NDArrayFactory::create('c', {1, 7, 7, 1}, { + 1.f, 2.1f, 3.15f, 4.2f, 5.15f, 6.1f, 7.f, + 8.f, 9.1f, 10.f, 11.f, 12.9f, 13.1f, 14.f, + 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, + 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, + 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, + 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, + 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f + }); + NDArray expected = NDArrayFactory::create('c', {1, 30, 30, 1}, { + 1.f, 1.1976162f, 1.4174359f, 1.6775769f, 1.9961575f, 2.3283265f, + 2.550918f, 2.7360606f, 2.9655411f, 3.2929654f, 3.5441515f, 3.7380352f, + 3.948995f, 4.248106f, 4.5073795f, 4.6843743f, 4.8572845f, 5.104302f, + 5.3869915f, 5.581401f, 5.7539616f, 5.974285f, 6.272836f, 6.5204263f, + 6.718899f, 6.8871036f, 7.039068f, 7.099216f, 7.0784245f, 7.0281887f, + 2.247592f, 2.446947f, 2.6694887f, 2.9312382f, 3.248216f, 3.5745337f, + 3.78931f, 3.9656973f, 4.186417f, 4.5046535f, 4.740569f, 4.9217057f, + 5.133866f, 5.459533f, 5.7744613f, 6.0197873f, 6.254011f, 6.535633f, + 6.8097296f, 6.9607787f, 7.0749416f, 7.241601f, 7.5094895f, 7.7499495f, + 7.954571f, 8.131972f, 8.286526f, 8.346463f, 8.325745f, 8.275683f, + 3.6286845f, 3.830573f, 4.0569587f, 4.3211575f, 4.6364856f, 4.9556503f, + 5.160583f, 5.3258467f, 5.535462f, 5.84216f, 6.058749f, 6.223753f, + 6.437597f, 6.797369f, 7.1836042f, 7.5164022f, 7.8290343f, 8.154773f, + 8.417635f, 8.512958f, 8.5521f, 8.649708f, 8.87788f, 9.108794f, + 9.320926f, 9.509781f, 9.667375f, 9.72694f, 9.706349f, 9.656599f, + 5.276778f, 5.480438f, 5.709702f, 5.9754477f, 6.288551f, 6.6005697f, + 6.796207f, 6.9511423f, 7.1503997f, 7.4461427f, 7.644651f, 7.794562f, + 8.009684f, 8.400473f, 8.851847f, 9.26469f, 9.649218f, 10.015648f, + 10.268647f, 10.313368f, 10.2843275f, 10.319379f, 10.512033f, 10.734956f, + 10.954604f, 11.154507f, 11.315369f, 11.374779f, 11.354242f, 11.304622f, + 7.325373f, 7.5284843f, 7.757575f, 8.022221f, 8.331997f, 8.638187f, + 8.827649f, 8.976217f, 9.168955f, 9.45726f, 9.6442375f, 9.784517f, + 9.999621f, 10.407702f, 10.896234f, 11.355122f, 11.781423f, 12.172186f, + 12.420712f, 12.4374485f, 12.370511f, 12.371386f, 12.545973f, 12.766424f, + 12.992249f, 13.20012f, 13.364252f, 13.424109f, 13.40342f, 13.353425f, + 9.493208f, 9.692467f, 9.9169445f, 10.176801f, 10.482199f, 10.78547f, + 10.974367f, 11.123442f, 11.31637f, 11.603645f, 11.790616f, 11.930889f, + 12.144082f, 12.546447f, 13.024898f, 13.4723f, 13.889232f, 14.276275f, + 14.528972f, 14.555555f, 14.50145f, 14.515459f, 14.700572f, 14.927055f, + 15.156046f, 15.366046f, 15.532901f, 15.594008f, 15.5728855f, 15.521847f, + 10.970133f, 11.163599f, 11.380694f, 11.633735f, 11.935032f, 12.238887f, + 12.43254f, 12.588294f, 12.787534f, 13.079956f, 13.27752f, 13.426631f, + 13.636713f, 14.013844f, 14.441672f, 14.827978f, 15.191209f, 15.549808f, + 15.81343f, 15.881828f, 15.883522f, 15.950411f, 16.16933f, 16.40794f, + 16.636436f, 16.842583f, 17.010887f, 17.07363f, 17.05194f, 16.999537f, + 12.219155f, 12.406129f, 12.614796f, 12.860335f, 13.157928f, 13.464224f, + 13.665207f, 13.830567f, 14.039036f, 14.339629f, 14.552863f, 14.715049f, + 14.921564f, 15.264454f, 15.622843f, 15.924977f, 16.213829f, 16.532364f, + 16.8099f, 16.934835f, 17.012146f, 17.150164f, 17.413412f, 17.666712f, + 17.892765f, 18.09207f, 18.261044f, 18.325531f, 18.303238f, 18.249378f, + 13.7663965f, 13.947391f, 14.148263f, 14.386917f, 14.681246f, 14.990087f, + 15.198166f, 15.372728f, 15.590062f, 15.898583f, 16.126892f, 16.301655f, + 16.50487f, 16.815214f, 17.107498f, 17.329458f, 17.547403f, 17.827654f, + 18.118288f, 18.296928f, 18.4461f, 18.651634f, 18.956806f, 19.22382f, + 19.447308f, 19.639887f, 19.809319f, 19.875397f, 19.852556f, 19.797365f, + 15.9419365f, 16.118704f, 16.314133f, 16.547867f, 16.839561f, 17.14954f, + 17.361883f, 17.542162f, 17.764957f, 18.078188f, 18.315733f, 18.498205f, + 18.699116f, 18.988684f, 19.238989f, 19.410137f, 19.583265f, 19.839512f, + 20.13878f, 20.35177f, 20.546844f, 20.795671f, 21.128067f, 21.404358f, + 21.626736f, 21.8155f, 21.98561f, 22.052843f, 22.029604f, 21.973448f, + 17.53522f, 17.71077f, 17.904636f, 18.13695f, 18.42784f, 18.738056f, + 18.951529f, 19.133352f, 19.357613f, 19.672083f, 19.912102f, 20.096638f, + 20.296894f, 20.580765f, 20.819603f, 20.976887f, 21.137802f, 21.387535f, + 21.689209f, 21.911621f, 22.119276f, 22.37999f, 22.71991f, 22.998823f, + 23.22097f, 23.40876f, 23.57911f, 23.646685f, 23.623325f, 23.566887f, + 18.746353f, 18.922657f, 19.117487f, 19.350685f, 19.64207f, 19.952137f, + 20.164913f, 20.345781f, 20.569134f, 20.88284f, 21.12133f, 21.30459f, + 21.505253f, 21.792645f, 22.038572f, 22.204426f, 22.37289f, 22.626648f, + 22.926834f, 23.143423f, 23.343302f, 23.596668f, 23.931936f, 24.209232f, + 24.431519f, 24.619913f, 24.79011f, 24.857473f, 24.83419f, 24.777927f, + 20.16656f, 20.344206f, 20.540766f, 20.775532f, 21.067804f, 21.377607f, + 21.589132f, 21.768297f, 21.99003f, 22.302366f, 22.538124f, 22.719105f, + 22.920494f, 23.214176f, 23.472767f, 23.653934f, 23.83589f, 24.096842f, + 24.394371f, 24.600555f, 24.786541f, 25.026773f, 25.353731f, 25.62813f, + 25.850672f, 26.04014f, 26.210072f, 26.277063f, 26.253906f, 26.197956f, + 22.363024f, 22.54125f, 22.738552f, 22.973991f, 23.266647f, 23.57634f, + 23.787327f, 23.96576f, 24.186796f, 24.498543f, 24.733124f, 24.913122f, + 25.114826f, 25.411213f, 25.675262f, 25.863028f, 26.050789f, 26.314838f, + 26.611223f, 26.812925f, 26.992926f, 27.227505f, 27.550882f, 27.824034f, + 28.046684f, 28.236614f, 28.406433f, 28.473265f, 28.450163f, 28.394344f, + 24.429443f, 24.60767f, 24.80497f, 25.04041f, 25.333065f, 25.642756f, + 25.853743f, 26.032173f, 26.25321f, 26.564959f, 26.79954f, 26.97954f, + 27.181242f, 27.47763f, 27.74168f, 27.929441f, 28.117207f, 28.381254f, + 28.677637f, 28.879343f, 29.059345f, 29.293922f, 29.617298f, 29.890451f, + 30.113104f, 30.303034f, 30.472853f, 30.539684f, 30.516582f, 30.460762f, + 26.f, 26.178228f, 26.375526f, 26.61097f, 26.903624f, 27.213314f, + 27.424305f, 27.602734f, 27.823772f, 28.135519f, 28.3701f, 28.550098f, + 28.7518f, 29.04819f, 29.312237f, 29.5f, 29.687763f, 29.951813f, + 30.2482f, 30.449903f, 30.629902f, 30.864483f, 31.187859f, 31.461012f, + 31.683659f, 31.873592f, 32.043407f, 32.11024f, 32.087135f, 32.03132f, + 27.570559f, 27.748787f, 27.946087f, 28.181528f, 28.474184f, 28.783876f, + 28.994865f, 29.173294f, 29.39433f, 29.70608f, 29.940659f, 30.120655f, + 30.32236f, 30.618746f, 30.882797f, 31.070557f, 31.25832f, 31.522371f, + 31.818754f, 32.02046f, 32.20046f, 32.43504f, 32.758415f, 33.031567f, + 33.25422f, 33.44415f, 33.613964f, 33.680794f, 33.657696f, 33.60188f, + 29.636976f, 29.815207f, 30.0125f, 30.247944f, 30.5406f, 30.85029f, + 31.061283f, 31.239712f, 31.46075f, 31.7725f, 32.00708f, 32.187077f, + 32.38878f, 32.685165f, 32.949215f, 33.13698f, 33.32474f, 33.58879f, + 33.885178f, 34.086884f, 34.26688f, 34.501457f, 34.824837f, 35.09799f, + 35.320637f, 35.510574f, 35.68039f, 35.747215f, 35.724117f, 35.6683f, + 31.83344f, 32.011665f, 32.20897f, 32.444412f, 32.73707f, 33.046757f, + 33.257744f, 33.436176f, 33.657207f, 33.96896f, 34.203537f, 34.383537f, + 34.58524f, 34.88163f, 35.145676f, 35.33344f, 35.521206f, 35.785255f, + 36.081642f, 36.28334f, 36.46334f, 36.69792f, 37.021297f, 37.294453f, + 37.517097f, 37.707027f, 37.876846f, 37.94368f, 37.920578f, 37.864758f, + 33.253647f, 33.431873f, 33.62917f, 33.864613f, 34.15727f, 34.466957f, + 34.677948f, 34.856377f, 35.077415f, 35.38916f, 35.623745f, 35.803745f, + 36.005447f, 36.301834f, 36.565884f, 36.753647f, 36.941406f, 37.205456f, + 37.50184f, 37.703545f, 37.883545f, 38.118122f, 38.4415f, 38.714653f, + 38.9373f, 39.127235f, 39.297054f, 39.363884f, 39.340782f, 39.28496f, + 34.464783f, 34.64301f, 34.840305f, 35.075752f, 35.368404f, 35.6781f, + 35.889088f, 36.067516f, 36.28855f, 36.6003f, 36.834885f, 37.014877f, + 37.216583f, 37.51297f, 37.77702f, 37.964783f, 38.152546f, 38.416595f, + 38.71298f, 38.914684f, 39.094685f, 39.32926f, 39.652645f, 39.925793f, + 40.14844f, 40.338375f, 40.508194f, 40.575024f, 40.55192f, 40.496105f, + 36.058067f, 36.23629f, 36.43359f, 36.669033f, 36.961685f, 37.271378f, + 37.48237f, 37.6608f, 37.881836f, 38.19359f, 38.42817f, 38.608162f, + 38.809868f, 39.10625f, 39.3703f, 39.558064f, 39.74583f, 40.00988f, + 40.306267f, 40.50797f, 40.68797f, 40.92255f, 41.245926f, 41.519077f, + 41.741722f, 41.931652f, 42.101475f, 42.168304f, 42.145203f, 42.089386f, + 38.315002f, 38.493233f, 38.690533f, 38.925976f, 39.218628f, 39.52832f, + 39.739307f, 39.917736f, 40.138775f, 40.45052f, 40.685104f, 40.865097f, + 41.066803f, 41.36319f, 41.627243f, 41.815002f, 42.002766f, 42.26682f, + 42.5632f, 42.764908f, 42.944904f, 43.179485f, 43.50286f, 43.776016f, + 43.998665f, 44.188595f, 44.358418f, 44.425247f, 44.402145f, 44.34633f, + 40.22708f, 40.40531f, 40.602608f, 40.83805f, 41.130707f, 41.440395f, + 41.651382f, 41.82982f, 42.050854f, 42.3626f, 42.597183f, 42.77718f, + 42.97888f, 43.27527f, 43.53932f, 43.72708f, 43.914845f, 44.178894f, + 44.47528f, 44.676983f, 44.856983f, 45.09156f, 45.41494f, 45.68809f, + 45.91074f, 46.100674f, 46.270493f, 46.337322f, 46.31422f, 46.2584f, + 41.785618f, 41.963844f, 42.161144f, 42.396584f, 42.68924f, 42.998936f, + 43.209923f, 43.388355f, 43.609394f, 43.921143f, 44.15572f, 44.335716f, + 44.53742f, 44.833805f, 45.09786f, 45.285614f, 45.473377f, 45.737427f, + 46.033817f, 46.235523f, 46.415524f, 46.650105f, 46.973476f, 47.24663f, + 47.469276f, 47.65921f, 47.82903f, 47.895855f, 47.872753f, 47.81694f, + 43.11514f, 43.293365f, 43.490665f, 43.726105f, 44.018764f, 44.328457f, + 44.539444f, 44.717873f, 44.93891f, 45.25066f, 45.48524f, 45.665237f, + 45.86694f, 46.163326f, 46.427376f, 46.615143f, 46.802902f, 47.066956f, + 47.363342f, 47.56505f, 47.74505f, 47.979626f, 48.302998f, 48.576153f, + 48.798798f, 48.98873f, 49.158546f, 49.225376f, 49.202282f, 49.146458f, + 44.303867f, 44.482094f, 44.679394f, 44.914833f, 45.207493f, 45.51718f, + 45.72817f, 45.9066f, 46.12764f, 46.439384f, 46.673965f, 46.853966f, + 47.055668f, 47.352055f, 47.6161f, 47.803867f, 47.99163f, 48.25568f, + 48.552063f, 48.75377f, 48.933773f, 49.16835f, 49.491726f, 49.764877f, + 49.987526f, 50.17746f, 50.347275f, 50.4141f, 50.391006f, 50.335186f, + 44.771675f, 44.949905f, 45.1472f, 45.382645f, 45.6753f, 45.98499f, + 46.195976f, 46.374413f, 46.595448f, 46.907196f, 47.141773f, 47.321774f, + 47.523476f, 47.819862f, 48.08391f, 48.27168f, 48.459446f, 48.72349f, + 49.019882f, 49.22158f, 49.401585f, 49.63616f, 49.959538f, 50.232693f, + 50.455338f, 50.64527f, 50.81509f, 50.88192f, 50.858818f, 50.803f, + 44.609966f, 44.788193f, 44.985493f, 45.220936f, 45.51359f, 45.82328f, + 46.03427f, 46.2127f, 46.433743f, 46.74549f, 46.98007f, 47.160065f, + 47.36177f, 47.658157f, 47.922207f, 48.10997f, 48.297733f, 48.561783f, + 48.858166f, 49.059875f, 49.239872f, 49.47445f, 49.79783f, 50.07098f, + 50.293625f, 50.48356f, 50.653378f, 50.720203f, 50.6971f, 50.64128f, + 44.219246f, 44.397472f, 44.594772f, 44.83021f, 45.122868f, 45.43256f, + 45.643543f, 45.82198f, 46.04302f, 46.354763f, 46.589344f, 46.76934f, + 46.971046f, 47.267433f, 47.531483f, 47.719242f, 47.907005f, 48.17105f, + 48.467438f, 48.66914f, 48.849144f, 49.08372f, 49.4071f, 49.680256f, + 49.902905f, 50.092834f, 50.262653f, 50.329483f, 50.30638f, 50.25057f}); + + auto size = NDArrayFactory::create({30, 30}); + sd::ops::resize_bicubic op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + NDArray* result = results.at(0); + +// result.printBuffer("Resized to 30x30"); +// expected.printBuffer("Expect for 30x30"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} +TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test2) { + + NDArray input = NDArrayFactory::create('c', {2, 5, 4, 3}); + NDArray expected = NDArrayFactory::create('c', {2, 10, 8, 3}, { + 1.000000f, 2.000000f, 3.000000f, 2.218750f, 3.218750f, 4.218750f, 4.000000f, 5.000000f, 6.000000f, + 5.500000f, 6.500000f, 7.500000f, 7.000000f, 8.000000f, 9.000000f, 8.781250f, 9.781250f, 10.781250f, + 10.000000f, 11.000000f, 12.000000f, 10.281250f, 11.281250f, 12.281250f, 5.875000f, 6.875000f, 7.875000f, + 7.093750f, 8.093750f, 9.093750f, 8.875000f, 9.875000f, 10.875000f, 10.375000f, 11.375000f, 12.375000f, + 11.875000f, 12.875000f, 13.875000f, 13.656250f, 14.656250f, 15.656250f, 14.875000f, 15.875000f, 16.875000f, + 15.156250f, 16.156250f, 17.156250f, 13.000000f, 14.000000f, 15.000000f, 14.218750f, 15.218750f, 16.218750f, + 16.000000f, 17.000000f, 18.000000f, 17.500000f, 18.500000f, 19.500000f, 19.000000f, 20.000000f, 21.000000f, + 20.781250f, 21.781250f, 22.781250f, 22.000000f, 23.000000f, 24.000000f, 22.281250f, 23.281250f, 24.281250f, + 19.000000f, 20.000000f, 21.000000f, 20.218750f, 21.218750f, 22.218750f, 22.000000f, 23.000000f, 24.000000f, + 23.500000f, 24.500000f, 25.500000f, 25.000000f, 26.000000f, 27.000000f, 26.781250f, 27.781250f, 28.781250f, + 28.000000f, 29.000000f, 30.000000f, 28.281250f, 29.281250f, 30.281250f, 25.000000f, 26.000000f, 27.000000f, + 26.218750f, 27.218750f, 28.218750f, 28.000000f, 29.000000f, 30.000000f, 29.500000f, 30.500000f, 31.500000f, + 31.000000f, 32.000000f, 33.000000f, 32.781250f, 33.781250f, 34.781250f, 34.000000f, 35.000000f, 36.000000f, + 34.281250f, 35.281250f, 36.281250f, 31.000000f, 32.000000f, 33.000000f, 32.218750f, 33.218750f, 34.218750f, + 34.000000f, 35.000000f, 36.000000f, 35.500000f, 36.500000f, 37.500000f, 37.000000f, 38.000000f, 39.000000f, + 38.781250f, 39.781250f, 40.781250f, 40.000000f, 41.000000f, 42.000000f, 40.281250f, 41.281250f, 42.281250f, + 37.000000f, 38.000000f, 39.000000f, 38.218750f, 39.218750f, 40.218750f, 40.000000f, 41.000000f, 42.000000f, + 41.500000f, 42.500000f, 43.500000f, 43.000000f, 44.000000f, 45.000000f, 44.781250f, 45.781250f, 46.781250f, + 46.000000f, 47.000000f, 48.000000f, 46.281250f, 47.281250f, 48.281250f, 44.125000f, 45.125000f, 46.125000f, + 45.343750f, 46.343750f, 47.343750f, 47.125000f, 48.125000f, 49.125000f, 48.625000f, 49.625000f, 50.625000f, + 50.125000f, 51.125000f, 52.125000f, 51.906250f, 52.906250f, 53.906250f, 53.125000f, 54.125000f, 55.125000f, + 53.406250f, 54.406250f, 55.406250f, 49.000000f, 50.000000f, 51.000000f, 50.218750f, 51.218750f, 52.218750f, + 52.000000f, 53.000000f, 54.000000f, 53.500000f, 54.500000f, 55.500000f, 55.000000f, 56.000000f, 57.000000f, + 56.781250f, 57.781250f, 58.781250f, 58.000000f, 59.000000f, 60.000000f, 58.281250f, 59.281250f, 60.281250f, + 50.125000f, 51.125000f, 52.125000f, 51.343750f, 52.343750f, 53.343750f, 53.125000f, 54.125000f, 55.125000f, + 54.625000f, 55.625000f, 56.625000f, 56.125000f, 57.125000f, 58.125000f, 57.906250f, 58.906250f, 59.906250f, + 59.125000f, 60.125000f, 61.125000f, 59.406250f, 60.406250f, 61.406250f, 61.000000f, 62.000000f, 63.000000f, + 62.218750f, 63.218750f, 64.218750f, 64.000000f, 65.000000f, 66.000000f, 65.500000f, 66.500000f, 67.500000f, + 67.000000f, 68.000000f, 69.000000f, 68.781250f, 69.781250f, 70.781250f, 70.000000f, 71.000000f, 72.000000f, + 70.281250f, 71.281250f, 72.281250f, 65.875000f, 66.875000f, 67.875000f, 67.093750f, 68.093750f, 69.093750f, + 68.875000f, 69.875000f, 70.875000f, 70.375000f, 71.375000f, 72.375000f, 71.875000f, 72.875000f, 73.875000f, + 73.656250f, 74.656250f, 75.656250f, 74.875000f, 75.875000f, 76.875000f, 75.156250f, 76.156250f, 77.156250f, + 73.000000f, 74.000000f, 75.000000f, 74.218750f, 75.218750f, 76.218750f, 76.000000f, 77.000000f, 78.000000f, + 77.500000f, 78.500000f, 79.500000f, 79.000000f, 80.000000f, 81.000000f, 80.781250f, 81.781250f, 82.781250f, + 82.000000f, 83.000000f, 84.000000f, 82.281250f, 83.281250f, 84.281250f, 79.000000f, 80.000000f, 81.000000f, + 80.218750f, 81.218750f, 82.218750f, 82.000000f, 83.000000f, 84.000000f, 83.500000f, 84.500000f, 85.500000f, + 85.000000f, 86.000000f, 87.000000f, 86.781250f, 87.781250f, 88.781250f, 88.000000f, 89.000000f, 90.000000f, + 88.281250f, 89.281250f, 90.281250f, 85.000000f, 86.000000f, 87.000000f, 86.218750f, 87.218750f, 88.218750f, + 88.000000f, 89.000000f, 90.000000f, 89.500000f, 90.500000f, 91.500000f, 91.000000f, 92.000000f, 93.000000f, + 92.781250f, 93.781250f, 94.781250f, 94.000000f, 95.000000f, 96.000000f, 94.281250f, 95.281250f, 96.281250f, + 91.000000f, 92.000000f, 93.000000f, 92.218750f, 93.218750f, 94.218750f, 94.000000f, 95.000000f, 96.000000f, + 95.500000f, 96.500000f, 97.500000f, 97.000000f, 98.000000f, 99.000000f, 98.781250f, 99.781250f, 100.781250f, + 100.000000f, 101.000000f, 102.000000f, 100.281250f, 101.281250f, 102.281250f, 97.000000f, 98.000000f, + 99.000000f, 98.218750f, 99.218750f, 100.218750f, 100.000000f, 101.000000f, 102.000000f, 101.500000f, + 102.500000f, 103.500000f, 103.000000f, 104.000000f, 105.000000f, 104.781250f, 105.781250f, 106.781250f, + 106.000000f, 107.000000f, 108.000000f, 106.281250f, 107.281250f, 108.281250f, 104.125000f, 105.125000f, + 106.125000f, 105.343750f, 106.343750f, 107.343750f, 107.125000f, 108.125000f, 109.125000f, 108.625000f, + 109.625000f, 110.625000f, 110.125000f, 111.125000f, 112.125000f, 111.906250f, 112.906250f, 113.906250f, + 113.125000f, 114.125000f, 115.125000f, 113.406250f, 114.406250f, 115.406250f, 109.000000f, 110.000000f, + 111.000000f, 110.218750f, 111.218750f, 112.218750f, 112.000000f, 113.000000f, 114.000000f, 113.500000f, + 114.500000f, 115.500000f, 115.000000f, 116.000000f, 117.000000f, 116.781250f, 117.781250f, 118.781250f, + 118.000000f, 119.000000f, 120.000000f, 118.281250f, 119.281250f, 120.281250f, 110.125000f, 111.125000f, + 112.125000f, 111.343750f, 112.343750f, 113.343750f, 113.125000f, 114.125000f, 115.125000f, 114.625000f, + 115.625000f, 116.625000f, 116.125000f, 117.125000f, 118.125000f, 117.906250f, 118.906250f, 119.906250f, + 119.125000f, 120.125000f, 121.125000f, 119.406250f, 120.406250f, 121.406250f + }); //input = 1.f; + input.linspace(1); + auto size = NDArrayFactory::create({10, 8}); + sd::ops::resize_bicubic op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + +// result.printBuffer("Resized to 10x8"); +// expected.printBuffer("Expect for 10x8"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test3) { + + NDArray input = NDArrayFactory::create('c', {1, 3, 3, 4}); + NDArray expected = NDArrayFactory::create('c', {1, 6, 6, 4}, { + 1.000000f, 2.000000f, 3.000000f, 4.000000f, 2.625000f, 3.625000f, 4.625000f, 5.625000f, 5.000000f, + 6.000000f, 7.000000f, 8.000000f, 7.375000f, 8.375000f, 9.375000f, 10.375000f, 9.000000f, 10.000000f, + 11.000000f, 12.000000f, 9.375000f, 10.375000f, 11.375000f, 12.375000f, 5.875000f, 6.875000f, 7.875000f, + 8.875000f, 7.500000f, 8.500000f, 9.500000f, 10.500000f, 9.875000f, 10.875000f, 11.875000f, 12.875000f, + 12.250000f, 13.250000f, 14.250000f, 15.250000f, 13.875000f, 14.875000f, 15.875000f, 16.875000f, 14.250000f, + 15.250000f, 16.250000f, 17.250000f, 13.000000f, 14.000000f, 15.000000f, 16.000000f, 14.625000f, 15.625000f, + 16.625000f, 17.625000f, 17.000000f, 18.000000f, 19.000000f, 20.000000f, 19.375000f, 20.375000f, 21.375000f, + 22.375000f, 21.000000f, 22.000000f, 23.000000f, 24.000000f, 21.375000f, 22.375000f, 23.375000f, 24.375000f, + 20.125000f, 21.125000f, 22.125000f, 23.125000f, 21.750000f, 22.750000f, 23.750000f, 24.750000f, 24.125000f, + 25.125000f, 26.125000f, 27.125000f, 26.500000f, 27.500000f, 28.500000f, 29.500000f, 28.125000f, 29.125000f, + 30.125000f, 31.125000f, 28.500000f, 29.500000f, 30.500000f, 31.500000f, 25.000000f, 26.000000f, 27.000000f, + 28.000000f, 26.625000f, 27.625000f, 28.625000f, 29.625000f, 29.000000f, 30.000000f, 31.000000f, 32.000000f, + 31.375000f, 32.375000f, 33.375000f, 34.375000f, 33.000000f, 34.000000f, 35.000000f, 36.000000f, 33.375000f, + 34.375000f, 35.375000f, 36.375000f, 26.125000f, 27.125000f, 28.125000f, 29.125000f, 27.750000f, 28.750000f, + 29.750000f, 30.750000f, 30.125000f, 31.125000f, 32.125000f, 33.125000f, 32.500000f, 33.500000f, 34.500000f, + 35.500000f, 34.125000f, 35.125000f, 36.125000f, 37.125000f, 34.500000f, 35.500000f, 36.500000f, 37.500000f + }); + input.linspace(1); + auto size = NDArrayFactory::create({6, 6}); + sd::ops::resize_bicubic op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + +// result.printBuffer("Resized to 6x6"); +// expected.printBuffer("Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test4) { + + NDArray input = NDArrayFactory::create('c', {1, 3, 4, 3}); + NDArray expected = NDArrayFactory::create('c', {1, 6, 8, 3}, { + 1.000000f, 2.000000f, 3.000000f, 2.218750f, 3.218750f, 4.218750f, 4.000000f, 5.000000f, 6.000000f, + 5.500000f, 6.500000f, 7.500000f, 7.000000f, 8.000000f, 9.000000f, 8.781250f, 9.781250f, 10.781250f, + 10.000000f, 11.000000f, 12.000000f, 10.281250f, 11.281250f, 12.281250f, 5.875000f, 6.875000f, 7.875000f, + 7.093750f, 8.093750f, 9.093750f, 8.875000f, 9.875000f, 10.875000f, 10.375000f, 11.375000f, 12.375000f, + 11.875000f, 12.875000f, 13.875000f, 13.656250f, 14.656250f, 15.656250f, 14.875000f, 15.875000f, 16.875000f, + 15.156250f, 16.156250f, 17.156250f, 13.000000f, 14.000000f, 15.000000f, 14.218750f, 15.218750f, 16.218750f, + 16.000000f, 17.000000f, 18.000000f, 17.500000f, 18.500000f, 19.500000f, 19.000000f, 20.000000f, 21.000000f, + 20.781250f, 21.781250f, 22.781250f, 22.000000f, 23.000000f, 24.000000f, 22.281250f, 23.281250f, 24.281250f, + 20.125000f, 21.125000f, 22.125000f, 21.343750f, 22.343750f, 23.343750f, 23.125000f, 24.125000f, 25.125000f, + 24.625000f, 25.625000f, 26.625000f, 26.125000f, 27.125000f, 28.125000f, 27.906250f, 28.906250f, 29.906250f, + 29.125000f, 30.125000f, 31.125000f, 29.406250f, 30.406250f, 31.406250f, 25.000000f, 26.000000f, 27.000000f, + 26.218750f, 27.218750f, 28.218750f, 28.000000f, 29.000000f, 30.000000f, 29.500000f, 30.500000f, 31.500000f, + 31.000000f, 32.000000f, 33.000000f, 32.781250f, 33.781250f, 34.781250f, 34.000000f, 35.000000f, 36.000000f, + 34.281250f, 35.281250f, 36.281250f, 26.125000f, 27.125000f, 28.125000f, 27.343750f, 28.343750f, 29.343750f, + 29.125000f, 30.125000f, 31.125000f, 30.625000f, 31.625000f, 32.625000f, 32.125000f, 33.125000f, 34.125000f, + 33.906250f, 34.906250f, 35.906250f, 35.125000f, 36.125000f, 37.125000f, 35.406250f, 36.406250f, 37.406250f + }); + input.linspace(1); + auto size = NDArrayFactory::create({6, 8}); + sd::ops::resize_bicubic op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + +// result.printBuffer("Resized to 6x8"); +// expected.printBuffer("Expect for 6x8"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test5) { + + NDArray input = NDArrayFactory::create('c', {1, 4, 4, 3}); + NDArray expected = NDArrayFactory::create('c', {1, 8, 8, 3}, { + 1.000000f, 2.000000f, 3.000000f, 2.218750f, 3.218750f, 4.218750f, 4.000000f, 5.000000f, 6.000000f, + 5.500000f, 6.500000f, 7.500000f, 7.000000f, 8.000000f, 9.000000f, 8.781250f, 9.781250f, 10.781250f, + 10.000000f, 11.000000f, 12.000000f, 10.281250f, 11.281250f, 12.281250f, 5.875000f, 6.875000f, 7.875000f, + 7.093750f, 8.093750f, 9.093750f, 8.875000f, 9.875000f, 10.875000f, 10.375000f, 11.375000f, 12.375000f, + 11.875000f, 12.875000f, 13.875000f, 13.656250f, 14.656250f, 15.656250f, 14.875000f, 15.875000f, 16.875000f, + 15.156250f, 16.156250f, 17.156250f, 13.000000f, 14.000000f, 15.000000f, 14.218750f, 15.218750f, 16.218750f, + 16.000000f, 17.000000f, 18.000000f, 17.500000f, 18.500000f, 19.500000f, 19.000000f, 20.000000f, 21.000000f, + 20.781250f, 21.781250f, 22.781250f, 22.000000f, 23.000000f, 24.000000f, 22.281250f, 23.281250f, 24.281250f, + 19.000000f, 20.000000f, 21.000000f, 20.218750f, 21.218750f, 22.218750f, 22.000000f, 23.000000f, 24.000000f, + 23.500000f, 24.500000f, 25.500000f, 25.000000f, 26.000000f, 27.000000f, 26.781250f, 27.781250f, 28.781250f, + 28.000000f, 29.000000f, 30.000000f, 28.281250f, 29.281250f, 30.281250f, 25.000000f, 26.000000f, 27.000000f, + 26.218750f, 27.218750f, 28.218750f, 28.000000f, 29.000000f, 30.000000f, 29.500000f, 30.500000f, 31.500000f, + 31.000000f, 32.000000f, 33.000000f, 32.781250f, 33.781250f, 34.781250f, 34.000000f, 35.000000f, 36.000000f, + 34.281250f, 35.281250f, 36.281250f, 32.125000f, 33.125000f, 34.125000f, 33.343750f, 34.343750f, 35.343750f, + 35.125000f, 36.125000f, 37.125000f, 36.625000f, 37.625000f, 38.625000f, 38.125000f, 39.125000f, 40.125000f, + 39.906250f, 40.906250f, 41.906250f, 41.125000f, 42.125000f, 43.125000f, 41.406250f, 42.406250f, 43.406250f, + 37.000000f, 38.000000f, 39.000000f, 38.218750f, 39.218750f, 40.218750f, 40.000000f, 41.000000f, 42.000000f, + 41.500000f, 42.500000f, 43.500000f, 43.000000f, 44.000000f, 45.000000f, 44.781250f, 45.781250f, 46.781250f, + 46.000000f, 47.000000f, 48.000000f, 46.281250f, 47.281250f, 48.281250f, 38.125000f, 39.125000f, 40.125000f, + 39.343750f, 40.343750f, 41.343750f, 41.125000f, 42.125000f, 43.125000f, 42.625000f, 43.625000f, 44.625000f, + 44.125000f, 45.125000f, 46.125000f, 45.906250f, 46.906250f, 47.906250f, 47.125000f, 48.125000f, 49.125000f, + 47.406250f, 48.406250f, 49.406250f, + }); + input.linspace(1); + auto size = NDArrayFactory::create({8, 8}); + sd::ops::resize_bicubic op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + +// result.printBuffer("Resized to 8x8"); +// expected.printBuffer("Expect for 8x8"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test6) { + + NDArray input = NDArrayFactory::create('c', {7, 7, 1}, { + 1.f, 2.1f, 3.15f, 4.2f, 5.15f, 6.1f, 7.f, + 8.f, 9.1f, 10.f, 11.f, 12.9f, 13.1f, 14.f, + 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, + 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, + 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, + 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, + 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f + }); + + NDArray expected = NDArrayFactory::create('c', {30, 30, 1}, { + 1.000000f, 1.197616f, 1.417436f, 1.677577f, 1.996158f, 2.328327f, 2.550918f, 2.736061f, 2.965541f, + 3.292965f, 3.544151f, 3.738035f, 3.948995f, 4.248106f, 4.507379f, 4.684374f, 4.857284f, 5.104302f, + 5.386991f, 5.581401f, 5.753962f, 5.974285f, 6.272836f, 6.520426f, 6.718899f, 6.887104f, 7.039068f, + 7.099216f, 7.078424f, 7.028189f, 2.247592f, 2.446947f, 2.669489f, 2.931238f, 3.248216f, 3.574534f, + 3.789310f, 3.965697f, 4.186417f, 4.504653f, 4.740569f, 4.921706f, 5.133866f, 5.459533f, 5.774461f, + 6.019787f, 6.254011f, 6.535633f, 6.809730f, 6.960779f, 7.074942f, 7.241601f, 7.509489f, 7.749949f, + 7.954571f, 8.131972f, 8.286526f, 8.346463f, 8.325745f, 8.275683f, 3.628684f, 3.830573f, 4.056959f, + 4.321157f, 4.636486f, 4.955650f, 5.160583f, 5.325847f, 5.535462f, 5.842160f, 6.058749f, 6.223753f, + 6.437597f, 6.797369f, 7.183604f, 7.516402f, 7.829034f, 8.154773f, 8.417635f, 8.512958f, 8.552100f, + 8.649708f, 8.877880f, 9.108794f, 9.320926f, 9.509781f, 9.667375f, 9.726940f, 9.706349f, 9.656599f, + 5.276778f, 5.480438f, 5.709702f, 5.975448f, 6.288551f, 6.600570f, 6.796207f, 6.951142f, 7.150400f, + 7.446143f, 7.644651f, 7.794562f, 8.009684f, 8.400473f, 8.851847f, 9.264690f, 9.649218f, 10.015648f, + 10.268647f, 10.313368f, 10.284327f, 10.319379f, 10.512033f, 10.734956f, 10.954604f, 11.154507f, 11.315369f, + 11.374779f, 11.354242f, 11.304622f, 7.325373f, 7.528484f, 7.757575f, 8.022221f, 8.331997f, 8.638187f, + 8.827649f, 8.976217f, 9.168955f, 9.457260f, 9.644237f, 9.784517f, 9.999621f, 10.407702f, 10.896234f, + 11.355122f, 11.781423f, 12.172186f, 12.420712f, 12.437449f, 12.370511f, 12.371386f, 12.545973f, 12.766424f, + 12.992249f, 13.200120f, 13.364252f, 13.424109f, 13.403420f, 13.353425f, 9.493208f, 9.692467f, 9.916944f, + 10.176801f, 10.482199f, 10.785470f, 10.974367f, 11.123442f, 11.316370f, 11.603645f, 11.790616f, 11.930889f, + 12.144082f, 12.546447f, 13.024898f, 13.472300f, 13.889232f, 14.276275f, 14.528972f, 14.555555f, 14.501450f, + 14.515459f, 14.700572f, 14.927055f, 15.156046f, 15.366046f, 15.532901f, 15.594008f, 15.572885f, 15.521847f, + 10.970133f, 11.163599f, 11.380694f, 11.633735f, 11.935032f, 12.238887f, 12.432540f, 12.588294f, 12.787534f, + 13.079956f, 13.277520f, 13.426631f, 13.636713f, 14.013844f, 14.441672f, 14.827978f, 15.191209f, 15.549808f, + 15.813430f, 15.881828f, 15.883522f, 15.950411f, 16.169330f, 16.407940f, 16.636436f, 16.842583f, 17.010887f, + 17.073630f, 17.051940f, 16.999537f, 12.219155f, 12.406129f, 12.614796f, 12.860335f, 13.157928f, 13.464224f, + 13.665207f, 13.830567f, 14.039036f, 14.339629f, 14.552863f, 14.715049f, 14.921564f, 15.264454f, 15.622843f, + 15.924977f, 16.213829f, 16.532364f, 16.809900f, 16.934835f, 17.012146f, 17.150164f, 17.413412f, 17.666712f, + 17.892765f, 18.092070f, 18.261044f, 18.325531f, 18.303238f, 18.249378f, 13.766397f, 13.947391f, 14.148263f, + 14.386917f, 14.681246f, 14.990087f, 15.198166f, 15.372728f, 15.590062f, 15.898583f, 16.126892f, 16.301655f, + 16.504870f, 16.815214f, 17.107498f, 17.329458f, 17.547403f, 17.827654f, 18.118288f, 18.296928f, 18.446100f, + 18.651634f, 18.956806f, 19.223820f, 19.447308f, 19.639887f, 19.809319f, 19.875397f, 19.852556f, 19.797365f, + 15.941937f, 16.118704f, 16.314133f, 16.547867f, 16.839561f, 17.149540f, 17.361883f, 17.542162f, 17.764957f, + 18.078188f, 18.315733f, 18.498205f, 18.699116f, 18.988684f, 19.238989f, 19.410137f, 19.583265f, 19.839512f, + 20.138780f, 20.351770f, 20.546844f, 20.795671f, 21.128067f, 21.404358f, 21.626736f, 21.815500f, 21.985610f, + 22.052843f, 22.029604f, 21.973448f, 17.535220f, 17.710770f, 17.904636f, 18.136950f, 18.427840f, 18.738056f, + 18.951529f, 19.133352f, 19.357613f, 19.672083f, 19.912102f, 20.096638f, 20.296894f, 20.580765f, 20.819603f, + 20.976887f, 21.137802f, 21.387535f, 21.689209f, 21.911621f, 22.119276f, 22.379990f, 22.719910f, 22.998823f, + 23.220970f, 23.408760f, 23.579110f, 23.646685f, 23.623325f, 23.566887f, 18.746353f, 18.922657f, 19.117487f, + 19.350685f, 19.642070f, 19.952137f, 20.164913f, 20.345781f, 20.569134f, 20.882840f, 21.121330f, 21.304590f, + 21.505253f, 21.792645f, 22.038572f, 22.204426f, 22.372890f, 22.626648f, 22.926834f, 23.143423f, 23.343302f, + 23.596668f, 23.931936f, 24.209232f, 24.431519f, 24.619913f, 24.790110f, 24.857473f, 24.834190f, 24.777927f, + 20.166560f, 20.344206f, 20.540766f, 20.775532f, 21.067804f, 21.377607f, 21.589132f, 21.768297f, 21.990030f, + 22.302366f, 22.538124f, 22.719105f, 22.920494f, 23.214176f, 23.472767f, 23.653934f, 23.835890f, 24.096842f, + 24.394371f, 24.600555f, 24.786541f, 25.026773f, 25.353731f, 25.628130f, 25.850672f, 26.040140f, 26.210072f, + 26.277063f, 26.253906f, 26.197956f, 22.363024f, 22.541250f, 22.738552f, 22.973991f, 23.266647f, 23.576340f, + 23.787327f, 23.965760f, 24.186796f, 24.498543f, 24.733124f, 24.913122f, 25.114826f, 25.411213f, 25.675262f, + 25.863028f, 26.050789f, 26.314838f, 26.611223f, 26.812925f, 26.992926f, 27.227505f, 27.550882f, 27.824034f, + 28.046684f, 28.236614f, 28.406433f, 28.473265f, 28.450163f, 28.394344f, 24.429443f, 24.607670f, 24.804970f, + 25.040410f, 25.333065f, 25.642756f, 25.853743f, 26.032173f, 26.253210f, 26.564959f, 26.799540f, 26.979540f, + 27.181242f, 27.477630f, 27.741680f, 27.929441f, 28.117207f, 28.381254f, 28.677637f, 28.879343f, 29.059345f, + 29.293922f, 29.617298f, 29.890451f, 30.113104f, 30.303034f, 30.472853f, 30.539684f, 30.516582f, 30.460762f, + 26.000000f, 26.178228f, 26.375526f, 26.610970f, 26.903624f, 27.213314f, 27.424305f, 27.602734f, 27.823772f, + 28.135519f, 28.370100f, 28.550098f, 28.751800f, 29.048190f, 29.312237f, 29.500000f, 29.687763f, 29.951813f, + 30.248200f, 30.449903f, 30.629902f, 30.864483f, 31.187859f, 31.461012f, 31.683659f, 31.873592f, 32.043407f, + 32.110240f, 32.087135f, 32.031320f, 27.570559f, 27.748787f, 27.946087f, 28.181528f, 28.474184f, 28.783876f, + 28.994865f, 29.173294f, 29.394330f, 29.706080f, 29.940659f, 30.120655f, 30.322360f, 30.618746f, 30.882797f, + 31.070557f, 31.258320f, 31.522371f, 31.818754f, 32.020460f, 32.200460f, 32.435040f, 32.758415f, 33.031567f, + 33.254220f, 33.444150f, 33.613964f, 33.680794f, 33.657696f, 33.601880f, 29.636976f, 29.815207f, 30.012500f, + 30.247944f, 30.540600f, 30.850290f, 31.061283f, 31.239712f, 31.460750f, 31.772500f, 32.007080f, 32.187077f, + 32.388780f, 32.685165f, 32.949215f, 33.136980f, 33.324740f, 33.588790f, 33.885178f, 34.086884f, 34.266880f, + 34.501457f, 34.824837f, 35.097990f, 35.320637f, 35.510574f, 35.680390f, 35.747215f, 35.724117f, 35.668300f, + 31.833440f, 32.011665f, 32.208970f, 32.444412f, 32.737070f, 33.046757f, 33.257744f, 33.436176f, 33.657207f, + 33.968960f, 34.203537f, 34.383537f, 34.585240f, 34.881630f, 35.145676f, 35.333440f, 35.521206f, 35.785255f, + 36.081642f, 36.283340f, 36.463340f, 36.697920f, 37.021297f, 37.294453f, 37.517097f, 37.707027f, 37.876846f, + 37.943680f, 37.920578f, 37.864758f, 33.253647f, 33.431873f, 33.629170f, 33.864613f, 34.157270f, 34.466957f, + 34.677948f, 34.856377f, 35.077415f, 35.389160f, 35.623745f, 35.803745f, 36.005447f, 36.301834f, 36.565884f, + 36.753647f, 36.941406f, 37.205456f, 37.501840f, 37.703545f, 37.883545f, 38.118122f, 38.441500f, 38.714653f, + 38.937300f, 39.127235f, 39.297054f, 39.363884f, 39.340782f, 39.284960f, 34.464783f, 34.643010f, 34.840305f, + 35.075752f, 35.368404f, 35.678100f, 35.889088f, 36.067516f, 36.288550f, 36.600300f, 36.834885f, 37.014877f, + 37.216583f, 37.512970f, 37.777020f, 37.964783f, 38.152546f, 38.416595f, 38.712980f, 38.914684f, 39.094685f, + 39.329260f, 39.652645f, 39.925793f, 40.148440f, 40.338375f, 40.508194f, 40.575024f, 40.551920f, 40.496105f, + 36.058067f, 36.236290f, 36.433590f, 36.669033f, 36.961685f, 37.271378f, 37.482370f, 37.660800f, 37.881836f, + 38.193590f, 38.428170f, 38.608162f, 38.809868f, 39.106250f, 39.370300f, 39.558064f, 39.745830f, 40.009880f, + 40.306267f, 40.507970f, 40.687970f, 40.922550f, 41.245926f, 41.519077f, 41.741722f, 41.931652f, 42.101475f, + 42.168304f, 42.145203f, 42.089386f, 38.315002f, 38.493233f, 38.690533f, 38.925976f, 39.218628f, 39.528320f, + 39.739307f, 39.917736f, 40.138775f, 40.450520f, 40.685104f, 40.865097f, 41.066803f, 41.363190f, 41.627243f, + 41.815002f, 42.002766f, 42.266820f, 42.563200f, 42.764908f, 42.944904f, 43.179485f, 43.502860f, 43.776016f, + 43.998665f, 44.188595f, 44.358418f, 44.425247f, 44.402145f, 44.346330f, 40.227080f, 40.405310f, 40.602608f, + 40.838050f, 41.130707f, 41.440395f, 41.651382f, 41.829820f, 42.050854f, 42.362600f, 42.597183f, 42.777180f, + 42.978880f, 43.275270f, 43.539320f, 43.727080f, 43.914845f, 44.178894f, 44.475280f, 44.676983f, 44.856983f, + 45.091560f, 45.414940f, 45.688090f, 45.910740f, 46.100674f, 46.270493f, 46.337322f, 46.314220f, 46.258400f, + 41.785618f, 41.963844f, 42.161144f, 42.396584f, 42.689240f, 42.998936f, 43.209923f, 43.388355f, 43.609394f, + 43.921143f, 44.155720f, 44.335716f, 44.537420f, 44.833805f, 45.097860f, 45.285614f, 45.473377f, 45.737427f, + 46.033817f, 46.235523f, 46.415524f, 46.650105f, 46.973476f, 47.246630f, 47.469276f, 47.659210f, 47.829030f, + 47.895855f, 47.872753f, 47.816940f, 43.115140f, 43.293365f, 43.490665f, 43.726105f, 44.018764f, 44.328457f, + 44.539444f, 44.717873f, 44.938910f, 45.250660f, 45.485240f, 45.665237f, 45.866940f, 46.163326f, 46.427376f, + 46.615143f, 46.802902f, 47.066956f, 47.363342f, 47.565050f, 47.745050f, 47.979626f, 48.302998f, 48.576153f, + 48.798798f, 48.988730f, 49.158546f, 49.225376f, 49.202282f, 49.146458f, 44.303867f, 44.482094f, 44.679394f, + 44.914833f, 45.207493f, 45.517180f, 45.728170f, 45.906600f, 46.127640f, 46.439384f, 46.673965f, 46.853966f, + 47.055668f, 47.352055f, 47.616100f, 47.803867f, 47.991630f, 48.255680f, 48.552063f, 48.753770f, 48.933773f, + 49.168350f, 49.491726f, 49.764877f, 49.987526f, 50.177460f, 50.347275f, 50.414100f, 50.391006f, 50.335186f, + 44.771675f, 44.949905f, 45.147200f, 45.382645f, 45.675300f, 45.984990f, 46.195976f, 46.374413f, 46.595448f, + 46.907196f, 47.141773f, 47.321774f, 47.523476f, 47.819862f, 48.083910f, 48.271680f, 48.459446f, 48.723490f, + 49.019882f, 49.221580f, 49.401585f, 49.636160f, 49.959538f, 50.232693f, 50.455338f, 50.645270f, 50.815090f, + 50.881920f, 50.858818f, 50.803000f, 44.609966f, 44.788193f, 44.985493f, 45.220936f, 45.513590f, 45.823280f, + 46.034270f, 46.212700f, 46.433743f, 46.745490f, 46.980070f, 47.160065f, 47.361770f, 47.658157f, 47.922207f, + 48.109970f, 48.297733f, 48.561783f, 48.858166f, 49.059875f, 49.239872f, 49.474450f, 49.797830f, 50.070980f, + 50.293625f, 50.483560f, 50.653378f, 50.720203f, 50.697100f, 50.641280f, 44.219246f, 44.397472f, 44.594772f, + 44.830210f, 45.122868f, 45.432560f, 45.643543f, 45.821980f, 46.043020f, 46.354763f, 46.589344f, 46.769340f, + 46.971046f, 47.267433f, 47.531483f, 47.719242f, 47.907005f, 48.171050f, 48.467438f, 48.669140f, 48.849144f, + 49.083720f, 49.407100f, 49.680256f, 49.902905f, 50.092834f, 50.262653f, 50.329483f, 50.306380f, 50.250570f + }); + + auto size = NDArrayFactory::create({30, 30}); + sd::ops::resize_bicubic op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + NDArray* result = results.at(0); + +// result.printBuffer("Resized to 30x30"); +// expected.printBuffer("Expect for 30x30"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test7) { + + NDArray input = NDArrayFactory::create('c', {2, 5, 5, 1}, { + 0.2303, 0.7950, 0.8171, 0.0451, 0.3690, 0.6846, 0.2727, 0.2770, 0.2381, 0.9511, + 0.4116, 0.3997, 0.4075, 0.6275, 0.8018, 0.0678, 0.6221, 0.2982, 0.1524, 0.2613, + 0.7425, 0.6036, 0.7926, 0.5838, 0.1361, 0.4154, 0.3634, 0.3741, 0.2088, 0.2989, + 0.3982, 0.5618, 0.7266, 0.1089, 0.2922, 0.3306, 0.2869, 0.6638, 0.3091, 0.9312, + 0.0240, 0.2893, 0.5632, 0.9625, 0.4189, 0.3854, 0.2743, 0.6754, 0.8820, 0.8699}); + + NDArray expected = NDArrayFactory::create('c', {2, 9, 9, 1}, { + 0.2303f, 0.54569f, 0.840649f, 0.92725444f, 0.65660673f, + 0.16641647f, 0.06117659f, 0.33279106f, 0.4023279f, 0.5139505f, + 0.49821317f, 0.4906872f, 0.537642f, 0.4070102f, 0.13030615f, + 0.258801f, 0.65352744f, 0.773368f, 0.69225276f, 0.44177493f, + 0.21910316f, 0.22368976f, 0.24221404f, 0.21399781f, 0.5114972f, + 0.9169859f, 1.0511527f, 0.5608501f, 0.41315168f, 0.2913824f, + 0.2966933f, 0.38585684f, 0.48849702f, 0.71013063f, 0.9086001f, + 0.9794303f, 0.29625386f, 0.39427578f, 0.45971435f, 0.39693952f, + 0.40860707f, 0.51061106f, 0.6181093f, 0.67309624f, 0.69564015f, + 0.06012487f, 0.3863805f, 0.58993465f, 0.40679216f, 0.22607432f, + 0.20093678f, 0.25901243f, 0.3615362f, 0.39371052f, 0.24176767f, + 0.4868709f, 0.650651f, 0.5493148f, 0.3825456f, 0.27788478f, + 0.18927254f, 0.16692996f, 0.15432167f, 0.677519f, 0.6236242f, + 0.61700624f, 0.7214321f, 0.7307374f, 0.6251454f, 0.3924176f, + 0.17802659f, 0.10231908f, 0.81192374f, 0.66878575f, 0.6118803f, + 0.7797006f, 0.8396968f, 0.72889954f, 0.44547448f, 0.16794783f, + 0.07125802f, 0.4154f, 0.38504714f, 0.3623221f, 0.3862173f, + 0.3397379f, 0.23285517f, 0.21876639f, 0.2892362f, 0.30817088f, + 0.41268015f, 0.45587808f, 0.51991886f, 0.60977113f, 0.49489656f, + 0.21313031f, 0.11297428f, 0.2167207f, 0.23940037f, 0.39337245f, + 0.46112412f, 0.583034f, 0.76207364f, 0.6326203f, 0.22189438f, + 0.12071565f, 0.3275853f, 0.3794855f, 0.38497013f, 0.35049653f, + 0.41895086f, 0.671095f, 0.62119365f, 0.22362521f, 0.30189657f, + 0.72530353f, 0.85048175f, 0.2524255f, 0.2182264f, 0.2964637f, + 0.5361996f, 0.6255393f, 0.46424767f, 0.5741281f, 0.8408146f, + 0.92403257f, 0.04648584f, 0.14959256f, 0.32215607f, 0.46194845f, + 0.6642166f, 0.83560026f, 0.7663391f, 0.5284251f, 0.4573109f, + 0.10357999f, 0.17442937f, 0.32116935f, 0.45530772f, 0.7163773f, + 0.9856574f, 0.8976148f, 0.5538923f, 0.45173654f, 0.34958175f, + 0.2680429f, 0.30470955f, 0.51233786f, 0.75128907f, 0.86736864f, + 0.8982046f, 0.83254474f, 0.8168574f, 0.4225865f, 0.2956836f, + 0.29948136f, 0.5276342f, 0.76461166f, 0.8442875f, 0.907862f, + 0.9139262f, 0.92068815f + }); + auto size = NDArrayFactory::create({9, 9}); + sd::ops::resize_bicubic op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + +// result.printBuffer("Resized to 9x9"); +// expected.printBuffer("Expect for 9x9"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests11, ImageResizeBicubic_Test8) { + + NDArray input = NDArrayFactory::create('c', {2, 5, 5, 1}, { + 0.23028551377579154, 0.7949972231516509, 0.8171307820461517, 0.04507309923418412, 0.3689673597428338, + 0.6845757584903018, 0.27268547668219667, 0.2770196372806053, 0.2381478370531429, 0.9511201914609859, + 0.41160882670429033, 0.3997152563642703, 0.4074505147711718, 0.6274595060113246, 0.8017922711300232, + 0.06782045852179475, 0.6220772280691722, 0.2982335327629251, 0.1523603480424196, 0.2612986044295986, + 0.7424762244324299, 0.6036156464824591, 0.7926371071102005, 0.5838270656432538, 0.13607200219168547, + 0.4154002170215956, 0.36340617544852116, 0.37405031188276827, 0.20880251686544882, 0.298919946410666, + 0.39820758164277126, 0.5617728968896589, 0.72660225993937, 0.10888245916813699, 0.29215797784445496, + 0.3305531351746034, 0.28693451964931715, 0.6637635348315494, 0.30913418229827583, 0.9312186188801752, + 0.0239594182399363, 0.2892942758780874, 0.5631691110629038, 0.9625499752246309, 0.4189439089689968, + 0.3854304088214935, 0.27426304203925045, 0.6754051704648238, 0.8820362490795286, 0.8699337744328859}); + + + auto testData = NDArrayFactory::create('c', {2,9,9,1}, { + 0.230286514f, 0.510566354f, 0.794997215f, 0.931386113f, 0.817130804f, 0.402811885f, 0.045073099f, 0.134639814f, 0.368967354f, + 0.483021289f, 0.501266003f, 0.521932304f, 0.572325349f, 0.534847379f, 0.267853439f, 0.105112493f, 0.349290252f, 0.674043298f, + 0.684575737f, 0.478224277f, 0.272685468f, 0.239882097f, 0.27701965f, 0.191148892f, 0.23814784f, 0.590989769f, 0.951120198f, + 0.622912169f, 0.441326082f, 0.266387194f, 0.232538164f, 0.301838756f, 0.356378645f, 0.495445013f, 0.756725252f, 0.981704295f, + 0.411608815f, 0.40493685f, 0.399715245f, 0.381842017f, 0.407450527f, 0.501836538f, 0.627459526f, 0.735251725f, 0.801792264f, + 0.150875032f, 0.357000858f, 0.524536073f, 0.450354964f, 0.318719596f, 0.319606483f, 0.385957927f, 0.46392554f, 0.529285908f, + 0.06782046f, 0.375309169f, 0.622077227f, 0.525792599f, 0.298233539f, 0.184723631f, 0.15236035f, 0.193153858f, 0.261298597f, + + 0.372918189f, 0.512539625f, 0.63369292f, 0.628733814f, 0.535196245f, 0.436597466f, 0.323553175f, 0.215942055f, 0.148014024f, + 0.742476225f, 0.655325174f, 0.603615642f, 0.704684138f, 0.79263711f, 0.747929871f, 0.583827078f, 0.340373576f, 0.136071995f, + 0.415400207f, 0.388405323f, 0.363406181f, 0.379345775f, 0.374050319f, 0.28397581f, 0.208802521f, 0.238369256f, 0.298919946f, + 0.413146496f, 0.444389015f, 0.488355637f, 0.568351328f, 0.556217432f, 0.345546633f, 0.140068889f, 0.148834035f, 0.23562704f, + 0.398207575f, 0.464537472f, 0.561772883f, 0.717433035f, 0.726602256f, 0.416013002f, 0.108882457f, 0.142608985f, 0.292157978f, + 0.391511708f, 0.389470309f, 0.442729384f, 0.651181757f, 0.737665415f, 0.41685915f, 0.138383076f, 0.342548877f, 0.659080088f, + + 0.330553144f, 0.273416102f, 0.286934525f, 0.50450629f, 0.663763523f, 0.463456154f, 0.309134185f, 0.586929917f, 0.931218624f, + 0.137025774f, 0.169145152f, 0.263757467f, 0.436182201f, 0.597053051f, 0.657990932f, 0.662163854f, 0.68354249f, 0.692712903f, + 0.023959421f, 0.130951077f, 0.289294273f, 0.413664877f, 0.563169122f, 0.839498401f, 0.962549984f, 0.728188932f, 0.418943912f, + 0.175951749f, 0.198239252f, 0.281999886f, 0.420836329f, 0.609856486f, 0.863734365f, 0.983550847f, 0.825015843f, 0.596413136f, + 0.385430396f, 0.292239636f, 0.274263054f, 0.445040524f, 0.675405145f, 0.817462444f, 0.882036269f, 0.895356655f, 0.869933784f + }); + + auto size = NDArrayFactory::create({9, 9}); + sd::ops::resize_bicubic op; + auto results = op.evaluate({&input, &size}, {}, {}, {true, false}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + +// result.printBuffer("Resized to 9x9"); +// testData.printBuffer("Expect for 9x9"); + ASSERT_TRUE(testData.isSameShape(result)); + ASSERT_TRUE(testData.equalsTo(result)); +} + + +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test1) { + + NDArray input = NDArrayFactory::create('c', {1, 3, 3, 4}); + NDArray expected = NDArrayFactory::create('c', {1, 6, 6, 4}, { + 1.f, 2.f, 3.f, 4.f, + 1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f, + 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, + 9.f, 10.f, 11.f, 12.f, + + 1.f, 2.f, 3.f, 4.f, + 1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f, + 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, + 9.f, 10.f, 11.f, 12.f, + + 13.f, 14.f, 15.f, 16.f, + 13.f, 14.f, 15.f, 16.f, + 17.f, 18.f, 19.f, 20.f, + 17.f, 18.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, + 21.f, 22.f, 23.f, 24.f, + + 13.f, 14.f, 15.f, 16.f, + 13.f, 14.f, 15.f, 16.f, + 17.f, 18.f, 19.f, 20.f, + 17.f, 18.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, + 21.f, 22.f, 23.f, 24.f, + + 25.f, 26.f, 27.f, 28.f, + 25.f, 26.f, 27.f, 28.f, + 29.f, 30.f, 31.f, 32.f, + 29.f, 30.f, 31.f, 32.f, + 33.f, 34.f, 35.f, 36.f, + 33.f, 34.f, 35.f, 36.f, + + 25.f, 26.f, 27.f, 28.f, + 25.f, 26.f, 27.f, 28.f, + 29.f, 30.f, 31.f, 32.f, + 29.f, 30.f, 31.f, 32.f, + 33.f, 34.f, 35.f, 36.f, + 33.f, 34.f, 35.f, 36.f }); + input.linspace(1); + auto size = NDArrayFactory::create({6, 6}); + sd::ops::resize_area op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + +// result.printBuffer("Area Resized to 6x6"); +// expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + + +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test2) { + + NDArray input = NDArrayFactory::create('c', {1, 3, 3, 1}); + NDArray expected = NDArrayFactory::create('c', {1, 6, 6, 1}, { + 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, + 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, + 4.f, 4.f, 5.f, 5.f, 6.f, 6.f, + 4.f, 4.f, 5.f, 5.f, 6.f, 6.f, + 7.f, 7.f, 8.f, 8.f, 9.f, 9.f, + 7.f, 7.f, 8.f, 8.f, 9.f, 9.f + }); + input.linspace(1); + auto size = NDArrayFactory::create({6, 6}); + sd::ops::resize_area op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + +// result.printBuffer("Area Resized to 6x6"); +// expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + + +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test3) { + + NDArray input = NDArrayFactory::create('c', {1, 3, 3, 3}); + NDArray expected = NDArrayFactory::create('c', {1, 6, 6, 3}, { + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, + 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f + }); + input.linspace(1); + auto size = NDArrayFactory::create({6, 6}); + sd::ops::resize_area op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + +// result.printBuffer("Area Resized to 6x6"); +// expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test4) { + + NDArray input = NDArrayFactory::create('c', {2, 3, 3, 3}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27 + }); + + NDArray expected = NDArrayFactory::create('c', {2, 6, 6, 3}, { + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, + 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, + + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, + 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f + }); + //input.linspace(1); + auto size = NDArrayFactory::create({6, 6}); + sd::ops::resize_area op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + +// result.printBuffer("Area Resized to 6x6"); +// expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test5) { + + NDArray input = NDArrayFactory::create('c', {2, 3, 3, 3}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27 + }); + + NDArray expected = NDArrayFactory::create('c', {2, 6, 6, 3}, { + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, + 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, + + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, + 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, + 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, + 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f + }); + //input.linspace(1); + auto size = NDArrayFactory::create({6, 6}); + sd::ops::resize_area op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + +// result.printBuffer("Area Resized to 6x6"); +// expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test6) { + + NDArray input = NDArrayFactory::create('c', {2, 3, 3, 1}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, + 1, 2, 3, 4, 5, 6, 7, 8, 9 + }); + + NDArray expected = NDArrayFactory::create('c', {2, 6, 6, 1}, { + 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, + 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, + 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, + 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, + 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, + 7.f, 7.f, 7.5f, 8.f, 8.f, 9.f, + + 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, + 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, + 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, + 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, + 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, + 7.f, 7.f, 7.5f, 8.f, 8.f, 9.f + }); + //input.linspace(1); + auto size = NDArrayFactory::create({6, 6}); + sd::ops::resize_area op; + auto results = op.evaluate({&input, &size}, {}, {}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + +// result.printBuffer("Area Resized to 6x6"); +// expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test7) { + + NDArray input = NDArrayFactory::create('c', {2, 3, 3, 1}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, + 1, 2, 3, 4, 5, 6, 7, 8, 9 + }); + + NDArray expected = NDArrayFactory::create('c', {2, 6, 6, 1}, { + 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, + 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, + 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, + 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, + 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, + 7.f, 7.f, 7.5f, 8.f, 8.f, 9.f, + + 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, + 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, + 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, + 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, + 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, + 7.f, 7.f, 7.5f, 8.f, 8.f, 9.f + }); + //input.linspace(1); +// auto size = NDArrayFactory::create({6, 6}); + sd::ops::resize_area op; + auto results = op.evaluate({&input}, {}, {6, 6}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + +// result.printBuffer("Area Resized to 6x6"); +// expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test8) { + + NDArray input = NDArrayFactory::create('c', {1, 3, 3, 1}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9 + }); + + NDArray expected = NDArrayFactory::create('c', {1, 6, 6, 1}, { + 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, + 1.f, 1.f, 1.5f, 2.f, 2.f, 3.f, + 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, + 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, + 4.f, 4.f, 4.5f, 5.f, 5.f, 6.f, + 7.f, 7.f, 7.5f, 8.f, 8.f, 9.f + }); + //input.linspace(1); +// auto size = NDArrayFactory::create({6, 6}); + sd::ops::resize_area op; + auto results = op.evaluate({&input}, {}, {6, 6}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + +// result.printBuffer("Area Resized to 6x6"); +// expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests11, ResizeImages_Test8) { + + NDArray input = NDArrayFactory::create('c', {1, 3, 3, 1}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9 + }); + + NDArray expected = NDArrayFactory::create('c', {1, 6, 6, 1}, { +// 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 5.f, 5.f, 6.f, 6.f, 4.f, 4.f, 5.f, 5.f, +// 6.f, 6.f, 7.f, 7.f, 8.f, 8.f, 9.f, 9.f, 7.f, 7.f, 8.f, 8.f, 9.f, 9.f + 1.f , 1.f , 1.5f, 2.f , 2.f, 3.f, 1.f , 1.f , 1.5f, 2.f , 2.f, 3.f, + 2.5f, 2.5f, 3.f, 3.5f, 3.5f, 4.5f, 4.f , 4.f , 4.5f , 5.f, 5.f, 6.f , + 4.f, 4.f, 4.5f , 5.f, 5.f, 6.f, 7.f , 7.f , 7.5f , 8.f , 8.f , 9.f + }); + //input.linspace(1); +// auto size = NDArrayFactory::create({6, 6}); + sd::ops::resize_images op; + auto results = op.evaluate({&input}, {}, {6, 8, ops::helpers::kResizeArea}, {true, true}); // resize_area to 6x8 with align corners and preserve aspect ratio of input image + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + +// result->printBuffer("Area Resized to 6x6"); +// expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test9) { + + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24 + }); + + NDArray expected = NDArrayFactory::create('c', {1, 10, 10, 4}, { + 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, + 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, + 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, + 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333336f, 8.999999f, 9.999999f, 11.000000f, 11.999999f, 8.999999f, 9.999999f, 11.000000f, 11.999999f, 8.999998f, 9.999997f, 10.999997f, 11.999997f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 15.666671f, 16.666672f, 17.666672f, 18.666672f, 17.000006f, 18.000004f, 19.000006f, 20.000004f, 17.000006f, 18.000004f, 19.000006f, 20.000004f, 18.333344f, 19.333344f, 20.333345f, 21.333344f, 21.000006f, 22.000006f, 23.000006f, 24.000006f, 21.000006f, 22.000006f, 23.000006f, 24.000006f, 21.000002f, 22.000000f, 23.000002f, 24.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, 12.999995f, 13.999995f, 14.999994f, 15.999994f, 12.999995f, 13.999995f, 14.999994f, 15.999994f, 12.999995f, 13.999995f, 14.999994f, 15.999994f, 15.666661f, 16.666662f, 17.666660f, 18.666660f, 16.999994f, 17.999994f, 18.999992f, 19.999992f, 16.999994f, 17.999994f, 18.999992f, 19.999992f, 18.333334f, 19.333332f, 20.333334f, 21.333332f, 20.999992f, 21.999992f, 22.999990f, 23.999992f, 20.999992f, 21.999992f, 22.999990f, 23.999992f, 20.999989f, 21.999989f, 22.999987f, 23.999987f + + }); + //input.linspace(1); + auto size = NDArrayFactory::create({10, 10}); + sd::ops::resize_area op; + auto results = op.evaluate({&input, &size}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + +// result.printBuffer("Area Resized to 10x10"); + // expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test10) { + + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24 + }); + + NDArray expected = NDArrayFactory::create('c', {1, 10, 10, 4}, { + 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333337f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 9.000000f, 10.000000f, 11.000000f, 12.000000f, 8.999998f, 9.999998f, 10.999998f, 11.999998f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 1.000000f, 2.000000f, 3.000000f, 4.000000f, 3.666667f, 4.666667f, 5.666667f, 6.666667f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 5.000000f, 6.000000f, 7.000000f, 8.000000f, 6.333336f, 7.333336f, 8.333336f, 9.333336f, 8.999999f, 9.999999f, 11.000000f, 11.999999f, 8.999999f, 9.999999f, 11.000000f, 11.999999f, 8.999998f, 9.999997f, 10.999997f, 11.999997f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 13.000003f, 14.000004f, 15.000003f, 16.000004f, 15.666671f, 16.666672f, 17.666672f, 18.666672f, 17.000006f, 18.000004f, 19.000006f, 20.000004f, 17.000006f, 18.000004f, 19.000006f, 20.000004f, 18.333344f, 19.333344f, 20.333345f, 21.333344f, 21.000006f, 22.000006f, 23.000006f, 24.000006f, 21.000006f, 22.000006f, 23.000006f, 24.000006f, 21.000002f, 22.000000f, 23.000002f, 24.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 13.000000f, 14.000001f, 15.000000f, 16.000000f, 15.666667f, 16.666668f, 17.666668f, 18.666668f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 17.000002f, 18.000000f, 19.000002f, 20.000000f, 18.333340f, 19.333340f, 20.333342f, 21.333340f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 21.000002f, 22.000000f, 22.999998f, 24.000000f, 20.999996f, 21.999996f, 22.999994f, 23.999996f, 12.999995f, 13.999995f, 14.999994f, 15.999994f, 12.999995f, 13.999995f, 14.999994f, 15.999994f, 12.999995f, 13.999995f, 14.999994f, 15.999994f, 15.666661f, 16.666662f, 17.666660f, 18.666660f, 16.999994f, 17.999994f, 18.999992f, 19.999992f, 16.999994f, 17.999994f, 18.999992f, 19.999992f, 18.333334f, 19.333332f, 20.333334f, 21.333332f, 20.999992f, 21.999992f, 22.999990f, 23.999992f, 20.999992f, 21.999992f, 22.999990f, 23.999992f, 20.999989f, 21.999989f, 22.999987f, 23.999987f + + }); + //input.linspace(1); + //auto size = NDArrayFactory::create({10, 10}); + sd::ops::resize_area op; + auto results = op.evaluate({&input}, {}, {10, 10}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + +// result.printBuffer("Area Resized to 10x10"); + // expected.printBuffer("Area Expect for 6x6"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test11) { + + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24 + }); + +// NDArray expected = NDArrayFactory::create('c', {1, 6, 9, 4}, { +// 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333336, 8.999999, 9.999999, 11.000000, 11.999999, 8.999999, 9.999999, 11.000000, 11.999999, 8.999998, 9.999997, 10.999997, 11.999997, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 15.666671, 16.666672, 17.666672, 18.666672, 17.000006, 18.000004, 19.000006, 20.000004, 17.000006, 18.000004, 19.000006, 20.000004, 18.333344, 19.333344, 20.333345, 21.333344, 21.000006, 22.000006, 23.000006, 24.000006, 21.000006, 22.000006, 23.000006, 24.000006, 21.000002, 22.000000, 23.000002, 24.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 15.666661, 16.666662, 17.666660, 18.666660, 16.999994, 17.999994, 18.999992, 19.999992, 16.999994, 17.999994, 18.999992, 19.999992, 18.333334, 19.333332, 20.333334, 21.333332, 20.999992, 21.999992, 22.999990, 23.999992, 20.999992, 21.999992, 22.999990, 23.999992, 20.999989, 21.999989, 22.999987, 23.999987 +// +// }); + //input.linspace(1); + //auto size = NDArrayFactory::create({10, 10}); + sd::ops::resize_area op; + auto results = op.evaluate({&input}, {}, {6, 9}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + +// result.printBuffer("Area Resized to 6x9"); + // expected.printBuffer("Area Expect for 6x6"); +// ASSERT_TRUE(expected.isSameShape(result)); +// ASSERT_TRUE(expected.equalsTo(result)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test12) { + + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24 + }); + +// NDArray expected = NDArrayFactory::create('c', {1, 6, 9, 4}, { +// 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333336, 8.999999, 9.999999, 11.000000, 11.999999, 8.999999, 9.999999, 11.000000, 11.999999, 8.999998, 9.999997, 10.999997, 11.999997, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 15.666671, 16.666672, 17.666672, 18.666672, 17.000006, 18.000004, 19.000006, 20.000004, 17.000006, 18.000004, 19.000006, 20.000004, 18.333344, 19.333344, 20.333345, 21.333344, 21.000006, 22.000006, 23.000006, 24.000006, 21.000006, 22.000006, 23.000006, 24.000006, 21.000002, 22.000000, 23.000002, 24.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 15.666661, 16.666662, 17.666660, 18.666660, 16.999994, 17.999994, 18.999992, 19.999992, 16.999994, 17.999994, 18.999992, 19.999992, 18.333334, 19.333332, 20.333334, 21.333332, 20.999992, 21.999992, 22.999990, 23.999992, 20.999992, 21.999992, 22.999990, 23.999992, 20.999989, 21.999989, 22.999987, 23.999987 +// +// }); + //input.linspace(1); + //auto size = NDArrayFactory::create({10, 10}); + sd::ops::resize_area op; + auto results = op.evaluate({&input}, {}, {10, 15}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + +// result.printBuffer("Area Resized to 6x9"); + // expected.printBuffer("Area Expect for 6x6"); +// ASSERT_TRUE(expected.isSameShape(result)); +// ASSERT_TRUE(expected.equalsTo(result)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test13) { + + NDArray input = NDArrayFactory::create('c', {1, 2, 3, 4}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,12,13,14,15,16,17,18,19,20,21,22,23,24 + }); + +// NDArray expected = NDArrayFactory::create('c', {1, 8, 8, 4}, { +// 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333337, 9.000000, 10.000000, 11.000000, 12.000000, 9.000000, 10.000000, 11.000000, 12.000000, 8.999998, 9.999998, 10.999998, 11.999998, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 1.000000, 2.000000, 3.000000, 4.000000, 3.666667, 4.666667, 5.666667, 6.666667, 5.000000, 6.000000, 7.000000, 8.000000, 5.000000, 6.000000, 7.000000, 8.000000, 6.333336, 7.333336, 8.333336, 9.333336, 8.999999, 9.999999, 11.000000, 11.999999, 8.999999, 9.999999, 11.000000, 11.999999, 8.999998, 9.999997, 10.999997, 11.999997, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 13.000003, 14.000004, 15.000003, 16.000004, 15.666671, 16.666672, 17.666672, 18.666672, 17.000006, 18.000004, 19.000006, 20.000004, 17.000006, 18.000004, 19.000006, 20.000004, 18.333344, 19.333344, 20.333345, 21.333344, 21.000006, 22.000006, 23.000006, 24.000006, 21.000006, 22.000006, 23.000006, 24.000006, 21.000002, 22.000000, 23.000002, 24.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 13.000000, 14.000001, 15.000000, 16.000000, 15.666667, 16.666668, 17.666668, 18.666668, 17.000002, 18.000000, 19.000002, 20.000000, 17.000002, 18.000000, 19.000002, 20.000000, 18.333340, 19.333340, 20.333342, 21.333340, 21.000002, 22.000000, 22.999998, 24.000000, 21.000002, 22.000000, 22.999998, 24.000000, 20.999996, 21.999996, 22.999994, 23.999996, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 12.999995, 13.999995, 14.999994, 15.999994, 15.666661, 16.666662, 17.666660, 18.666660, 16.999994, 17.999994, 18.999992, 19.999992, 16.999994, 17.999994, 18.999992, 19.999992, 18.333334, 19.333332, 20.333334, 21.333332, 20.999992, 21.999992, 22.999990, 23.999992, 20.999992, 21.999992, 22.999990, 23.999992, 20.999989, 21.999989, 22.999987, 23.999987 +// +// }); + //input.linspace(1); + //auto size = NDArrayFactory::create({10, 10}); + sd::ops::resize_area op; + auto results = op.evaluate({&input}, {}, {9, 9}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); + +// result.printBuffer("Area Resized to 8x8"); + // expected.printBuffer("Area Expect for 6x6"); +// ASSERT_TRUE(expected.isSameShape(result)); +// ASSERT_TRUE(expected.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test14) { + + NDArray input = NDArrayFactory::create('c', {1, 5, 5, 1}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 + }); + auto size = NDArrayFactory::create({8, 7}); + NDArray expected = NDArrayFactory::create('c', {1, 8, 7, 1}, { + 1.f, 1.6f , 2.1999993f, 2.9999995f , 3.8f , 4.399997f, 5.f , 2.9999995f , 3.5999997f , 4.199999f, + 4.9999995f, 5.8f , 6.3999963f , 7.f , 5.999999f , 6.6f, 7.1999984f , 7.9999995f , 8.8f, + 9.399994f, 10.f , 10.f, 10.6f , 11.199998f, 12.f, 12.8f, 13.399992f , 14.f, 12.f , 12.599999f, + 13.199998f , 13.999998f , 14.800002f , 15.399991f , 16.f , 15.999999f , 16.599998f , 17.199995f, + 18.f , 18.800003f , 19.399986f , 20.000002f , 19.f , 19.599998f , 20.199997f , + 20.999998f , 21.800003f , 22.399984f , 23.000002f , 20.999998f , + 21.599998f , 22.199995f , 22.999998f , 23.800001f , 24.399984f , + 25.f + }); //input.linspace(1); +// auto size = NDArrayFactory::create({6, 6}); + sd::ops::resize_area op; + auto results = op.evaluate({&input, &size}, {}, {false}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); +// result.printBuffer("Area Resized to 8x7"); +// expected.printBuffer("Area Expect for 8x7"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests11, ImageResizeArea_Test15) { + + NDArray input = NDArrayFactory::create('c', {1, 5, 5, 1}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 + }); + //auto size = NDArrayFactory::create({8, 7}); + NDArray expected = NDArrayFactory::create('c', {1, 8, 7, 1}, { + 1.f, 1.6f , 2.1999993f, 2.9999995f , 3.8f , 4.399997f, 5.f , 2.9999995f , 3.5999997f , 4.199999f, + 4.9999995f, 5.8f , 6.3999963f , 7.f , 5.999999f , 6.6f, 7.1999984f , 7.9999995f , 8.8f, + 9.399994f, 10.f , 10.f, 10.6f , 11.199998f, 12.f, 12.8f, 13.399992f , 14.f, 12.f , 12.599999f, + 13.199998f , 13.999998f , 14.800002f , 15.399991f , 16.f , 15.999999f , 16.599998f , 17.199995f, + 18.f , 18.800003f , 19.399986f , 20.000002f , 19.f , 19.599998f , 20.199997f , + 20.999998f , 21.800003f , 22.399984f , 23.000002f , 20.999998f , 21.599998f , 22.199995f , + 22.999998f , 23.800001f , 24.399984f , 25.f + }); + + sd::ops::resize_area op; + auto results = op.evaluate({&input}, {}, {8, 7}, {false}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray* result = results.at(0); +// result.printBuffer("Area Resized to 8x7"); +// expected.printBuffer("Area Expect for 8x7"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, summaryStatsData_test1) { + + functions::summarystats::SummaryStatsData var1; + functions::summarystats::SummaryStatsData var2; + var2.n = var2.mean = var2.M2 = var2.M3 = var2.M4 = var2.bias = 5; + + functions::summarystats::SummaryStatsData* arr = new functions::summarystats::SummaryStatsData[2]; + arr[0] = var1; + arr[1] = var2; + arr[0] = arr[1]; + + functions::summarystats::SummaryStatsData var3(var1); + + ASSERT_TRUE(arr[0].n == arr[0].mean && arr[0].M2 == arr[0].M3 && arr[0].n == 5); + ASSERT_TRUE(arr[1].n == arr[1].mean && arr[1].M2 == arr[1].M3 && arr[1].n == 5); + ASSERT_TRUE(var3.n == var3.mean && var3.M2 == var3.M3 && var3.n == 0); + + delete []arr; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_1) { + + auto a = NDArrayFactory::create('c', {3, 3}, { + 2.f, -1.f, -2.f, -4.f, 6.f, 3.f, -4.f, -2.f, 8.f + }); + + auto b = NDArrayFactory::create('c', {3, 1}, { + 2.f, 4.f, 3.f + }); + + auto exp = NDArrayFactory::create('c', {3, 1}, { + 7.625f, 3.25f, 5.f + }); + + sd::ops::solve op; + + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + +// z->printIndexedBuffer("Solve of 3x3"); + + ASSERT_TRUE(exp.equalsTo(z)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_2) { + + auto a = NDArrayFactory::create('c', {4, 4}, { + 1.f, 1.f, 1.f, 1.f, + 0.f, 1.f, 1.f, 0.f, + 0.f, 0.f, 2.f, 1.f, + 0.f, 0.f, 0.f, 3.f, + }); + + auto b = NDArrayFactory::create('c', {4, 1}, { + 2.f, 4.f, 2.f, 4.f + }); + + auto exp = NDArrayFactory::create('c', {4, 1}, { + -3.3333333f, 3.6666666f, 0.333333f, 1.3333333f + }); + + sd::ops::solve op; + + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + +// z->printIndexedBuffer("Solve 4x4"); + + ASSERT_TRUE(exp.equalsTo(z)); + +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_3) { + + auto a = NDArrayFactory::create('c', {2, 4, 4}, { + 1.f, 1.f, 1.f, 1.f, + 0.f, 1.f, 1.f, 0.f, + 0.f, 0.f, 2.f, 1.f, + 0.f, 0.f, 0.f, 3.f, + + 3.f, 0.f, 0.f, 0.f, + 2.f, 1.f, 0.f, 0.f, + 1.f, 0.f, 1.f, 0.f, + 1.f, 1.f, 1.f, 1.f + + }); + + auto b = NDArrayFactory::create('c', {2, 4, 1}, { + 2.f, 4.f, 2.f, 4.f, + 4.f, 2.f, 4.f, 2.f + }); + + auto exp = NDArrayFactory::create('c', {2, 4, 1}, { + -3.3333333f, 3.6666666f, 0.333333f, 1.3333333f, + 1.333333f, -0.6666667f, 2.6666667f, -1.3333333f + }); + + sd::ops::solve op; + + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + +// z->printIndexedBuffer("Solve 4x4"); + + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_4) { + + auto a = NDArrayFactory::create('c', {2, 2, 2}, { + 0.7788f, 0.8012f, 0.7244f, 0.2309f, + 0.7271f, 0.1804f, 0.5056f, 0.8925f + }); + + auto b = NDArrayFactory::create('c', {2, 2, 2}, { + 0.7717f, 0.9281f, 0.9846f, 0.4838f, + 0.6433f, 0.6041f, 0.6501f, 0.7612f + }); + + auto exp = NDArrayFactory::create('c', {2, 2, 2}, { +// 1.524494767f, 0.432706356f,-0.518630624f, 0.737760842f, +// 0.819143713f, 0.720401764f, 0.264349997f, 0.444699198f + 1.5245394f, 0.4326952f, -0.51873577f, 0.7377896f, + 0.81915987f, 0.72049433f, 0.2643504f, 0.44472617f + }); + + sd::ops::solve op; + + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + +// z->printBuffer("4 Solve 4x4"); +// exp.printBuffer("4 Expec 4x4"); + + ASSERT_TRUE(exp.equalsTo(z)); + +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_4_1) { + + auto a = NDArrayFactory::create('c', {2, 2, 2}, { + 0.7788f, 0.8012f, 0.7244f, 0.2309f, + 0.7271f, 0.1804f, 0.5056f, 0.8925f + }); + + auto b = NDArrayFactory::create('c', {2, 2, 2}, { + 0.7717f, 0.9281f, 0.9846f, 0.4838f, 0.6433f, 0.6041f, 0.6501f, 0.7612f + }); + + auto exp = NDArrayFactory::create('c', {2, 2, 2}, { + 1.3357621f, 0.3399364f, -0.37077796f, 0.91573375f, + 0.4400987f, 0.2766527f, 0.6394467f, 0.79696566f + }); + + sd::ops::solve op; + + auto res = op.evaluate({&a, &b}, {true}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + +// z->printBuffer("4 Solve 4x4"); +// exp.printBuffer("4 Expec 4x4"); + + ASSERT_TRUE(exp.equalsTo(z)); + +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_4_2) { + + auto a = NDArrayFactory::create('c', {3, 3}, { + 0.7788f, 0.8012f, 0.7244f, + 0.2309f, 0.7271f, 0.1804f, + 0.5056f, 0.8925f, 0.5461f + }); + + auto b = NDArrayFactory::create('c', {3, 3}, { + 0.7717f, 0.9281f, 0.9846f, + 0.4838f, 0.6433f, 0.6041f, + 0.6501f, 0.7612f, 0.7605f + }); + + auto exp = NDArrayFactory::create('c', {3, 3}, { + 0.99088347f, 1.1917052f, 1.2642528f, + 0.35071516f, 0.50630623f, 0.42935497f, + -0.30013534f, -0.53690606f, -0.47959247f + }); + + sd::ops::triangular_solve op; + + auto res = op.evaluate({&a, &b}, {true, false}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + +// z->printBuffer("4_2 Triangular_Solve 3x3"); +// exp.printBuffer("4_2 Triangular_Expec 3x3"); + + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_4_3) { + + auto a = NDArrayFactory::create('c', {3, 3}, { + 0.7788f, 0.8012f, 0.7244f, + 0.2309f, 0.7271f, 0.1804f, + 0.5056f, 0.8925f, 0.5461f + }); + + auto b = NDArrayFactory::create('c', {3, 3}, { + 0.7717f, 0.9281f, 0.9846f, + 0.4838f, 0.6433f, 0.6041f, + 0.6501f, 0.7612f, 0.7605f + }); + + auto exp = NDArrayFactory::create('c', {3, 3}, { + 0.45400196f, 0.53174824f, 0.62064564f, + -0.79585856f, -0.82621557f, -0.87855506f, + 1.1904413f, 1.3938838f, 1.3926021f + }); + + sd::ops::triangular_solve op; + + auto res = op.evaluate({&a, &b}, {true, true}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + +// z->printBuffer("4_3 Triangular_Solve 3x3"); +// exp.printBuffer("4_3 Triangular_Expec 3x3"); + + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_4_4) { + + auto a = NDArrayFactory::create('c', {3, 3}, { + 0.7788f, 0.8012f, 0.7244f, + 0.2309f, 0.7271f, 0.1804f, + 0.5056f, 0.8925f, 0.5461f + }); + + auto b = NDArrayFactory::create('c', {3, 3}, { + 0.7717f, 0.9281f, 0.9846f, + 0.4838f, 0.6433f, 0.6041f, + 0.6501f, 0.7612f, 0.7605f + }); + + auto exp = NDArrayFactory::create('c', {3, 3}, { + 0.8959121f, 1.6109066f, 1.7501404f, + 0.49000582f, 0.66842675f, 0.5577021f, + -0.4398522f, -1.1899745f, -1.1392052f + }); + + sd::ops::solve op; + + auto res = op.evaluate({&a, &b}, {false}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + +// z->printBuffer("4_4 Solve 3x3"); +// exp.printBuffer("4_4 Expec 3x3"); + + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_4_5) { + + auto a = NDArrayFactory::create('c', {3, 3}, { + 0.7788f, 0.8012f, 0.7244f, + 0.2309f, 0.7271f, 0.1804f, + 0.5056f, 0.8925f, 0.5461f + }); + + auto b = NDArrayFactory::create('c', {3, 3}, { + 0.7717f, 0.9281f, 0.9846f, + 0.4838f, 0.6433f, 0.6041f, + 0.6501f, 0.7612f, 0.7605f + }); + + auto exp = NDArrayFactory::create('c', {3, 3}, { + 1.5504692f, 1.8953944f, 2.2765768f, + 0.03399149f, 0.2883001f, 0.5377323f, + -0.8774802f, -1.2155888f, -1.8049058f + }); + + sd::ops::solve op; + + auto res = op.evaluate({&a, &b}, {true, true}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + +// z->printBuffer("4_5 Solve 3x3"); +// exp.printBuffer("4_5 Expec 3x3"); + + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_4_6) { + + auto a = NDArrayFactory::create('c', {3, 3}, { + 0.7788f, 0.8012f, 0.7244f, + 0.2309f, 0.7271f, 0.1804f, + 0.5056f, 0.8925f, 0.5461f + }); + + auto b = NDArrayFactory::create('c', {3, 3}, { + 0.7717f, 0.9281f, 0.9846f, + 0.4838f, 0.6433f, 0.6041f, + 0.6501f, 0.7612f, 0.7605f + }); + + auto exp = NDArrayFactory::create('c', {3, 3}, { + 0.99088347f, 1.1917052f, 1.2642528f, + -0.426483f, -0.42840624f, -0.5622601f, + 0.01692283f, -0.04538865f, -0.09868701f + }); + + sd::ops::triangular_solve op; + + auto res = op.evaluate({&a, &b}, {false, true}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + +// z->printBuffer("4_6 Solve 3x3"); +// exp.printBuffer("4_6 Expec 3x3"); + + ASSERT_TRUE(exp.equalsTo(z)); + +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_4_7) { + + auto a = NDArrayFactory::create('c', {3, 3}, { +// 0.7788f, 0.2309f, 0.5056f, +// 0.8012f, 0.7271f, 0.8925f, +// 0.7244f, 0.1804f, 0.5461f + + 0.7788f, 0.2309f, 0.5056f, + 0.8012f, 0.7271f, 0.8925f, + 0.7244f, 0.1804f, 0.5461f + }); + + auto b = NDArrayFactory::create('c', {3, 3}, { + 0.7717f, 0.9281f, 0.9846f, + 0.4838f, 0.6433f, 0.6041f, + 0.6501f, 0.7612f, 0.7605f + }); + + auto exp = NDArrayFactory::create('c', {3, 3}, { + 0.99088347f, 1.1917052f, 1.2642528f, + -0.426483f, -0.42840624f, -0.5622601f, + 0.01692283f, -0.04538865f, -0.09868701f + }); + + sd::ops::triangular_solve op; + + auto res = op.evaluate({&a, &b}, {true, false}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + +// z->printBuffer("4_7 Solve 3x3"); +// exp.printBuffer("4_7 Expec 3x3"); + + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Solve_Test_5) { + + auto a = NDArrayFactory::create('c', {3, 3}, { + 0.7788f, 0.8012f, 0.7244f, + 0.2309f, 0.7271f, 0.1804f, + 0.5056f, 0.8925f, 0.5461f + }); + + auto b = NDArrayFactory::create('c', {3, 3}, { + 0.7717f, 0.9281f, 0.9846f, + 0.4838f, 0.6433f, 0.6041f, + 0.6501f, 0.7612f, 0.7605f + }); + + auto exp = NDArrayFactory::create('c', {3, 3}, { + 1.5504692f, 1.8953944f, 2.2765768f, + 0.03399149f, 0.2883001f, 0.5377323f, + -0.8774802f, -1.2155888f, -1.8049058f + }); + + sd::ops::solve op; + + auto res = op.evaluate({&a, &b}, {true}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + +// z->printBuffer("4 Solve 4x4"); +// exp.printBuffer("4 Expec 4x4"); + + ASSERT_TRUE(exp.equalsTo(z)); + +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, SolveLS_Test_1) { + + auto a = NDArrayFactory::create('c', {2,2, 2}, { + 1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f + }); + + auto b = NDArrayFactory::create('c', {2, 2, 1}, { + 3.f, 7.f, 11.f, 15.f + }); + + auto exp = NDArrayFactory::create('c', {2, 2, 1}, { + 0.8311695f, 1.0909086f, 0.9205573f, 1.0630057f + }); + + sd::ops::lstsq op; + + auto res = op.evaluate({&a, &b}, {0.5}, {}, {true}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + +// z->printIndexedBuffer("LS Solve 2x2"); +// exp.printIndexedBuffer("LS Expec 2x2"); + + ASSERT_TRUE(exp.equalsTo(z, 1.e-4)); +} + +TEST_F(DeclarableOpsTests11, SolveLS_Test_2) { + + auto a = NDArrayFactory::create('c', {2,2, 2}, { + 1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f + }); + + auto b = NDArrayFactory::create('c', {2, 2, 1}, { + 3.f, 7.f, 11.f, 15.f + }); + + auto exp = NDArrayFactory::create('c', {2, 2, 1}, { + 0.8311695f, 1.0909086f, 0.9205573f, 1.0630057f + }); + + sd::ops::lstsq op; + + auto res = op.evaluate({&a, &b}, {0.5}, {}, {true}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + +// z->printIndexedBuffer("2LS Solve 2x2"); +// exp.printIndexedBuffer("2LS Expec 2x2"); + + ASSERT_TRUE(exp.equalsTo(z, 1.e-4)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Cholesky_Test_2x2x2) { + + auto a = NDArrayFactory::create('c', {2,2, 2}, { + 10.f, 14.f, + 14.f, 20.f, + + 74.f, 86.f, + 86.f, 100.f + }); + + auto exp = NDArrayFactory::create('c', {2, 2, 2}, { + 3.1622777f, 0.f, 4.427189f, 0.6324552f, + 8.602325f, 0.f, 9.997296f, 0.23252854f + }); + + sd::ops::cholesky op; + + auto res = op.evaluate({&a}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + + z->printIndexedBuffer("L matrix is"); + exp.printIndexedBuffer("L expected is"); + + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Cholesky_Test_2x2x2_2) { + + auto a = NDArrayFactory::create('c', {2,2, 2}, { + 10.5f, 14.f, + 14.f, 20.5f, + + 74.5f, 86.f, + 86.f, 100.5f + }); + + auto exp = NDArrayFactory::create('c', {2, 2, 2}, { + 3.2403703f, 0.f, 4.3204937f, 1.3540066f, + 8.631338f, 0.f, 9.963693f, 1.1067207f + }); + + sd::ops::cholesky op; + + auto res = op.evaluate({&a}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + +// z->printIndexedBuffer("L matrix is"); +// exp.printIndexedBuffer("L expected is"); + MmulHelper::matmul(z, z, &exp, false, true); + ASSERT_TRUE(exp.equalsTo(a)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test1) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {-0.96, -1.92, -2.88, -3.84, -4.8 , -5.76, -6.72, -7.68, -8.64, -9.6 ,-10.56,-11.52, + -12.48,-13.44,-14.4 ,-15.36,-16.32,-17.28,-18.24,-19.2 ,-20.16,-21.12,-22.08,-23.04}); + NDArray dLdwExp('c', {2,3,4}, {0.9216 , 3.6864 , 8.2944 , 14.7456 , 23.04 , 33.1776 , 45.1584 , 58.9824 , 74.6496 , 92.16 ,111.51361,132.7104 , + 155.75038,180.63359,207.35999,235.9296 ,266.34238,298.59842,332.6976 ,368.64001,406.4256 ,446.05444,487.5264 ,530.84161}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test2) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2,1,4}, sd::DataType::DOUBLE); + + NDArray dLdwExp('c', {2,1,4}, {98.61121,129.024 , 164.9664 , 206.4384 , 828.51837,925.28644,1027.58398,1135.41113}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdw = results.at(1); + + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test3) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights(sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {-0.96, -1.92, -2.88, -3.84, -4.8 , -5.76, -6.72, -7.68, -8.64, -9.6 ,-10.56,-11.52, + -12.48,-13.44,-14.4 ,-15.36,-16.32,-17.28,-18.24,-19.2 ,-20.16,-21.12,-22.08,-23.04}); + NDArray dLdwExp('c', {}, std::vector{4515.84}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test4) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); + + NDArray dLdwExp('c', {1,3,1}, {807.32153, 1426.63684, 2281.88159}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdw = results.at(1); + + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test5) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {-0.08,-0.16,-0.24,-0.32,-0.4 ,-0.48,-0.56,-0.64,-0.72,-0.8 ,-0.88,-0.96, + -1.04,-1.12,-1.2 ,-1.28,-1.36,-1.44,-1.52,-1.6 ,-1.68,-1.76,-1.84,-1.92}); + NDArray dLdwExp('c', {2,3,4}, {-15.6032,-15.3728,-14.9888,-14.4512,-13.76 ,-12.9152,-11.9168,-10.7648, -9.4592, -8. , -6.3872, -4.6208, + -2.7008, -0.6272, 1.6 , 3.9808, 6.5152, 9.2032, 12.0448, 15.04 , 18.1888, 21.4912, 24.9472, 28.5568}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test6) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); + + NDArray dLdwExp('c', {1,3,1}, {-58.16319, -6.5536 , 64.71682}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdw = results.at(1); + + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test7) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights(sd::DataType::DOUBLE); + + NDArray dLdwExp('c', {}, std::vector{0.}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdw = results.at(1); + + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test8) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {0. ,0. ,0. ,0. ,-0.48 ,-0.576,-0.672,-0.768,-0.864,-0.96 ,-1.056,-1.152, + -1.248,-1.344,-1.44 ,-1.536,-1.632,-1.728,-1.824,-1.92 ,-2.016,-2.112,-2.208,-2.304}); + NDArray dLdwExp('c', {2,3,4}, {-22.3488 ,-22.07232,-21.61152,-20.9664 ,-20.13696,-19.1232 ,-17.92512,-16.54272,-14.976 ,-13.22496,-11.2896 , -9.16992, + -6.86592, -4.3776 , -1.70496, 1.152 , 4.19328, 7.41888, 10.8288 , 14.42304, 18.2016 , 22.16449, 26.31168, 30.6432 }); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); + + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test9) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {-0.04,-0.08,-0.12,-0.16,-0.2 ,-0.24,-0.28,-0.32,-0.36,-0.4 ,-0.44,-0.48, + -0.52,-0.56,-0.6 ,-0.64,-0.68,-0.72,-0.76,-0.8 ,-0.84,-0.88,-0.92,-0.96}); + NDArray dLdwExp('c', {2,3,4}, {0.0384, 0.1536, 0.3456, 0.6144, 0.96 , 1.3824, 1.8816, 2.4576, 3.1104, 3.84 , 4.6464, 5.5296, + 6.4896, 7.5264, 8.64 , 9.8304,11.0976,12.4416,13.8624,15.36 ,16.9344,18.5856,20.3136,22.1184}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test10) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {1,1}, sd::DataType::DOUBLE); + + NDArray dLdwExp('c', {1,1}, std::vector{188.16}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdw = results.at(1); + + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test11) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); + + NDArray dLdwExp('c', {1,3,1}, {33.6384 ,59.4432 ,95.07841}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdw = results.at(1); + + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test12) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {0.,0.,0.,0., -0.24 ,-0.288,-0.336,-0.384,-0.432,-0.48 ,-0.528,-0.576, + -0.624,-0.672,-0.72 ,-0.768,-0.816,-0.864,-0.912,-0.96 ,-1.008,-1.056,-1.104,-1.152}); + NDArray dLdwExp('c', {2,3,4}, {0.04608, 0.18432, 0.41472, 0.73728, 1.152 , 1.65888, 2.25792, 2.94912, 3.73248, 4.608 , 5.57568, 6.63552, + 7.78752, 9.03168,10.368 ,11.79648,13.31712,14.92992,16.63488,18.432 ,20.32128,22.30272,24.37632,26.54208}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.r(0) = 0.; + weights.r(1) = 0.; + weights.r(2) = 0.; + weights.r(3) = 0.; + + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, mean_sqerr_loss_grad_test13) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2,3,1}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + -1.04,-1.12,-1.2 ,-1.28,-1.36,-1.44,-1.52,-1.6 ,-1.68,-1.76,-1.84,-1.92}); + NDArray dLdwExp('c', {2,3,1}, {2.304 , 13.3632 , 34.2528 , 64.97279,105.5232 ,155.90401}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.r(0) = 0.; + weights.r(1) = 0.; + weights.r(2) = 0.; + + sd::ops::mean_sqerr_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); + +} + +TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test1) { + auto x = NDArrayFactory::create('c', {4}, {0, 1, 2, 3}); + auto y = NDArrayFactory::create('c',{4}, {3, 2, 1, 0}); + auto exp = NDArrayFactory::create('c', {4}, {9, 1,1, 9}); + sd::ops::squaredsubtract op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test2) { + auto x = NDArrayFactory::create('c', {2, 4}, {0, 1, 2, 3, 0, 1, 2, 3}); + auto y = NDArrayFactory::create('c',{4}, {3, 2, 1, 0}); + auto exp = NDArrayFactory::create('c', {2, 4}, {9, 1,1, 9, 9, 1, 1, 9}); + sd::ops::squaredsubtract op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + +} + +TEST_F(DeclarableOpsTests11, SquaredSubtractTest_Test3) { + auto x = NDArrayFactory::create('c', {2, 4}, {0, 1, 2, 3, 0, 1, 2, 3}); + auto y = NDArrayFactory::create('c',{4}, {3, 2, 1, 0}); + auto exp = NDArrayFactory::create('c', {2, 4}, {-6, -4, 6, 24, -30, -12, 14, 48}); + auto eps = NDArrayFactory::create('c', {2, 4}, {1,2,3,4,5,6,7,8}); + sd::ops::squaredsubtract_bp op; + auto result = op.evaluate({&x, &y, &eps}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test1) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5, + -0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5}); + NDArray dLdwExp('c', {2,3,4}, {0.96, 1.92, 2.88, 3.84, 4.8 , 5.76, 6.72, 7.68, 8.64, 9.6 ,10.56,11.52, + 12.48,13.44,14.4 ,15.36,16.32,17.28,18.24,19.2 ,20.16,21.12,22.08,23.04}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdp = results.at(0); + auto dLdw = results.at(1); + auto dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test2) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2,1,4}, sd::DataType::DOUBLE); + + NDArray dLdwExp('c', {2,1,4}, {14.4 , 17.28, 20.16, 23.04, 48.96, 51.84, 54.72, 57.6}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdw = results.at(1); + + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test3) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights(sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5, + -0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5,-0.5}); + NDArray dLdwExp('c', {}, std::vector{288.}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test4) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); + + NDArray dLdwExp('c', {1,3,1}, {65.28, 96., 126.72001}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdw = results.at(1); + + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test5) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167, + -0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167,-0.04167}); + NDArray dLdwExp('c', {2,3,4}, {-0.92,-0.84,-0.76,-0.68,-0.6 ,-0.52,-0.44,-0.36,-0.28,-0.2 ,-0.12,-0.04, + 0.04, 0.12, 0.2 , 0.28, 0.36, 0.44, 0.52, 0.6 , 0.68, 0.76, 0.84, 0.92}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test6) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); + + NDArray dLdwExp('c', {1,3,1}, {-2.56, 0., 2.56}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdw = results.at(1); + + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test7) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights(sd::DataType::DOUBLE); + + NDArray dLdwExp('c', {}, std::vector{0.}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdw = results.at(1); + + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test8) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {-0. ,-0. ,-0. ,-0. ,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05, + -0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05,-0.05}); + NDArray dLdwExp('c', {2,3,4}, {-1.296,-1.2 ,-1.104,-1.008,-0.912,-0.816,-0.72 ,-0.624,-0.528,-0.432,-0.336,-0.24 , + -0.144,-0.048, 0.048, 0.144, 0.24 , 0.336, 0.432, 0.528, 0.624, 0.72 , 0.816, 0.912}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); + + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test9) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {-0.02083, -0.02083, -0.02083, -0.02083,-0.02083, -0.02083, -0.02083, -0.02083,-0.02083, -0.02083, -0.02083, -0.02083, + -0.02083, -0.02083, -0.02083, -0.02083,-0.02083, -0.02083, -0.02083, -0.02083,-0.02083, -0.02083, -0.02083, -0.02083}); + NDArray dLdwExp('c', {2,3,4}, {0.04, 0.08, 0.12, 0.16, 0.2 , 0.24, 0.28, 0.32,0.36, 0.4 , 0.44, 0.48, + 0.52, 0.56, 0.6 , 0.64,0.68, 0.72, 0.76, 0.8 ,0.84, 0.88, 0.92, 0.96}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test10) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {1,1}, sd::DataType::DOUBLE); + + NDArray dLdwExp('c', {1,1}, std::vector{12.}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdw = results.at(1); + + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test11) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); + + NDArray dLdwExp('c', {1,3,1}, {2.72, 4., 5.28}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdw = results.at(1); + + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test12) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {0., 0., 0., 0., -0.025, -0.025, -0.025, -0.025,-0.025, -0.025, -0.025, -0.025, + -0.025, -0.025, -0.025, -0.025,-0.025, -0.025, -0.025, -0.025,-0.025, -0.025, -0.025, -0.025}); + NDArray dLdwExp('c', {2,3,4}, {0.048, 0.096, 0.144, 0.192,0.24 , 0.288, 0.336, 0.384,0.432, 0.48 , 0.528, 0.576, + 0.624, 0.672, 0.72 , 0.768,0.816, 0.864, 0.912, 0.96 ,1.008, 1.056, 1.104, 1.152}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.r(0) = 0.; + weights.r(1) = 0.; + weights.r(2) = 0.; + weights.r(3) = 0.; + + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, absolute_difference_loss_grad_test13) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2,3,1}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., + -0.04167, -0.04167, -0.04167, -0.04167,-0.04167, -0.04167, -0.04167, -0.04167,-0.04167, -0.04167, -0.04167, -0.04167}); + NDArray dLdwExp('c', {2,3,1}, {0.8 ,2.08,3.36,4.64,5.92,7.2 }); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.r(0) = 0.; + weights.r(1) = 0.; + weights.r(2) = 0.; + + sd::ops::absolute_difference_loss_grad op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdpExp.isSameShape(-*dLdl)); + ASSERT_TRUE(dLdpExp.equalsTo(-*dLdl)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, BFloat16_Test_1) { + + NDArray x = NDArrayFactory::create('c', {2,3,4}); + NDArray y = NDArrayFactory::create('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16); + NDArray exp = NDArrayFactory::create('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16); + + x.linspace(1); + y.linspace(1); + exp.linspace(2,2); + sd::ops::add op; + auto results = op.evaluate({&x, &y}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto res = results.at(0); + ASSERT_TRUE(res->equalsTo(exp)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, BFloat16_Test_2) { + + NDArray x = NDArrayFactory::create('c', {2,3,4}); + NDArray y = NDArrayFactory::create('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16); + NDArray exp = NDArrayFactory::create('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16); + + x.linspace(1); + y.linspace(1); + exp.linspace(2,2); + sd::ops::add op; + auto results = op.evaluate({&x, &y}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto res = results.at(0); + ASSERT_TRUE(res->equalsTo(exp)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, BFloat16_Test_3) { + + NDArray x('c', {2,3,4}, sd::DataType::BFLOAT16); + NDArray y('c', {2,3,4}, sd::DataType::BFLOAT16); + NDArray exp('c', {2,3,4}, sd::DataType::BFLOAT16); + + x.linspace(1); + y.linspace(1); + exp.linspace(2,2); + sd::ops::add op; + auto results = op.evaluate({&x, &y}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto res = results.at(0); + ASSERT_TRUE(res->equalsTo(exp)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test1) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {-0.25999, -0.755 , -1.25 , -1.745 , -2.24001, -2.73502, -3.23004, -3.72508, -4.22014, -4.71523, -5.21034, -5.70548, + -6.20066, -6.69587, -7.19113, -7.68643, -8.18177, -8.67717, -9.17262, -9.66813,-10.1637 ,-10.65932,-11.15501,-11.65077}); + NDArray dLdwExp('c', {2,3,4}, {0.73395, 0.75335, 0.69315, 0.55335, 0.33395, 0.03495, -0.34366, -0.80186, -1.33967, -1.95708, -2.65411, -3.43074, + -4.28698, -5.22285, -6.23833, -7.33343, -8.50815, -9.76251,-11.0965 ,-12.51013,-14.00341,-15.57633,-17.2289 ,-18.96113}); + NDArray dLdlExp('c', {2,3,4}, {0.04, 0.02,-0. ,-0.02,-0.04,-0.06,-0.08,-0.1 ,-0.12,-0.14,-0.16,-0.18, + -0.2 ,-0.22,-0.24,-0.26,-0.28,-0.3 ,-0.32,-0.34,-0.36,-0.38,-0.4 ,-0.42}); + + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test2) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2,1,4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {-0.18499,-0.53 ,-0.875 ,-1.22 ,-1.56501,-1.91002,-2.25504,-2.60008,-2.94514,-3.29023,-3.63534,-3.98048, + -4.32566,-4.67087,-5.01613,-5.36143,-5.70677,-6.05217,-6.39762,-6.74313,-7.0887 ,-7.43432,-7.78001,-8.12577}); + NDArray dLdwExp('c', {2,1,4}, {0.43622, -0.19079, -0.98462, -1.94525,-18.09855,-20.72768,-23.52373,-26.48669}); + NDArray dLdlExp('c', {2,3,4}, {0.028, 0.014, -0. , -0.014,-0.028, -0.042, -0.056, -0.07 ,-0.084, -0.098, -0.112, -0.126, + -0.14 , -0.154, -0.168, -0.182,-0.196, -0.21 , -0.224, -0.238,-0.252, -0.266, -0.28 , -0.294}); + + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test3) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights(sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {-0.18499,-0.53 ,-0.875 ,-1.22 ,-1.56501,-1.91002,-2.25504,-2.60008,-2.94514,-3.29023,-3.63534,-3.98048, + -4.32566,-4.67087,-5.01613,-5.36143,-5.70677,-6.05217,-6.39762,-6.74313,-7.0887 ,-7.43432,-7.78001,-8.12577}); + NDArray dLdwExp('c', {}, std::vector{-91.52109}); + NDArray dLdlExp('c', {2,3,4}, {0.028, 0.014, -0., -0.014,-0.028, -0.042, -0.056, -0.07 ,-0.084, -0.098, -0.112, -0.126, + -0.14 , -0.154, -0.168, -0.182,-0.196, -0.21 , -0.224, -0.238,-0.252, -0.266, -0.28 , -0.294}); + + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test4) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); + + NDArray dLdwExp('c', {1,3,1}, {-12.54779,-28.13393,-50.83936}); + + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdw = results.at(1); + + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test5) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {-0.01542,-0.04417,-0.07292,-0.10167,-0.13042,-0.15917,-0.18792,-0.21667,-0.24543,-0.27419,-0.30294,-0.33171, + -0.36047,-0.38924,-0.41801,-0.44679,-0.47556,-0.50435,-0.53314,-0.56193,-0.59072,-0.61953,-0.64833,-0.67715}); + NDArray dLdwExp('c', {2,3,4}, {0.37794, 0.37906, 0.37554, 0.36739, 0.35461, 0.33719, 0.31514, 0.28846, 0.25714, 0.22119, 0.18061, 0.13539, + 0.08553, 0.03104,-0.02808,-0.09184,-0.16023,-0.23326,-0.31093,-0.39323,-0.48017,-0.57175,-0.66796,-0.76881}); + NDArray dLdlExp('c', {2,3,4}, {0.00233, 0.00117,-0.,-0.00117,-0.00233,-0.0035 ,-0.00467,-0.00583,-0.007 ,-0.00817,-0.00933,-0.0105, + -0.01167,-0.01283,-0.014 ,-0.01517,-0.01633,-0.0175 ,-0.01867,-0.01983,-0.021 ,-0.02217,-0.02333,-0.0245}); + + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test6) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); + + NDArray dLdwExp('c', {1,3,1}, {1.4966 , 0.19776,-1.69436}); + + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdw = results.at(1); + + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test7) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights(sd::DataType::DOUBLE); + + NDArray dLdwExp('c', {}, std::vector{0.}); + + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdw = results.at(1); + + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test8) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, { 0. , 0. , 0. , 0. ,-0.1565 ,-0.191 ,-0.2255 ,-0.26001,-0.29451,-0.32902,-0.36353,-0.39805, + -0.43257,-0.46709,-0.50161,-0.53614,-0.57068,-0.60522,-0.63976,-0.67431,-0.70887,-0.74343,-0.778 ,-0.81258}); + NDArray dLdwExp('c', {2,3,4}, {0.54353, 0.54487, 0.54065, 0.53087, 0.51553, 0.49463, 0.46817, 0.43615, 0.39857, 0.35543, 0.30672, 0.25246, + 0.19264, 0.12725, 0.0563 ,-0.02021,-0.10228,-0.18992,-0.28312,-0.38188,-0.48621,-0.5961 ,-0.71156,-0.83258}); + NDArray dLdlExp('c', {2,3,4}, {-0. ,-0. , 0. , 0. ,-0.0028,-0.0042,-0.0056,-0.007 ,-0.0084,-0.0098,-0.0112,-0.0126, + -0.014 ,-0.0154,-0.0168,-0.0182,-0.0196,-0.021 ,-0.0224,-0.0238,-0.0252,-0.0266,-0.028 ,-0.0294}); + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); + + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test9) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {-0.00771, -0.02208, -0.03646, -0.05083,-0.06521, -0.07958, -0.09396, -0.10834,-0.12271, -0.13709, -0.15147, -0.16585, + -0.18024, -0.19462, -0.20901, -0.22339,-0.23778, -0.25217, -0.26657, -0.28096,-0.29536, -0.30976, -0.32417, -0.33857}); + NDArray dLdwExp('c', {2,3,4}, {0.03008, 0.03064, 0.02888, 0.02481, 0.01841, 0.00971, -0.00132, -0.01466,-0.03032, -0.0483 , -0.06859, -0.0912 , + -0.11612, -0.14337, -0.17293, -0.20481,-0.23901, -0.27552, -0.31435, -0.35551,-0.39898, -0.44476, -0.49287, -0.5433 }); + NDArray dLdlExp('c', {2,3,4}, {0.00117, 0.00058, -0. , -0.00058,-0.00117, -0.00175, -0.00233, -0.00292,-0.0035 , -0.00408, -0.00467, -0.00525, + -0.00583, -0.00642, -0.007 , -0.00758,-0.00817, -0.00875, -0.00933, -0.00992,-0.0105 , -0.01108, -0.01167, -0.01225}); + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test10) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {1,1}, sd::DataType::DOUBLE); + + NDArray dLdwExp('c', {1,1}, std::vector{-3.81338}); + + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdw = results.at(1); + + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test11) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); + + NDArray dLdwExp('c', {1,3,1}, {-0.52282,-1.17225,-2.11831}); + + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdw = results.at(1); + + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test12) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2,3,4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {0. , 0. , 0. , 0. ,-0.07825, -0.0955 , -0.11275, -0.13 ,-0.14726, -0.16451, -0.18177, -0.19902, + -0.21628, -0.23354, -0.25081, -0.26807,-0.28534, -0.30261, -0.31988, -0.33716,-0.35443, -0.37172, -0.389 , -0.40629}); + NDArray dLdwExp('c', {2,3,4}, {0.0361 , 0.03677, 0.03466, 0.02977, 0.0221 , 0.01165, -0.00158, -0.01759,-0.03638, -0.05795, -0.08231, -0.10944, + -0.13935, -0.17204, -0.20752, -0.24577,-0.28681, -0.33063, -0.37723, -0.42661,-0.47877, -0.53372, -0.59144, -0.65196}); + NDArray dLdlExp('c', {2,3,4}, {-0. , -0. , 0. , 0. ,-0.0014, -0.0021, -0.0028, -0.0035,-0.0042, -0.0049, -0.0056, -0.0063, + -0.007 , -0.0077, -0.0084, -0.0091,-0.0098, -0.0105, -0.0112, -0.0119,-0.0126, -0.0133, -0.014 , -0.0147}); + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.r(0) = 0.; + weights.r(1) = 0.; + weights.r(2) = 0.; + weights.r(3) = 0.; + + + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, sigm_cross_entropy_loss_grad_test13) { + + NDArray labels('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2,3,1}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , + -0.36047, -0.38924, -0.41801, -0.44679,-0.47556, -0.50435, -0.53314, -0.56193,-0.59072, -0.61953, -0.64833, -0.67715}); + NDArray dLdwExp('c', {2,3,1}, {0.22882, 0.02428,-0.4768 ,-1.27447,-2.36878,-3.75981,}); + NDArray dLdlExp('c', {2,3,4}, {-0. , -0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0., + -0.01167, -0.01283, -0.014 , -0.01517,-0.01633, -0.0175 , -0.01867, -0.01983,-0.021 , -0.02217, -0.02333, -0.0245}); + logits.linspace(-0.08, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.r(0) = 0.; + weights.r(1) = 0.; + weights.r(2) = 0.; + + sd::ops::sigm_cross_entropy_loss_grad op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, BFloat16_Test_4) { + + NDArray x = NDArrayFactory::create('c', {2,3,4}); + NDArray y = NDArrayFactory::create('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16); + NDArray exp = NDArrayFactory::create('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16); + + x.linspace(1); + y.linspace(1); + exp.linspace(2,2); + sd::ops::add op; + auto results = op.evaluate({&x, &y}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto res = results.at(0); + ASSERT_TRUE(res->equalsTo(exp)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, BFloat16_Test_5) { + + NDArray x = NDArrayFactory::create('c', {2,3,4}); + NDArray y = NDArrayFactory::create('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16); + NDArray exp = NDArrayFactory::create('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16); + + x.linspace(2, 2); + y.linspace(1); + exp.linspace(1); + sd::ops::subtract op; + auto results = op.evaluate({&x, &y}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto res = results.at(0); + ASSERT_TRUE(res->equalsTo(exp)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, BFloat16_Test_6) { + + NDArray x = NDArrayFactory::create('c', {2,3,4}); + NDArray y = NDArrayFactory::create('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16); + NDArray exp = NDArrayFactory::create('c', {2,3,4});//('c', {2,3,4}, sd::DataType::BFLOAT16); + + x.linspace(2, 2); + y.linspace(1); + exp.linspace(1); + sd::ops::subtract op; + auto results = op.evaluate({&x, &y}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto res = results.at(0); + ASSERT_TRUE(res->equalsTo(exp)); +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test1) { + + NDArray labels('c', {2,4}, {0,0,1,0, 0,1,0,0}, sd::DataType::INT32); + NDArray logits('c', {2,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,4}, {0.1176, 0.1224, -0.3726, 0.1326, 0.1176, -0.3776, 0.1274, 0.1326}); + NDArray dLdwExp('c', {2}, {1.36729, 1.40729}); + + logits.linspace(-0.08, 0.04); + weights.assign(0.5); + + sd::ops::softmax_cross_entropy_loss_grad op; + + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test2) { + + NDArray labels('c', {4}, {0,0,1,0}, sd::DataType::INT32); + NDArray logits('c', {4}, sd::DataType::DOUBLE); + NDArray weights('c', {1}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125}); + NDArray dLdwExp('c', {1}, std::vector{1.38629}); + + logits = 2.; + weights.assign(0.5); + + sd::ops::softmax_cross_entropy_loss_grad op; + + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test3) { + + NDArray labels('c', {4}, {0,0,1,0}, sd::DataType::INT32); + NDArray logits('c', {4}, sd::DataType::DOUBLE); + NDArray weights('c', {}, std::vector{0}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {4}, {0.125, 0.125, -0.375, 0.125}); + NDArray dLdwExp('c', {}, std::vector{1.38629}); + + logits = 2.; + weights.assign(0.5); + + sd::ops::softmax_cross_entropy_loss_grad op; + + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test4) { + + NDArray labels('c', {4}, {0,0,1,0}, sd::DataType::INT32); + NDArray logits('c', {4}, sd::DataType::DOUBLE); + NDArray weights('c', {}, std::vector{0}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {4}, {0.23521, 0.2448 , -0.7452 , 0.26519}); + NDArray dLdwExp('c', {}, std::vector{0.}); + + logits.linspace(-0.08, 0.04); + weights = 0.5; + + sd::ops::softmax_cross_entropy_loss_grad op; + + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test5) { + + NDArray labels('c', {4}, {0,0,1,0}, sd::DataType::INT32); + NDArray logits('c', {4}, sd::DataType::DOUBLE); + NDArray weights('c', {1}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {4}, {0.1176, 0.1224, -0.3726, 0.1326}); + NDArray dLdwExp('c', {1}, std::vector{1.36729}); + + logits.linspace(-0.08, 0.04); + weights = 0.5; + + sd::ops::softmax_cross_entropy_loss_grad op; + + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test6) { + + NDArray labels('c', {2,4}, {0,0,1,0, 0,1,0,0}, sd::DataType::INT32); + NDArray logits('c', {2,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,4}, {0.0801, 0.0849, -0.2601, 0.0951, 0.0801, -0.2651, 0.0899, 0.0951}); + NDArray dLdwExp('c', {2}, {-0.014000, 0.014000}); + + logits.linspace(-0.08, 0.04); + weights.assign(0.5); + + sd::ops::softmax_cross_entropy_loss_grad op; + + auto results = op.evaluate({&logits, &weights, &labels}, {0.3}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test7) { + + NDArray labels('c', {2,3,4}, {1,0,0,0, 0,1,0,0, 0,0,1,0, 0,0,0,1, 1,0,0,0, 0,1,0,0}, sd::DataType::INT32); + NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {1,3}, {0.5, 0., 1.5}); + + NDArray dLdpExp('c', {2,3,4}, {-0.0956 , 0.0306 , 0.03185, 0.03315, 0.,-0., 0., 0., 0.0882 , 0.0918 ,-0.27945, 0.09945, + 0.0294 , 0.0306 , 0.03185,-0.09185,-0., 0., 0., 0., 0.0882 ,-0.2832 , 0.09555, 0.09945}); + NDArray dLdwExp('c', {1,3}, {0.69365, 0.71365, 0.69365}); + + logits.linspace(-0.08, 0.04); + + sd::ops::softmax_cross_entropy_loss_grad op; + + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, softmax_cross_entropy_loss_grad_test8) { + + NDArray labels('c', {2,3,4,5}, {1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0, + 0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1, + 0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,1,0,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0,1,0}, sd::DataType::INT32); + + NDArray logits('c', {2,3,4,5}, sd::DataType::DOUBLE); + NDArray weights('c', {1,1,4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4,5}, {-0.03399, 0.00799, 0.00832, 0.00866, 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, + 0.00866, 0.00901, 0.00768, 0.00799, 0.00832,-0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866,-0.03265,-0.03399, + 0.00799, 0.00832, 0.00866, 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, 0.00866, + 0.00901, 0.00768, 0.00799, 0.00832,-0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866,-0.03265,-0.03399, 0.00799, + 0.00832, 0.00866, 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, 0.00866, 0.00901, + 0.00768, 0.00799, 0.00832,-0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866,-0.03265,-0.03399, 0.00799, 0.00832, + 0.00866, 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, 0.00866, 0.00901, 0.00768, + 0.00799, 0.00832,-0.03301, 0.00901, 0.00768, 0.00799, 0.00832, 0.00866,-0.03265,-0.03399, 0.00799, 0.00832, 0.00866, + 0.00901, 0.00768,-0.03367, 0.00832, 0.00866, 0.00901, 0.00768, 0.00799,-0.03335, 0.00866, 0.00901, 0.00768, 0.00799, 0.00832,-0.03301, 0.00901}); + + NDArray dLdwExp('c', {1,1,4}, {0.005, 0.00167, -0.00167, -0.005}); + logits.linspace(-0.08, 0.04); + weights.assign(0.5); + + sd::ops::softmax_cross_entropy_loss_grad op; + + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + // dLdp->printIndexedBuffer(); + + // ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + // ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, SafeDivideMixed_Test1) { + + NDArray labels('c', {2, 3}, {1.0, 2.0, 3.0, -1.0, 2.0, 1.0}); + auto sumDiff = labels.reduceAlongDimension(reduce::Sum, {1}, true); + + NDArray numOfNonZero(sumDiff.shapeInfo(), sd::DataType::INT64, false); + numOfNonZero.assign(1); + sumDiff.applyPairwiseTransform(pairwise::SafeDivide, numOfNonZero, sumDiff); +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test1) { + + NDArray labels('c', {2,3,4}, {1,0,0,0, 0,1,0,0, 0,0,1,0, 0,0,0,1, 1,0,0,0, 0,1,0,0}); + NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {-0.76479, 0.2448, 0.2548, 0.26519, 0.23521,-0.7552, 0.2548, 0.26519, 0.23521, 0.2448,-0.7452, 0.26519, + 0.23521, 0.2448, 0.2548,-0.73481,-0.76479, 0.2448, 0.2548, 0.26519, 0.23521,-0.7552, 0.2548, 0.26519}); + logits.linspace(-0.08, 0.04); + + sd::ops::softmax_cross_entropy_loss_with_logits_grad op; + + auto results = op.evaluate({&logits, &labels}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test2) { + + NDArray labels('c', {2,3,4}, {1,0,0,0, 0,1,0,1, 0,0,1,0, 0,0,0,1, 1,0,1,0, 0,1,0,0}); + NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {-0.71836, 0.28164, 0.28164, 0.28164, 0.33051, -0.66949, 0.33051, -0.66949, 0.38785, 0.38785, -0.61215, 0.38785, + 0.28164, 0.28164, 0.28164, -0.71836,-0.66949, 0.33051, -0.66949, 0.33051, 0.38785, -0.61215, 0.38785, 0.38785}); + logits.linspace(-0.08, 0.04); + + sd::ops::softmax_cross_entropy_loss_with_logits_grad op; + + auto results = op.evaluate({&logits, &labels}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test3) { + + NDArray labels('c', {2,3}, {1,0,0, 0,1,1}); + NDArray logits('c', {2,3}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3}, {-0.52996, 0.47004, 0.47004, 0.52996, -0.47004, -0.47004}); + logits.linspace(-0.08, 0.04); + + sd::ops::softmax_cross_entropy_loss_with_logits_grad op; + + auto results = op.evaluate({&logits, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test4) { + + NDArray labels('c', {2,1}, {1,1}); + NDArray logits('c', {2,1}, {-0.04, 0.04}); + + NDArray dLdpExp('c', {2,1}, {0., 0.}); + + sd::ops::softmax_cross_entropy_loss_with_logits_grad op; + + auto results = op.evaluate({&logits, &labels}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test5) { + + NDArray labels('c', {2,1}, std::vector{1,0}); + NDArray logits('c', {2,1}, {-0.04, 0.04}); + + NDArray dLdpExp('c', {2,1}, {-0.51999, 0.51999}); + + sd::ops::softmax_cross_entropy_loss_with_logits_grad op; + + auto results = op.evaluate({&logits, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test6) { + + NDArray labels('c', {1,2}, {1,1.}); + NDArray logits('c', {1,2}, {-0.04, 0.04}); + + NDArray dLdpExp('c', {1,2}, {0, 0.}); + + sd::ops::softmax_cross_entropy_loss_with_logits_grad op; + + auto results = op.evaluate({&logits, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test7) { + + NDArray labels('c', {2}, {0,1}); + NDArray logits('c', {2}, {-0.04, 0.04}); + + NDArray dLdpExp('c', {2}, {0.48001, -0.48001}); + + sd::ops::softmax_cross_entropy_loss_with_logits_grad op; + + auto results = op.evaluate({&logits, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, softmaxCrossEntropyWithLogits_grad_test8) { + + NDArray labels('c', {1}, std::vector{1}); + NDArray logits('c', {1}, std::vector{0.04}); + + NDArray dLdpExp('c', {1}, std::vector{0}); + + sd::ops::softmax_cross_entropy_loss_with_logits_grad op; + + auto results = op.evaluate({&logits, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, Multiply_BP_Test1) { + + NDArray x('c', {3,4,5}, sd::DataType::DOUBLE); + NDArray y('c', {1,1,1}, sd::DataType::DOUBLE); + + NDArray dLdp('c', {3,4,5}, sd::DataType::DOUBLE); + NDArray dLdpExp('c', {3,4,5}, sd::DataType::DOUBLE); + + x.assign(1.0);//linspace(0.1, 0.1); + y.assign(1.0); + dLdp.assign(1.0); + dLdpExp.assign(1.0); + sd::ops::multiply_bp op; + + auto results = op.evaluate({&x, &y, &dLdp}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdo = results.at(0); + ASSERT_TRUE(dLdpExp.isSameShape(dLdo)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdo)); + +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test1) { + + NDArray labels('c', {2}, {2,1}, sd::DataType::INT64); + NDArray logits('c', {2,3}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3}, {0.30061, 0.33222, -0.63283, 0.30061, -0.66778, 0.36717}); + + logits.linspace(0.1, 0.1); + + sd::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; + + auto results = op.evaluate({&labels, &logits}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test2) { + + NDArray labels('c', {2}, {0,1}, sd::DataType::INT64); + NDArray logits('c', {2,3}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3}, {-0.69939, 0.33222, 0.36717, 0.30061, -0.66778, 0.36717}); + + logits.linspace(-0.1, 0.1); + + sd::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; + + auto results = op.evaluate({&labels, &logits}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test3) { + + NDArray labels('c', {}, std::vector{1}, sd::DataType::INT64); + NDArray logits('c', {2}, {-0.2, 0.3}); + + NDArray dLdpExp('c', {2}, {0.37754, -0.37754}); + + sd::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; + + auto results = op.evaluate({&labels, &logits}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test4) { + + NDArray labels('c', {2,3}, {0,1,1, 3,3,2}, sd::DataType::INT64); + NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {-0.78616, 0.23633, 0.26118, 0.28865, 0.21384, -0.76367, 0.26118, 0.28865, 0.21384, -0.76367, 0.26118, 0.28865, + 0.21384, 0.23633, 0.26118, -0.71135, 0.21384, 0.23633, 0.26118, -0.71135, 0.21384, 0.23633, -0.73882, 0.28865}); + logits.linspace(-0.5, 0.1); + + sd::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; + + auto results = op.evaluate({&labels, &logits}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests11, sparseSoftmaxCrossEntropyWithLogits_grad_test5) { + + NDArray labels('c', {1,1}, std::vector({0}), sd::DataType::INT64); + NDArray logits('c', {1,1,2}, {-0.3,0.2}); + + NDArray dLdpExp('c', {1,1,2}, {-0.62246, 0.62246}); + + sd::ops::sparse_softmax_cross_entropy_loss_with_logits_grad op; + + auto results = op.evaluate({&labels, &logits}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + +} + diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests12.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests12.cpp new file mode 100644 index 000000000..be6a32ee1 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests12.cpp @@ -0,0 +1,3469 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +// +// Created by raver on 8/4/2018. +// + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace sd; + + +class DeclarableOpsTests12 : public testing::Test { +public: + + DeclarableOpsTests12() { + printf("\n"); + fflush(stdout); + } +}; + +TEST_F(DeclarableOpsTests12, test_any_validation_1) { + auto x = NDArrayFactory::create('c', {2, 1}, {1.0, 2.0}); + auto y = NDArrayFactory::create('c', {2}, {1, 0}); + + sd::ops::transpose op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_EQ(x.dataType(), z->dataType()); + + +} + + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test1) { + + NDArray labels('c', {2,4}, {0,1,1,0,1,0,1,0}); + NDArray predictions('c', {2,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2,1}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,4}, {-0. , -0.5, -0.5, -0., -0.5, -0. , -0.5, -0.}); + NDArray dLdwExp('c', {2,1}, {1.2, -0.2}); + + predictions.linspace(-0.4, 0.2); + weights.assign(0.5); + + sd::ops::cosine_distance_loss_grad op; + + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, -1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + + +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test2) { + + NDArray labels('c', {2,4}, {-0.1, 0.3, 2, -1.4, 2.5, -3, 1.2, 2.2}); + NDArray predictions('c', {2,4}, sd::DataType::DOUBLE); + NDArray weights('c', {1,4}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,4}, {0.05, -0.15, -1. , 0.7 ,-1.25, 1.5 , -0.6 , -1.1 }); + NDArray dLdwExp('c', {1,4}, {-0.04, 2.86, 0.04, -0.92}); + NDArray dLdlExp('c', {2,4}, {0.2, 0.1, 0. , -0.1, -0.2, -0.3, -0.4, -0.5}); + + predictions.linspace(-0.4, 0.2); + weights.assign(0.5); + + sd::ops::cosine_distance_loss_grad op; + + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + + +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test3) { + + NDArray labels('c', {4}, {-0.1, 0.3, 2, -1.4}); + NDArray predictions('c', {4}, sd::DataType::DOUBLE); + NDArray weights('c', {1}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {4}, {0.05, -0.15, -1., 0.7}); + NDArray dLdwExp('c', {1}, std::vector{1.3}); + NDArray dLdlExp('c', {4}, {0.2, 0.1, -0. , -0.1}); + + predictions.linspace(-0.4, 0.2); + weights.assign(0.5); + + sd::ops::cosine_distance_loss_grad op; + + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + + +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test4) { + + NDArray labels('c', {1,4}, {-0.1, 0.3, 2, -1.4}); + NDArray predictions('c', {1,4}, sd::DataType::DOUBLE); + NDArray weights('c', {}, std::vector{0.}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {1,4}, {0.05, -0.15, -1., 0.7}); + NDArray dLdwExp('c', {}, std::vector{1.3}); + NDArray dLdlExp('c', {1,4}, {0.2, 0.1, -0. , -0.1}); + + predictions.linspace(-0.4, 0.2); + weights.assign(0.5); + + sd::ops::cosine_distance_loss_grad op; + + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + + +} + + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test5) { + + NDArray labels('c', {4}, {-0.1, 0.3, 2, -1.4}, sd::DataType::DOUBLE); + NDArray predictions('c', {4}, sd::DataType::DOUBLE); + NDArray weights('c', {1,1}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {4}, {0.1, -0.3, -2. , 1.4}); + NDArray dLdwExp('c', {1,1}, std::vector{0.}); + NDArray dLdlExp('c', {4}, {0.4, 0.2, -0. , -0.2}); + + predictions.linspace(-0.4, 0.2); + weights = 0.5; + + sd::ops::cosine_distance_loss_grad op; + + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + + +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test6) { + + NDArray labels('c', {4,1}, {-0.1, 0.3, 2, -1.4}, sd::DataType::DOUBLE); + NDArray predictions('c', {4,1}, sd::DataType::DOUBLE); + NDArray weights('c', {4,1}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {4,1}, {0.0125, -0.0375, -0.25 , 0.175}); + NDArray dLdwExp('c', {4,1}, {0.24 , 0.265, 0.25 , 0.32}); + NDArray dLdlExp('c', {4,1}, {0.05 , 0.025, -0. , -0.025}); + + predictions.linspace(-0.4, 0.2); + weights = 0.5; + + sd::ops::cosine_distance_loss_grad op; + + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + + +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test7) { + + NDArray labels('c', {2,3,4}, {-0.1, 0.3, 2, -1.4, 2.5, -3, 1.2, 2.2,-0.1, 0.3, 2, -3.4, 2.5, -3, 1.2, 2.2,-0.2, 0.3, 2, -1.4, 2.7, -3, 1.2, 4.2}); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {1,3,1}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {0.00833, -0.025 , -0.16667, 0.11667,-0.20833, 0.25 , -0.1 , -0.18333, 0.00833, -0.025 , -0.16667, 0.28333, + -0.20833, 0.25 , -0.1 , -0.18333, 0.01667, -0.025 , -0.16667, 0.11667,-0.225 , 0.25 , -0.1 , -0.35 }); + NDArray dLdwExp('c', {1,3,1}, {0.50444, 0.89778, -1.40222}); + NDArray dLdlExp('c', {2,3,4}, {0.03333, 0.01667, -0. , -0.01667,-0.03333, -0.05 , -0.06667, -0.08333,-0.1, -0.11667, -0.13333, -0.15, + -0.16667, -0.18333, -0.2 , -0.21667,-0.23333, -0.25 , -0.26667, -0.28333,-0.3, -0.31667, -0.33333, -0.35 }); + + predictions.linspace(-0.4, 0.2); + weights = 0.5; + + sd::ops::cosine_distance_loss_grad op; + + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + + +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test8) { + + NDArray labels('c', {2,3,4}, {-0.1, 0.3, 2, -1.4, 2.5, -3, 1.2, 2.2,-0.1, 0.3, 2, -3.4, 2.5, -3, 1.2, 2.2,-0.2, 0.3, 2, -1.4, 2.7, -3, 1.2, 4.2}); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2,1,1}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {0.00625, -0.01875, -0.125 , 0.0875,-0.15625, 0.1875 , -0.075 , -0.1375, 0.00625, -0.01875, -0.125 , 0.2125, + -0.15625, 0.1875 , -0.075 , -0.1375, 0.0125 , -0.01875, -0.125 , 0.0875,-0.16875, 0.1875 , -0.075 , -0.2625}); + NDArray dLdwExp('c', {2,1,1}, {0.57, -3.2175}); + NDArray dLdlExp('c', {2,3,4}, {0.025, 0.0125, -0. , -0.0125,-0.025, -0.0375, -0.05, -0.0625,-0.075, -0.0875, -0.1 , -0.1125, + -0.125, -0.1375, -0.15, -0.1625,-0.175, -0.1875, -0.2 , -0.2125,-0.225, -0.2375, -0.25, -0.2625}); + + predictions.linspace(-0.4, 0.2); + weights = 0.5; + + sd::ops::cosine_distance_loss_grad op; + + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + + +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, cosine_distance_loss_grad_test9) { + + NDArray labels('c', {2,3,4}, {-0.1, 0.3, 2, -1.4, 2.5, -3, 1.2, 2.2,-0.1, 0.3, 2, -3.4, 2.5, -3, 1.2, 2.2,-0.2, 0.3, 2, -1.4, 2.7, -3, 1.2, 4.2}); + NDArray predictions('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {2,3,1}, sd::DataType::DOUBLE); + + NDArray dLdpExp('c', {2,3,4}, {0.05, -0.15, -1. , 0.7,-1.25, 1.5 , -0.6 , -1.1, 0.05, -0.15, -1. , 1.7, + -1.25, 1.5 , -0.6 , -1.1, 0.1 , -0.15, -1. , 0.7,-1.35, 1.5 , -0.6 , -2.1}); + NDArray dLdwExp('c', {2,3,1}, {1.3 , -1.36, 3.62, -6. , -0.98,-19.76}); + NDArray dLdlExp('c', {2,3,4}, {0.2, 0.1, -0. , -0.1,-0.2, -0.3, -0.4, -0.5,-0.6, -0.7, -0.8, -0.9, + -1. , -1.1, -1.2, -1.3,-1.4, -1.5, -1.6, -1.7,-1.8, -1.9, -2. , -2.1}); + + predictions.linspace(-0.4, 0.2); + weights = 0.5; + + sd::ops::cosine_distance_loss_grad op; + + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0, 2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *dLdp = results.at(0); + auto *dLdw = results.at(1); + auto *dLdl = results.at(2); + + ASSERT_TRUE(dLdpExp.isSameShape(dLdp)); + ASSERT_TRUE(dLdpExp.equalsTo(dLdp)); + ASSERT_TRUE(dLdwExp.isSameShape(dLdw)); + ASSERT_TRUE(dLdwExp.equalsTo(dLdw)); + ASSERT_TRUE(dLdlExp.isSameShape(dLdl)); + ASSERT_TRUE(dLdlExp.equalsTo(dLdl)); + + +} + + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, hinge_loss_14) { + + NDArray logits('c', {3,4}, sd::DataType::DOUBLE); + NDArray weights('c', {}, std::vector{1.}); + NDArray labels('c', {3,4}, {0,1,1,0,1,0,1,0,1,0,1,0}); + + NDArray output('c', {}, std::vector{0.}, sd::DataType::DOUBLE); + + logits.linspace(1.); + weights.assign(1.); + + sd::ops::hinge_loss op; + Nd4jStatus status = op.execute({&logits, &weights, &labels}, {&output}, {}, {1}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + + ASSERT_TRUE(output.e(0) == 47.); +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, TestDivideBP_1) { + + NDArray x('c', {3,4}, sd::DataType::DOUBLE); + NDArray y = NDArrayFactory::create(2.); + NDArray eps('c', {3,4}, sd::DataType::DOUBLE); + + NDArray output1('c', {3, 4}, sd::DataType::DOUBLE); + NDArray output2(sd::DataType::DOUBLE); + + x.linspace(2., 2.); + eps.linspace(1.); + + sd::ops::divide_bp op; + Nd4jStatus status = op.execute({&x, &y, &eps}, {&output1, &output2}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + //ASSERT_TRUE(output.e(0) == 47.); +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, TestDivideBP_2) { + + NDArray x('c', {3,4}, sd::DataType::DOUBLE); + NDArray y = NDArrayFactory::create('c', {3,4}); + NDArray eps('c', {3,4}, sd::DataType::DOUBLE); + NDArray exp1('c', {3,4}, sd::DataType::DOUBLE); + NDArray exp2('c', {3,4}, sd::DataType::DOUBLE); + NDArray output1('c', {3, 4}, sd::DataType::DOUBLE); + NDArray output2('c', {3, 4}, sd::DataType::DOUBLE); + exp1.assign(1.); + exp2.assign(-2.); + x.linspace(2., 2.); + y.linspace(1.); + eps.linspace(1.); + + sd::ops::divide_bp op; + Nd4jStatus status = op.execute({&x, &y, &eps}, std::vector{&output1, &output2}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(output1.equalsTo(exp1)); + ASSERT_TRUE(output2.equalsTo(exp2)); +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, TestReverseDivideBP_1) { + + NDArray x('c', {3,4}, sd::DataType::DOUBLE); + NDArray y = NDArrayFactory::create(2.); + NDArray eps('c', {3,4}, sd::DataType::DOUBLE); + + NDArray output1('c', {3, 4}, sd::DataType::DOUBLE); + NDArray output2(sd::DataType::DOUBLE); + + x.linspace(2., 2.); + eps.linspace(1.); + + sd::ops::reversedivide_bp op; + Nd4jStatus status = op.execute({&y, &x, &eps}, std::vector{&output2, &output1}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + //ASSERT_TRUE(output.e(0) == 47.); +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, TestReverseDivideBP_2) { + + NDArray x('c', {3,4}, sd::DataType::DOUBLE); + NDArray y = NDArrayFactory::create('c', {3,4}); + NDArray eps('c', {3,4}, sd::DataType::DOUBLE); + NDArray exp1('c', {3,4}, sd::DataType::DOUBLE); + NDArray exp2('c', {3,4}, sd::DataType::DOUBLE); + + NDArray output1('c', {3, 4}, sd::DataType::DOUBLE); + NDArray output2('c', {3, 4}, sd::DataType::DOUBLE); + + x.linspace(2., 2.); + y.linspace(1.); + eps.linspace(1.); + exp1.assign(1.); + exp2.assign(-2.); + sd::ops::reversedivide_bp op; + Nd4jStatus status = op.execute({&y, &x, &eps}, std::vector{&output2, &output1}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(output1.equalsTo(exp1)); + ASSERT_TRUE(output2.equalsTo(exp2)); +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, TestSliceBP_1) { + + NDArray x('c', {3,4}, sd::DataType::DOUBLE); + NDArray eps('c', {2,2}, sd::DataType::DOUBLE); + NDArray exp('c', {3,4}, {0., 0., 0., 0., 0., 1.,1., 0., 0., 1., 1., 0.}); + //NDArray exp2('c', {3,4}, sd::DataType::DOUBLE); + + NDArray output('c', {3, 4}, sd::DataType::DOUBLE); + //NDArray output2('c', {3, 4}, sd::DataType::DOUBLE); + output.assign(119.113); + x.linspace(1.); + eps.assign(1.); + //exp1.assign(1.); + //exp2.assign(-2.); + sd::ops::slice_bp op; + Nd4jStatus status = op.execute({&x, &eps}, {&output}, {}, {1,1,2,2}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(output.equalsTo(exp)); + //ASSERT_TRUE(output2.equalsTo(exp2)); +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, TestConfusionZero_1) { + + NDArray x('c', {2}, {1,2}, sd::DataType::INT64); + NDArray i('c', {2}, {0,2}, sd::DataType::INT64); + //NDArray eps('c', {2,2}, sd::DataType::DOUBLE); + NDArray exp('c', {4,4}, {0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0}, sd::DataType::INT64); + //NDArray exp2('c', {3,4}, sd::DataType::DOUBLE); + + NDArray output('c', {4, 4}, sd::DataType::INT64); + //NDArray output2('c', {3, 4}, sd::DataType::DOUBLE); + output.assign(119.113); + x.linspace(1.); + //eps.assign(1.); + //exp1.assign(1.); + //exp2.assign(-2.); + sd::ops::confusion_matrix op; + Nd4jStatus status = op.execute({&x, &i}, {&output}, {}, {4}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(output.equalsTo(exp)); + //ASSERT_TRUE(output2.equalsTo(exp2)); +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, TestMaximumBP_1) { + + NDArray x('c', {3,4}, sd::DataType::DOUBLE); + NDArray y('c', {3,4}, sd::DataType::DOUBLE); + NDArray eps('c', {3,4}, sd::DataType::DOUBLE); + NDArray exp1('c', {3,4}, {0, 0, 0, 0, 0, 0, 7, 8, 9, 10, 11, 12}, sd::DataType::DOUBLE); + NDArray exp2('c', {3,4}, {1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0}, sd::DataType::DOUBLE); + + NDArray output1('c', {3, 4}, sd::DataType::DOUBLE); + NDArray output2('c', {3, 4}, sd::DataType::DOUBLE); + output1.assign(119); + x.linspace(1.); + y.linspace(12., -1.); + eps.linspace(1.); + //exp1.assign(1.); + //exp2.assign(-2.); + sd::ops::maximum_bp op; + Nd4jStatus status = op.execute({&x, &y, &eps}, std::vector{&output1, &output2}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(output1.equalsTo(exp1)); + ASSERT_TRUE(output2.equalsTo(exp2)); +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, TestMinimumBP_1) { + + NDArray x('c', {3,4}, sd::DataType::DOUBLE); + NDArray y('c', {3,4}, sd::DataType::DOUBLE); + NDArray eps('c', {3,4}, sd::DataType::DOUBLE); + NDArray exp1('c', {3,4}, {0, 0, 0, 0, 0, 0, 7, 8, 9, 10, 11, 12}, sd::DataType::DOUBLE); + NDArray exp2('c', {3,4}, {1, 2, 3, 4, 5, 6, 0, 0, 0, 0, 0, 0}, sd::DataType::DOUBLE); + + NDArray output1('c', {3, 4}, sd::DataType::DOUBLE); + NDArray output2('c', {3, 4}, sd::DataType::DOUBLE); + output1.assign(119); + x.linspace(1.); + y.linspace(12., -1.); + eps.linspace(1.); + //exp1.assign(1.); + //exp2.assign(-2.); + sd::ops::minimum_bp op; + Nd4jStatus status = op.execute({&x, &y, &eps}, std::vector{&output2, &output1}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(output1.equalsTo(exp1)); + ASSERT_TRUE(output2.equalsTo(exp2)); +} + + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, reverse_test15) { + + NDArray x('c', {5}, {1,2,3,4,5}, sd::DataType::DOUBLE); + NDArray axis('c', {}, std::vector{0}, sd::DataType::INT32); + NDArray z('c', {5}, sd::DataType::DOUBLE); + NDArray exp('c', {5}, {5,4,3,2,1}, sd::DataType::DOUBLE); + + + sd::ops::reverse op; + // auto result = op.execute({&x, &axis}, {}, {1}, {}); + Nd4jStatus status = op.execute({&x, &axis}, {&z}, {}, {1}, {}); + // auto z = result.at(0); + // z->printIndexedBuffer(); + + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + // +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, mirrorPad_test17) { + + NDArray x('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::DOUBLE); + NDArray padding('c', {2,2}, {1,1,2,2}, sd::DataType::INT64); + NDArray z('c', {4,7}, sd::DataType::DOUBLE); + NDArray exp1('c', {4,7}, {6, 5, 4, 5, 6, 5, 4,3, 2, 1, 2, 3, 2, 1,6, 5, 4, 5, 6, 5, 4,3, 2, 1, 2, 3, 2, 1}, sd::DataType::DOUBLE); + NDArray exp2('c', {4,7}, {2, 1, 1, 2, 3, 3, 2,2, 1, 1, 2, 3, 3, 2,5, 4, 4, 5, 6, 6, 5,5, 4, 4, 5, 6, 6, 5}, sd::DataType::DOUBLE); + + sd::ops::mirror_pad op; + Nd4jStatus status = op.execute({&x, &padding}, {&z}, {}, {0}, {}); // reflect + + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(exp1.isSameShape(z)); + ASSERT_TRUE(exp1.equalsTo(z)); + + z = 0.; + status = op.execute({&x, &padding}, {&z}, {}, {1}, {}); // symmetric + + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(exp2.isSameShape(z)); + ASSERT_TRUE(exp2.equalsTo(z)); +} + +///////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, mirrorPad_test18) { + + NDArray x('c', {3}, {1,2,3}, sd::DataType::DOUBLE); + NDArray padding('c', {1, 2}, {1,1}, sd::DataType::INT32); + NDArray z('c', {5}, sd::DataType::DOUBLE); + NDArray exp('c', {5}, {2,1,2,3,2}, sd::DataType::DOUBLE); + + sd::ops::mirror_pad op; + Nd4jStatus status = op.execute({&x, &padding}, {&z}, {}, {0}, {}); // reflect + + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, relu_1) { + + NDArray input('c', {1,5,5,6}, { 0.557449, 0.768277, 1.094015, -0.557449, -0.768277, -1.094015,0.563735, 0.900299, 0.789979, -0.563735, -0.900299, -0.789979, + 0.142528, 0.959611, 0.877506, -0.142528, -0.959611, -0.877506,0.448742, 0.995377, 1.171543, -0.448742, -0.995377, -1.171543, + 0.603772, 0.799391, 0.560310, -0.603772, -0.799391, -0.560310,0.529753, 0.906786, 0.737630, -0.529753, -0.906786, -0.737630, + 0.221464, 0.824996, 0.472221, -0.221464, -0.824996, -0.472221,0.427730, 0.397933, 0.714365, -0.427730, -0.397933, -0.714365, + 0.488365, 1.016589, 0.744197, -0.488365, -1.016589, -0.744197,0.789846, 0.940837, 0.838412, -0.789846, -0.940837, -0.838412, + 0.404485, 0.677328, 0.754997, -0.404485, -0.677328, -0.754997,0.436760, 0.794765, 0.729766, -0.436760, -0.794765, -0.729766, + 0.588081, 0.652226, 0.725522, -0.588081, -0.652226, -0.725522,0.374457, 1.225813, 1.053411, -0.374457, -1.225813, -1.053411, + 0.300958, 0.599417, 0.633234, -0.300958, -0.599417, -0.633234,0.241993, 1.025464, 0.695378, -0.241993, -1.025464, -0.695378, + 0.236289, 0.907919, 1.012100, -0.236289, -0.907919, -1.012100,0.627402, 0.565187, 0.766926, -0.627402, -0.565187, -0.766926, + 0.133276, 0.326284, 0.102804, -0.133276, -0.326284, -0.102804,0.426913, 0.256251, 0.305241, -0.426913, -0.256251, -0.305241, + 0.177977, 0.841799, 0.800615, -0.177977, -0.841799, -0.800615,0.001991, 0.518389, 0.439322, -0.001991, -0.518389, -0.439322, + 0.166846, 0.508224, 0.486687, -0.166846, -0.508224, -0.486687,0.167493, 0.930932, 0.868717, -0.167493, -0.930932, -0.868717, + 0.174864, 0.444607, 0.445000, -0.174864, -0.444607, -0.445000}, sd::DataType::FLOAT32); + + NDArray expected('c', {1,5,5,6}, { 0.557449, 0.768277, 1.094015, 0., 0., 0., 0.563735, 0.900299, 0.789979, 0., 0., 0., + 0.142528, 0.959611, 0.877506, 0., 0., 0., 0.448742, 0.995377, 1.171543, 0., 0., 0., + 0.603772, 0.799391, 0.560310, 0., 0., 0., 0.529753, 0.906786, 0.737630, 0., 0., 0., + 0.221464, 0.824996, 0.472221, 0., 0., 0., 0.427730, 0.397933, 0.714365, 0., 0., 0., + 0.488365, 1.016589, 0.744197, 0., 0., 0., 0.789846, 0.940837, 0.838412, 0., 0., 0., + 0.404485, 0.677328, 0.754997, 0., 0., 0., 0.436760, 0.794765, 0.729766, 0., 0., 0., + 0.588081, 0.652226, 0.725522, 0., 0., 0., 0.374457, 1.225813, 1.053411, 0., 0., 0., + 0.300958, 0.599417, 0.633234, 0., 0., 0., 0.241993, 1.025464, 0.695378, 0., 0., 0., + 0.236289, 0.907919, 1.012100, 0., 0., 0., 0.627402, 0.565187, 0.766926, 0., 0., 0., + 0.133276, 0.326284, 0.102804, 0., 0., 0., 0.426913, 0.256251, 0.305241, 0., 0., 0., + 0.177977, 0.841799, 0.800615, 0., 0., 0., 0.001991, 0.518389, 0.439322, 0., 0., 0., + 0.166846, 0.508224, 0.486687, 0., 0., 0., 0.167493, 0.930932, 0.868717, 0., 0., 0., + 0.174864, 0.444607, 0.445000, 0., 0., 0.}, sd::DataType::FLOAT32); + + NDArray z('c', {1,5,5,6}, sd::DataType::FLOAT32); + + sd::ops::relu op; + Nd4jStatus status = op.execute({&input}, {&z}, {0}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.isSameShapeStrict(z)); + ASSERT_TRUE(expected.equalsTo(z)); +} + +#include "ops/declarable/helpers/multiUnique.h" +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, multiUnique_1) { + + NDArray input1('c', {3,5}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, sd::DataType::INT32); + NDArray input2('c', {3,4}, {1,2,3,4,5,6,7,8,9,10,11,12}, sd::DataType::INT32); + NDArray input3('c', {2,3}, {10,11,12,13,14,15}, sd::DataType::INT32); + NDArray input4('c', {1,5}, {7,8,9,10,11}, sd::DataType::INT32); + NDArray input5('c', {5,3}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, sd::DataType::INT32); + + //NDArray indices('c', {1}, {2}, sd::DataType::INT32); + //NDArray expected('c', {1,5}, {11, 12, 13, 14, 15.}, sd::DataType::FLOAT32); + + std::vector arrayList({&input1, &input2, &input3, &input4, &input5}); + + ASSERT_FALSE(sd::ops::helpers::multiUnique(arrayList)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, multiUnique_2) { + + NDArray input1('c', {3,5}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, sd::DataType::INT32); + NDArray input2('c', {3,4}, {21,22,23,24,25,26,27,28,29,210,211,212}, sd::DataType::INT32); + NDArray input3('c', {2,3}, {310,311,312,313,314,315}, sd::DataType::INT32); + NDArray input4('c', {1,5}, {47,48,49,410,411}, sd::DataType::INT32); + NDArray input5('c', {5,3}, {51,52,53,54,55,56,57,58,59,510,511,512,513,514,515}, sd::DataType::INT32); + + //NDArray indices('c', {1}, {2}, sd::DataType::INT32); + //NDArray expected('c', {1,5}, {11, 12, 13, 14, 15.}, sd::DataType::FLOAT32); + + std::vector arrayList({&input1, &input2, &input3, &input4, &input5}); + ASSERT_TRUE(sd::ops::helpers::multiUnique(arrayList)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, reduceMeanBp_4) { + + NDArray x('c', {3,5}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}); + NDArray gradO('c', {5}, sd::DataType::DOUBLE); + NDArray exp('c', {3,5}, sd::DataType::DOUBLE); + + gradO = 1.; + exp = 0.333333; + + sd::ops::reduce_mean_bp op; + auto result = op.evaluate({&x, &gradO}, {}, {0}); + auto output = result.at(0); + + // output->printShapeInfo(); + // output->printIndexedBuffer(); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, reduceMeanBp_5) { + + NDArray x('c', {3,5}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}); + NDArray gradO('c', {3}, sd::DataType::DOUBLE); + NDArray exp('c', {3,5}, sd::DataType::DOUBLE); + + gradO = 1.; + exp = 0.2; + + sd::ops::reduce_mean_bp op; + auto result = op.evaluate({&x, &gradO}, {}, {1}); + auto output = result.at(0); + + // output->printShapeInfo(); + // output->printIndexedBuffer(); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, reduceSqnormBp_1) { + + NDArray x('c', {8,6,4}, sd::DataType::DOUBLE); + NDArray gradO('c', {8,6,1}, sd::DataType::DOUBLE); + + sd::ops::reduce_sqnorm_bp op; + auto result = op.evaluate({&x, &gradO}, {1}, {2}); + ASSERT_EQ(Status::OK(), result.status()); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, pullRows_1) { + + NDArray x('c', {5, 1}, {0,1,2,3,4}); + NDArray z('c', {4, 1}, sd::DataType::DOUBLE); + NDArray exp('c', {4, 1}, {0,2,3,4}); + + Nd4jLong indexes[] = {0,2,3,4}; + PointersManager pm(LaunchContext::defaultContext(), "pullRows"); + auto pidx = reinterpret_cast(pm.replicatePointer(indexes, 4 * sizeof(Nd4jLong))); + + std::vector dims = {1}; + + auto xTadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), dims); + auto zTadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(z.shapeInfo(), dims); + + Nd4jPointer nativeStart[2]; + +#ifdef __CUDABLAS__ + nativeStart[1] = (x.getContext()->getCudaStream()); +#endif + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + pullRows(nativeStart, &xBuf, x.shapeInfo(), x.specialShapeInfo(), + &zBuf, z.shapeInfo(), z.specialShapeInfo(), + 4, pidx, + xTadPack.platformShapeInfo(), xTadPack.platformOffsets(), + zTadPack.platformShapeInfo(), zTadPack.platformOffsets()); + + ASSERT_TRUE(z.equalsTo(exp)); + pm.synchronize(); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, pullRows_2) { + + NDArray arr('f', {5, 2}, {0,1,2,3,4,5,6,7,8,9}); + NDArray* y = new NDArray(arr.dup('c')); + NDArray x = (*y)({0,0, 0,1}, true); // view, points on first column of y, shape is {5,1} + + NDArray z('c', {4, 1}, sd::DataType::DOUBLE); + NDArray exp('c', {4, 1}, {0,2,3,4}); + + Nd4jLong indexes[] = {0,2,3,4}; + PointersManager pm(LaunchContext::defaultContext(), "pullRows"); + auto pidx = reinterpret_cast(pm.replicatePointer(indexes, 4 * sizeof(Nd4jLong))); + + std::vector dims = {1}; + + auto xTadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), dims); + auto zTadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(z.shapeInfo(), dims); + + Nd4jPointer nativeStart[2]; +#ifdef __CUDABLAS__ + nativeStart[1] = (x.getContext()->getCudaStream()); +#endif + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + pullRows(nativeStart, &xBuf, x.shapeInfo(), x.specialShapeInfo(), + &zBuf, z.shapeInfo(), z.specialShapeInfo(), + 4, pidx, + xTadPack.platformShapeInfo(), xTadPack.platformOffsets(), + zTadPack.platformShapeInfo(), zTadPack.platformOffsets()); + + ASSERT_TRUE(z.equalsTo(exp)); + pm.synchronize(); + delete y; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, softmax_9) { + NDArray arrC('c', {5,2}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6, -0.7, 0.8, -0.9, 1}, sd::DataType::FLOAT32); + NDArray* arrF = new NDArray(arrC.dup('f')); + + NDArray outCC('c', {5,2}, sd::DataType::FLOAT32); + NDArray outCF('f', {5,2}, sd::DataType::FLOAT32); + NDArray outFC('c', {5,2}, sd::DataType::FLOAT32); + NDArray outFF('c', {5,2}, sd::DataType::FLOAT32); + + sd::ops::softmax op; + auto status1 = op.execute({&arrC}, {&outCC}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status1); + auto status2 = op.execute({&arrC}, {&outCF}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status2); + auto status3 = op.execute({arrF}, {&outFC}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status3); + auto status4 = op.execute({arrF}, {&outFF}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status4); + + // outCC.printIndexedBuffer("\n"); + // outCF.printIndexedBuffer("\n"); + // outFC.printIndexedBuffer("\n"); + // outFF.printIndexedBuffer("\n"); + + ASSERT_EQ(outCC, outCF); + ASSERT_EQ(outCC, outFC); + ASSERT_EQ(outCC, outFF); + + delete arrF; +} + +TEST_F(DeclarableOpsTests12, maxpool_bp_half_1) { + auto x = NDArrayFactory::create('c', {2, 3, 10, 1}, {0.2019043f, 0.6464844f, 0.9116211f, 0.60058594f, 0.34033203f, 0.7036133f, 0.6772461f, 0.3815918f, 0.87353516f, 0.04650879f, 0.67822266f, 0.8618164f, 0.88378906f, 0.7573242f, 0.66796875f, 0.63427734f, 0.33764648f, 0.46923828f, 0.62939453f, 0.76464844f, -0.8618164f, -0.94873047f, -0.9902344f, -0.88916016f, -0.86572266f, -0.92089844f, -0.90722656f, -0.96533203f, -0.97509766f, -0.4975586f, -0.84814453f, -0.984375f, -0.98828125f, -0.95458984f, -0.9472656f, -0.91064453f, -0.80859375f, -0.83496094f, -0.9140625f, -0.82470703f, 0.4802246f, 0.45361328f, 0.28125f, 0.28320312f, 0.79345703f, 0.44604492f, -0.30273438f, 0.11730957f, 0.56396484f, 0.73583984f, 0.1418457f, -0.44848633f, 0.6923828f, -0.40234375f, 0.40185547f, 0.48632812f, 0.14538574f, 0.4638672f, 0.13000488f, 0.5058594f}); + auto y = NDArrayFactory::create('c', {2, 3, 10, 1}, {0.0f, -0.13391113f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, -0.1751709f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.51904297f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.5107422f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}); + auto z = NDArrayFactory::create('c', {2, 3, 10, 1}); + + sd::ops::maxpool2d_bp op; + Context ctx(1); + Nd4jLong iArgs[] = {5,1,1, 2,2,0, 1,1,1, 0,0}; + ctx.setIArguments(iArgs, 11); + ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); + ctx.setInputArray(1, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + + + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, lrn_bp_1) { + + NDArray input('c', {2,3,4,10}); + NDArray gradO('c', {2,3,4,10}); + NDArray exp('c', {2,3,4,10}, {1.00438418e-02, 5.25184907e-03, 1.78685773e-03, -1.14537543e-03, -4.00071684e-03, -5.31899510e-03, -4.97647980e-03, -4.42161644e-03, -3.95395281e-03, -3.59310722e-03, 2.91823584e-04, -2.18498681e-05, -3.12092161e-04, -6.07360795e-04, -9.36298165e-04, + -1.02553482e-03, -7.91735307e-04, -6.15672267e-04, -4.71792649e-04, -3.42114770e-04, 4.29357824e-05, -5.46473675e-05, -1.48361753e-04, -2.47166492e-04, -3.61090642e-04, -3.81607766e-04, -2.89086485e-04, -2.17203109e-04, -1.56231865e-04, -9.91634734e-05, + 8.99407951e-06, -3.76849275e-05, -8.32021178e-05, -1.31939698e-04, -1.89008832e-04, -1.96661276e-04, -1.47534331e-04, -1.08789405e-04, -7.53896020e-05, -4.36357586e-05, + 1.23124300e-06, -2.60028974e-05, -5.27824741e-05, -8.17063192e-05, -1.15871291e-04, -1.19515295e-04, -8.91248055e-05, -6.49499125e-05, -4.39216528e-05, -2.37579407e-05, -9.34046056e-07, -1.87477999e-05, -3.63574763e-05, -5.54830040e-05, -7.82010393e-05, + -8.02115537e-05, -5.95739621e-05, -4.30659420e-05, -2.86241393e-05, -1.47010251e-05, -1.52835810e-06, -1.40790498e-05, -2.65316012e-05, -4.01083526e-05, -5.62983550e-05, -5.75223821e-05, -4.25982689e-05, -3.06141737e-05, -2.00884024e-05, -9.90276021e-06, + -1.61666367e-06, -1.09328157e-05, -2.02010433e-05, -3.03347279e-05, -4.24536738e-05, -4.32532870e-05, -3.19610226e-05, -2.28673853e-05, -1.48570880e-05, -7.08444895e-06, + -1.53552355e-06, -8.72318924e-06, -1.58886232e-05, -2.37402273e-05, -3.31507035e-05, -3.37014644e-05, -2.48602537e-05, -1.77248403e-05, -1.14254890e-05, -5.30027773e-06, -1.40318230e-06, -7.11624580e-06, -1.28209140e-05, -1.90826468e-05, -2.66006646e-05, + -2.69959855e-05, -1.98865000e-05, -1.41387427e-05, -9.05554589e-06, -4.10473058e-06, -1.26330860e-06, -5.91293519e-06, -1.05618501e-05, -1.56718652e-05, -2.18157675e-05, -2.21090413e-05, -1.62681827e-05, -1.15394150e-05, -7.35144840e-06, -3.26711961e-06, + -1.13179840e-06, -4.98940426e-06, -8.85062400e-06, -1.30997241e-05, -1.82144904e-05, -1.84380206e-05, -1.35542105e-05, -9.59566933e-06, -6.08572736e-06, -2.65887866e-06, + -1.01367493e-06, -4.26561428e-06, -7.52358210e-06, -1.11123145e-05, -1.54364170e-05, -1.56106762e-05, -1.14666063e-05, -8.10436813e-06, -5.12021325e-06, -2.20401580e-06, -9.09635219e-07, -3.68808492e-06, -6.47385696e-06, -9.54499774e-06, -1.32485484e-05, + -1.33870126e-05, -9.82651000e-06, -6.93532820e-06, -4.36710525e-06, -1.85539375e-06, -8.18735487e-07, -3.22003825e-06, -5.62928972e-06, -8.28724023e-06, -1.14948289e-05, -1.16066676e-05, -8.51461300e-06, -6.00201292e-06, -3.76846447e-06, -1.58258263e-06, + -7.39498375e-07, -2.83553072e-06, -4.93973403e-06, -7.26259532e-06, -1.00675643e-05, -1.01591886e-05, -7.44886802e-06, -5.24508141e-06, -3.28481428e-06, -1.36524977e-06, + -6.70378654e-07, -2.51585061e-06, -4.36947221e-06, -6.41683391e-06, -8.89049170e-06, -8.96649362e-06, -6.57134478e-06, -4.62275193e-06, -2.88851857e-06, -1.18941352e-06, -6.09944266e-07, -2.24723408e-06, -3.89250545e-06, -5.71062310e-06, -7.90838203e-06, + -7.97212033e-06, -5.84020108e-06, -4.10491293e-06, -2.55976192e-06, -1.04521314e-06, -5.56935277e-07, -2.01937837e-06, -3.48954882e-06, -5.11487451e-06, -7.08044308e-06, -7.13442114e-06, -5.22460778e-06, -3.66942504e-06, -2.28403951e-06, -9.25535005e-07, + -5.10270809e-07, -1.82444705e-06, -3.14605040e-06, -4.60769843e-06, -6.37601988e-06, -6.42213308e-06, -4.70144141e-06, -3.29971408e-06, -2.05053857e-06, -8.25151346e-07, + -4.69036365e-07, -1.65639949e-06, -2.85086708e-06, -4.17237243e-06, -5.77171340e-06, -5.81141694e-06, -4.25308644e-06, -2.98317354e-06, -1.85106614e-06, -7.40148607e-07, -4.32460268e-07, -1.51051631e-06, -2.59534818e-06, -3.79594053e-06, -5.24941379e-06, + -5.28384317e-06, -3.86593183e-06, -2.71007866e-06, -1.67932183e-06, -6.67554332e-07, -3.99893480e-07, -1.38306928e-06, -2.37269478e-06, -3.46823890e-06, -4.79492701e-06, -4.82497671e-06, -3.52932648e-06, -2.47282924e-06, -1.53039912e-06, -6.05077048e-07, + -3.70789934e-07, -1.27108103e-06, -2.17750403e-06, -3.18120783e-06, -4.39700398e-06, -4.42338614e-06, -3.23483960e-06, -2.26541715e-06, -1.40042869e-06, -5.50929371e-07}); + input.linspace(1); + gradO = 1; + + sd::ops::lrn_bp op; + + auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {5}); + auto gradI = results.at(0); + + ASSERT_EQ(*gradI, exp); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, lrn_bp_2) { + + NDArray input('c', {2,3,4,10}); + NDArray gradO('c', {2,3,4,10}); + NDArray exp('c', {2,3,4,10}, {-1.06179598e-03, -2.70050880e-03, -4.02126182e-03, -2.58826977e-03, -2.16024881e-03, -2.20575323e-03, -2.75954953e-03, -4.42477595e-03, -2.89176637e-03, -9.46942251e-04, -1.32603094e-03, -3.34868953e-03, -4.98152524e-03, -3.21313459e-03, -2.68880837e-03, -2.75207381e-03, -3.45109636e-03, -5.54159656e-03, -3.61320702e-03, -1.16457068e-03, + -1.70158676e-03, -4.26037982e-03, -6.33032294e-03, -4.09416296e-03, -3.43742501e-03, -3.52900685e-03, -4.43827361e-03, -7.13911094e-03, -4.64041065e-03, -1.46419462e-03, -2.26016506e-03, -5.59943309e-03, -8.30824208e-03, -5.39253885e-03, -4.54709725e-03, -4.68666852e-03, -5.91615774e-03, -9.53640230e-03, -6.17204653e-03, -1.89000927e-03, + -3.14102764e-03, -7.67878769e-03, -1.13740638e-02, -7.41857197e-03, -6.29213545e-03, -6.51977258e-03, -8.27047508e-03, -1.33656031e-02, -8.59564263e-03, -2.51553906e-03, -4.64272872e-03, -1.11560747e-02, -1.64905936e-02, -1.08321551e-02, -9.26420093e-03, -9.67171416e-03, -1.23506878e-02, -2.00199075e-02, -1.27442302e-02, -3.45497206e-03, + -7.49545777e-03, -1.76018942e-02, -2.59558801e-02, -1.72390267e-02, -1.49321631e-02, -1.57669969e-02, -2.03234926e-02, -3.30405571e-02, -2.06389092e-02, -4.78462130e-03, -1.38390735e-02, -3.14943902e-02, -4.63354364e-02, -3.13667879e-02, -2.77508944e-02, -2.98541505e-02, -3.89749333e-02, -6.32867143e-02, -3.77952419e-02, -5.26650995e-03, + -3.16195861e-02, -6.90807998e-02, -1.01725549e-01, -7.13700354e-02, -6.54785037e-02, -7.25797564e-02, -9.49372798e-02, -1.47399038e-01, -7.21285641e-02, 2.15010419e-02, -8.06625858e-02, -1.79638922e-01, -2.66877055e-01, -1.64447501e-01, -1.00968637e-01, -2.75682062e-02, 1.13596700e-01, 3.32260162e-01, 5.96845448e-01, 8.13161016e-01, + 9.52381015e-01, 8.13161016e-01, 5.96845508e-01, 3.32260162e-01, 1.13596708e-01, -2.75682174e-02, -1.37202948e-01, -2.71326721e-01, -1.84127048e-01, -7.94974267e-02, 3.29870060e-02, -7.39035010e-02, -1.60488203e-01, -1.04997143e-01, -8.06594491e-02, -7.25797564e-02, -7.87955597e-02, -1.11791104e-01, -7.58660138e-02, -3.48676592e-02, + -4.96974029e-03, -4.04525958e-02, -6.82792515e-02, -4.20900472e-02, -3.21968049e-02, -2.98541524e-02, -3.36477235e-02, -4.95737195e-02, -3.37007530e-02, -1.48636252e-02, -4.92655952e-03, -2.17927732e-02, -3.49853337e-02, -2.15152260e-02, -1.66727621e-02, -1.57669988e-02, -1.81730352e-02, -2.73226351e-02, -1.85334161e-02, -7.91355036e-03, + -3.57114570e-03, -1.33136865e-02, -2.09431648e-02, -1.29161589e-02, -1.01064872e-02, -9.67171136e-03, -1.12970043e-02, -1.71830691e-02, -1.16271935e-02, -4.84848116e-03, -2.59314431e-03, -8.91274121e-03, -1.38697922e-02, -8.58002994e-03, -6.75992295e-03, -6.51977304e-03, -7.68158771e-03, -1.17703741e-02, -7.94785097e-03, -3.25604435e-03, + -1.94202550e-03, -6.36530807e-03, -9.84015409e-03, -6.10316684e-03, -4.83274320e-03, -4.68666898e-03, -5.55526093e-03, -8.55536573e-03, -5.76688722e-03, -2.33053416e-03, -1.50016253e-03, -4.76644421e-03, -7.33569637e-03, -4.55961144e-03, -3.62428720e-03, -3.52900638e-03, -4.20164689e-03, -6.49448857e-03, -4.37143166e-03, -1.74761284e-03, + -1.19028054e-03, -3.69978836e-03, -5.67591935e-03, -3.53418733e-03, -2.81759514e-03, -2.75207404e-03, -3.28776496e-03, -5.09600528e-03, -3.42601724e-03, -1.35771628e-03, -9.65878542e-04, -2.95373448e-03, -4.52052988e-03, -2.81889434e-03, -2.25270819e-03, -2.20575323e-03, -2.64216494e-03, -4.10421193e-03, -2.75646802e-03, -1.08450721e-03, + -7.98697409e-04, -2.41194153e-03, -3.68447183e-03, -2.30037421e-03, -1.84193184e-03, -1.80714857e-03, -2.16938392e-03, -3.37567786e-03, -2.26523401e-03, -8.85842834e-04, -6.71049987e-04, -2.00629188e-03, -3.06024216e-03, -1.91263494e-03, -1.53396139e-03, -1.50748459e-03, -1.81288645e-03, -2.82496959e-03, -1.89429161e-03, -7.36965681e-04, + -5.71501616e-04, -1.69480499e-03, -2.58198148e-03, -1.61517004e-03, -1.29717519e-03, -1.27655920e-03, -1.53747783e-03, -2.39865575e-03, -1.60740130e-03, -6.22576685e-04, -4.92433901e-04, -1.45049067e-03, -2.20754091e-03, -1.38200901e-03, -1.11122860e-03, -1.09486456e-03, -1.32032647e-03, -2.06194492e-03, -1.38099224e-03, -5.32818493e-04}); + + input.linspace(-10, 0.1); + gradO = 1; + + sd::ops::lrn_bp op; + + auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {2}); + auto gradI = results.at(0); + + ASSERT_EQ(*gradI, exp); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, lrn_bp_3) { + + NDArray input('c', {2,3,4,10}); + NDArray gradO('c', {2,3,4,10}); + NDArray exp('c', {2,3,4,10}, {-6.78180193e-04, -1.06947345e-03, -1.50362519e-03, -1.47711602e-03, -1.45060697e-03, -1.42409769e-03, -1.39758852e-03, -1.37107936e-03, -8.79839936e-04, -4.27795108e-04, -8.62496032e-04, -1.34585891e-03, -1.88281795e-03, -1.84591592e-03, -1.80901436e-03, -1.77211256e-03, -1.73521065e-03, -1.69830909e-03, -1.08184782e-03, -5.13895764e-04, + -1.13227055e-03, -1.74428569e-03, -2.42520543e-03, -2.37169350e-03, -2.31818156e-03, -2.26466986e-03, -2.21115816e-03, -2.15764646e-03, -1.36136822e-03, -6.26647263e-04, -1.54878304e-03, -2.34815548e-03, -3.23930010e-03, -3.15753091e-03, -3.07576265e-03, -2.99399323e-03, -2.91222427e-03, -2.83045508e-03, -1.76287338e-03, -7.75904860e-04, + -2.23870482e-03, -3.32566188e-03, -4.54067392e-03, -4.40674182e-03, -4.27281018e-03, -4.13887901e-03, -4.00494691e-03, -3.87101574e-03, -2.36659218e-03, -9.72117065e-04, -3.49745504e-03, -5.05724549e-03, -6.80746930e-03, -6.56589260e-03, -6.32431870e-03, -6.08274434e-03, -5.84116904e-03, -5.59959421e-03, -3.32604628e-03, -1.21081201e-03, + -6.14068285e-03, -8.55270587e-03, -1.12749329e-02, -1.07723922e-02, -1.02698486e-02, -9.76730697e-03, -9.26476624e-03, -8.76222178e-03, -4.94601438e-03, -1.37539487e-03, -1.30690653e-02, -1.72132626e-02, -2.19351258e-02, -2.06174850e-02, -1.92998387e-02, -1.79821979e-02, -1.66645572e-02, -1.53469117e-02, -7.72346184e-03, -5.22134826e-04, + -3.99478227e-02, -4.78655733e-02, -5.70126995e-02, -5.16961850e-02, -4.63796593e-02, -4.10631336e-02, -3.57466117e-02, -3.04300785e-02, -9.11374856e-03, 1.14024431e-02, -2.35893592e-01, -2.17480078e-01, -1.88097835e-01, -1.38812393e-01, -8.95269737e-02, -4.02415469e-02, 9.04385652e-03, 5.83292767e-02, 1.78530529e-01, 2.96026409e-01, + 4.16666657e-01, 2.79557735e-01, 1.36546940e-01, 7.49502778e-02, 1.33536234e-02, -4.82430384e-02, -1.09839723e-01, -1.71436355e-01, -2.33033031e-01, -2.74476141e-01, 1.54189002e-02, -8.10869783e-03, -3.24862264e-02, -3.88403721e-02, -4.51945364e-02, -5.15486896e-02, -5.79028539e-02, -6.42570183e-02, -5.45457527e-02, -4.61437553e-02, + -2.29711179e-04, -8.06892477e-03, -1.63567103e-02, -1.78351123e-02, -1.93135180e-02, -2.07919199e-02, -2.22703181e-02, -2.37487257e-02, -1.87229179e-02, -1.43175106e-02, -1.37000845e-03, -5.16320160e-03, -9.21433326e-03, -9.76086594e-03, -1.03073996e-02, -1.08539313e-02, -1.14004640e-02, -1.19469995e-02, -9.08647850e-03, -6.55380823e-03, + -1.23490533e-03, -3.45137389e-03, -5.83263952e-03, -6.09064987e-03, -6.34865928e-03, -6.60666777e-03, -6.86467718e-03, -7.12268520e-03, -5.30054048e-03, -3.67741752e-03, -9.94500006e-04, -2.44303374e-03, -4.00528917e-03, -4.14666394e-03, -4.28803731e-03, -4.42941114e-03, -4.57078544e-03, -4.71215881e-03, -3.45545518e-03, -2.33156094e-03, + -7.93270417e-04, -1.81236281e-03, -2.91444198e-03, -3.00004939e-03, -3.08565609e-03, -3.17126350e-03, -3.25687067e-03, -3.34247784e-03, -2.42513884e-03, -1.60246110e-03, -6.39747130e-04, -1.39506557e-03, -2.21352675e-03, -2.26921216e-03, -2.32489733e-03, -2.38058274e-03, -2.43626791e-03, -2.49195332e-03, -1.79354590e-03, -1.16592250e-03, + -5.23828785e-04, -1.10576022e-03, -1.73730974e-03, -1.77553250e-03, -1.81375467e-03, -1.85197743e-03, -1.89020019e-03, -1.92842260e-03, -1.37922564e-03, -8.84913374e-04, -4.35433642e-04, -8.97393096e-04, -1.39935245e-03, -1.42670958e-03, -1.45406683e-03, -1.48142409e-03, -1.50878134e-03, -1.53613824e-03, -1.09309505e-03, -6.93831593e-04, + -3.66991735e-04, -7.42538832e-04, -1.15100679e-03, -1.17125409e-03, -1.19150116e-03, -1.21174823e-03, -1.23199564e-03, -1.25224248e-03, -8.87364266e-04, -5.58210537e-04, -3.13144788e-04, -6.24410110e-04, -9.63238359e-04, -9.78639582e-04, -9.94040747e-04, -1.00944215e-03, -1.02484343e-03, -1.04024459e-03, -7.34565372e-04, -4.58585098e-04, + -2.70129647e-04, -5.32291830e-04, -8.17865424e-04, -8.29851197e-04, -8.41836852e-04, -8.53822567e-04, -8.65808397e-04, -8.77794111e-04, -6.18013146e-04, -3.83307983e-04, -2.35282409e-04, -4.59096394e-04, -7.03040219e-04, -7.12549896e-04, -7.22059398e-04, -7.31569016e-04, -7.41078693e-04, -7.50588137e-04, -5.27105702e-04, -3.25074652e-04}); + + input.linspace(-10, 0.1); + gradO = 1; + + sd::ops::lrn_bp op; + + auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {7}); + auto gradI = results.at(0); + + ASSERT_EQ(*gradI, exp); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, lrn_bp_4) { + + NDArray input('c', {2,3,4,10}); + NDArray gradO('c', {2,3,4,10}); + NDArray exp('c', {2,3,4,10}, {-0.00119282, -0.00116995, -0.00114708, -0.00112421, -0.00110134, -0.00107847, -0.00105559, -0.00103272, -0.00100985, -0.00098698, -0.00150102, -0.00146918, -0.00143734, -0.0014055 , -0.00137366, -0.00134182, -0.00130998, -0.00127814, -0.0012463 , -0.00121446, + -0.00194534,-0.00189916, -0.00185299, -0.00180681, -0.00176064, -0.00171446, -0.00166829, -0.00162211, -0.00157593, -0.00152976, -0.0026189 , -0.00254833, -0.00247776, -0.00240719, -0.00233662, -0.00226605, -0.00219548, -0.00212491, -0.00205434, -0.00198377, + -0.00370962, -0.00359401, -0.00347839, -0.00336277, -0.00324716, -0.00313154, -0.00301593, -0.00290031, -0.00278469, -0.00266908, -0.00564327, -0.00543464, -0.00522602, -0.00501739, -0.00480876, -0.00460013, -0.0043915 , -0.00418288, -0.00397425, -0.00376562, + -0.00955302, -0.00911865, -0.00868428, -0.00824992, -0.00781555, -0.00738118, -0.00694682, -0.00651245, -0.00607808, -0.00564371, -0.01927758, -0.01813637, -0.01699515, -0.01585394, -0.01471272, -0.01357151, -0.01243029, -0.01128908, -0.01014786, -0.00900664, + -0.05409876, -0.04945958, -0.04482041, -0.04018124, -0.03554206, -0.03090289, -0.02626371, -0.02162454, -0.01698537, -0.01234619, -0.26145172, -0.214688 , -0.16792431, -0.12116055, -0.07439683, -0.02763309, 0.01913062, 0.06589434, 0.11265809, 0.15942183, + 0.25974026, 0.19902176, 0.13830325, 0.07758474, 0.01686624, -0.04385226, -0.10457078, -0.16528927, -0.22600779, -0.2867263 , -0.01177884, -0.0173331 , -0.02288735, -0.02844159, -0.03399584, -0.0395501 , -0.04510435, -0.05065861, -0.05621284, -0.0617671 , + -0.00944993, -0.01073084, -0.01201174, -0.01329265, -0.01457355, -0.01585446, -0.01713536, -0.01841627, -0.01969717, -0.02097807, -0.00589878, -0.00637122, -0.00684368, -0.00731612, -0.00778858, -0.00826102, -0.00873347, -0.00920592, -0.00967837, -0.01015082, + -0.00390961, -0.00413245, -0.00435528, -0.00457812, -0.00480095, -0.00502378, -0.00524662, -0.00546945, -0.00569229, -0.00591512, -0.00275609, -0.00287813, -0.00300018, -0.00312222, -0.00324427, -0.00336631, -0.00348836, -0.0036104 , -0.00373245, -0.00385449, + -0.00203982, -0.00211371, -0.00218759, -0.00226147, -0.00233536, -0.00240924, -0.00248312, -0.00255701, -0.00263089, -0.00270478, -0.00156781, -0.00161586, -0.00166391, -0.00171197, -0.00176002, -0.00180807, -0.00185612, -0.00190417, -0.00195223, -0.00200028, + -0.00124141, -0.00127439, -0.00130737, -0.00134035, -0.00137333, -0.00140631, -0.00143929, -0.00147227, -0.00150525, -0.00153822, -0.00100674, -0.00103034, -0.00105394, -0.00107754, -0.00110115, -0.00112475, -0.00114835, -0.00117195, -0.00119556, -0.00121916, + -0.00083255, -0.00085002, -0.00086748, -0.00088495, -0.00090242, -0.00091989, -0.00093735, -0.00095482, -0.00097229, -0.00098976, -0.0006998 , -0.00071308, -0.00072637, -0.00073965, -0.00075294, -0.00076623, -0.00077951, -0.0007928 , -0.00080609, -0.00081937, + -0.00059635, -0.00060669, -0.00061703, -0.00062737, -0.00063771, -0.00064805, -0.00065839, -0.00066873, -0.00067906, -0.0006894 , -0.0005142 , -0.0005224 , -0.00053061, -0.00053881, -0.00054701, -0.00055522, -0.00056342, -0.00057162, -0.00057983, -0.00058803}); + + input.linspace(-10, 0.1); + gradO = 1; + + sd::ops::lrn_bp op; + + auto results = op.evaluate({&input, &gradO}, {1., 1., 1}, {12}); + auto gradI = results.at(0); + + ASSERT_EQ(*gradI, exp); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, lrn_bp_5) { + + NDArray input('c', {2,2,2,5}); + NDArray gradO('c', {2,2,2,5}); + NDArray exp('c', {2,2,2,5}, {6.2497472e-03, -3.4008762e-03, -1.5232352e-02, 2.3018382e-04, 1.3257053e-02, 7.1492628e-03, -5.4330104e-03, -2.0878183e-02, 1.5153568e-03, 2.0571884e-02, + 6.7926152e-03, -1.0990440e-02, -3.2685306e-02, 7.2436016e-03, 4.2120241e-02, -1.3439789e-02, -3.4284033e-02, -4.4852167e-02, 8.8073254e-02, 2.2223940e-01, + 4.0824831e-01, 2.1201703e-01, 3.8555145e-02, -3.1969927e-02, -3.0673094e-02, 5.2034661e-02, 1.0463811e-02, -3.6619946e-02, -1.3280880e-02, 5.9767403e-03, + 2.3028374e-02, 2.0452859e-03, -2.2533152e-02, -6.1039329e-03, 7.2805062e-03, 1.4290780e-02, 3.8017845e-04, -1.6107092e-02,-3.6896234e-03, 6.4357026e-03}); + input.linspace(-20, 1); + // gradO.linspace(0.1, 0.1); + gradO = 1; + + sd::ops::lrn_bp op; + + auto results = op.evaluate({&input, &gradO}, {1., 1., 0.5}, {2}); + auto gradI = results.at(0); + + ASSERT_EQ(*gradI, exp); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, lrn_bp_6) { + + NDArray input('c', {1,1,1,5}, {1, 2., 3, 4, 5}); + NDArray gradO('c', {1,1,1,5}); + NDArray exp('c', {1,1,1,5}, {0.06926288, 0.04360996, 0.01795704, -0.00769587, -0.0333488}); + // gradO.linspace(-1.5, 0.1); + gradO = 1; + + sd::ops::lrn_bp op; + + auto results = op.evaluate({&input, &gradO}, {1., 2., 0.5}, {10}); + auto gradI = results.at(0); + + ASSERT_EQ(*gradI, exp); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, lrn_bp_7) { + + NDArray input('c', {2,2,2,5}); + NDArray gradO('c', {2,2,2,5}); + + input.linspace(-20, 1); + gradO.linspace(-1.5, 0.1); + + const OpArgsHolder argsHolderFF({&input}, {1,2,0.5}, {2}); + const OpArgsHolder argsHolderBP({&input, &gradO}, {1,2,0.5}, {2}); + + sd::ops::lrn opFF; + sd::ops::lrn_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, lrn_bp_8) { + + NDArray input('c', {1,1,1,5}, {1, 2, 3, 4, 5}); + NDArray gradO('c', {1,1,1,5}, {2, 3, 4, 5, 6}); + + const OpArgsHolder argsHolderFF({&input}, {1,2,0.5}, {2}); + const OpArgsHolder argsHolderBP({&input, &gradO}, {1,2,0.5}, {2}); + + sd::ops::lrn opFF; + sd::ops::lrn_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, lrn_bp_9) { + + NDArray input('c', {1,1,1,5}, {1,2,3,4,5}); + NDArray gradO('c', {1,1,1,5}, {1, 1, 1, 1, 1}); + NDArray exp('c', {1,1,1,5}, {0.1084472 , 0.03816165, 0.00978456, -0.01859251,-0.02511311}); + + sd::ops::lrn_bp op; + + auto results = op.evaluate({&input, &gradO}, {1., 2., 0.5}, {3}); + auto gradI = results.at(0); + + // for (int i = 0; i < exp.lengthOf(); ++i) + // printf("%10.5f %10.5f\n", exp.e(i), gradI->e(i)); + + ASSERT_EQ(*gradI, exp); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, lrn_bp_10) { + + NDArray input('c', {1,1,1,1}, std::vector{1}); + NDArray gradO('c', {1,1,1,1}, std::vector{1}); + NDArray exp('c', {1,1,1,1}, std::vector{0.19245008}); + + sd::ops::lrn_bp op; + + auto results = op.evaluate({&input, &gradO}, {1., 2., 0.5}, {1}); + auto gradI = results.at(0); + + ASSERT_EQ(*gradI, exp); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, lrn_1) { + + NDArray input('c', {2,2,2,5}); + NDArray exp('c', {2,2,2,5}, {-0.42923987, -0.3623817 , -0.3152079 , -0.34268343, -0.3836809, -0.43648192, -0.3652726 , -0.31428117, -0.3379276 , -0.3731494 , + -0.45129365, -0.37083852, -0.3111639 , -0.3260225 , -0.34698898, -0.4975186 , -0.3831305 , -0.2847474 , -0.25607377, -0.18569534, + 0., 0.18569534, 0.25607377, 0.38411066, 0.52075565,0.33633637, 0.32117262, 0.30966178, 0.37259716, 0.45631808, + 0.36986336, 0.33643705, 0.31394684, 0.36608824, 0.43857202, 0.3821113 , 0.34197718, 0.31508508, 0.36284128, 0.4303756 }); + + input.linspace(-20, 1); + + sd::ops::lrn op; + + auto results = op.evaluate({&input}, {1., 2., 0.5}, {2}); + auto output = results.at(0); + + ASSERT_EQ(*output, exp); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, lrn_2) { + + NDArray input('c', {1,1,1,5}, {1, 2., 3, 4, 5}); + NDArray exp('c', {1,1,1,5}, {0.09530295, 0.1906059 , 0.28590885, 0.3812118 , 0.47651473}); + + sd::ops::lrn op; + + auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {5}); + auto output = results.at(0); + ASSERT_EQ(*output, exp); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, lrn_3) { + + NDArray input('c', {1,1,1,1}, std::vector{1.}); + NDArray exp('c', {1,1,1,1}, std::vector{0.69006556}); + + sd::ops::lrn op; + + auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {5}); + auto output = results.at(0); + ASSERT_EQ(*output, exp); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, lrn_4) { + + NDArray input('c', {1,1,1,1}, std::vector{1.}); + NDArray exp('c', {1,1,1,1}, std::vector{0.69006556}); + + sd::ops::lrn op; + + auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {0}); + auto output = results.at(0); + ASSERT_EQ(*output, exp); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, lrn_5) { + + NDArray input('c', {1,1,1,5}, {1, 2., 3, 4, 5}); + NDArray exp('c', {1,1,1,5}, {0.69006556, 0.70272833, 0.7051508 , 0.7060045 , 0.7064008}); + + sd::ops::lrn op; + + auto results = op.evaluate({&input}, {0.1, 2., 0.5}, {0}); + auto output = results.at(0); + ASSERT_EQ(*output, exp); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, inTopK_1) { + + NDArray x('c', {4, 5}, {11.0, 14.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 15.0, 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 5.0, 16.0, 9.0, 13.5, 7.0}); + NDArray y('c', {4}, {0., 0, 0, 0}, sd::DataType::INT64); + NDArray z('c', {4}, {1., 1, 1, 1}, sd::DataType::BOOL); + + NDArray expV('c', {4}, {1., 0, 0, 0}, sd::DataType::BOOL); + + sd::ops::in_top_k op; + Nd4jStatus status = op.execute({&x, &y, }, {&z}, {}, {2}, {}); + + // z.printIndexedBuffer(); + ASSERT_EQ(ND4J_STATUS_OK, status); + + ASSERT_TRUE(expV.isSameShape(z)); + ASSERT_TRUE(expV.equalsTo(z)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, inTopK_2) { + + auto input = NDArrayFactory::create('c', {4, 5}); + auto idx = NDArrayFactory::create('c', {4}); + + auto exp = NDArrayFactory::create({false, false, false, true}); + + int exclusive, reverse; + input.linspace(1); + idx.linspace(1); + + sd::ops::in_top_k op; + + auto res = op.evaluate({&input, &idx}, {}, {1}); + + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + //res.at(0)->printIndexedBuffer("IN_TOP_K output"); + ASSERT_TRUE(res.at(0)->equalsTo(&exp)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, inTopK_3) { + auto x = NDArrayFactory::create('c', {2, 3}, {1.0, 11.0, 3.0, 14.0, 5.0, 6.0}); + auto y = NDArrayFactory::create('c', {2}, {1, 1}); + auto expV = NDArrayFactory::create('c', {2}, {true, false}); + + sd::ops::in_top_k op; + auto result = op.evaluate({&x, &y}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(1, result.size()); + + auto v = result.at(0); + + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, inTopK_4) { + auto x = NDArrayFactory::create('c', {6, 4}, {11.0, 3.0, 14.0, 5.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 14.0, 15.0, 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 14.0, 5.0, 16.0, 9.0, 13.5, 7.0} ); + auto y = NDArrayFactory::create('c', {6}, {0, 0, 0, 0, 0, 0}); + auto expV = NDArrayFactory::create('c', {6}, {true, false, true, false, false, true}); + + sd::ops::in_top_k op; + auto result = op.evaluate({&x, &y}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(1, result.size()); + + auto v = result.at(0); + + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); + + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, inTopK_5) { + auto x = NDArrayFactory::create('f', {6, 4}, {11.0, 3.0, 14.0, 5.0, 6.0, 9.0, 3.5, 7.0, 21.0, 3.0, 14.0, 15.0, 6.0, 9.0, 3.5, 7.0, 11.0, 13.0, 14.0, 5.0, 16.0, 9.0, 13.5, 7.0} ); + auto y = NDArrayFactory::create('f', {6}, {0, 0, 0, 0, 0, 0}); + auto expV = NDArrayFactory::create('f', {6}, {true, false, false, false, false, false }); + + sd::ops::in_top_k op; + auto result = op.evaluate({&x, &y}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(1, result.size()); + + auto v = result.at(0); + + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, cube_1) { + + NDArray x('c', {2, 3}, {1., 2., 3., 4., 5, 6}); + NDArray exp('c', {2, 3}, {1., 8., 27., 64., 125, 216}); + + sd::ops::cube op; + + auto result = op.evaluate({&x}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, cube_bp_1) { + + NDArray x('c', {2, 3}, {1., 2., 3., 4., 5, 6}); + NDArray gradO('c', {2, 3}, sd::DataType::DOUBLE); + NDArray exp('c', {2, 3}, {1.5, 6., 13.5, 24., 37.5, 54}); + + gradO = 0.5; + + sd::ops::cube_bp op; + + auto result = op.evaluate({&x, &gradO}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + // z->printIndexedBuffer(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////// +// CONSTANT mode 2D +TEST_F(DeclarableOpsTests12, pad_tests1) { + + + NDArray input('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::FLOAT32); + NDArray paddings('c', {2,2}, {1,1,2,2}, sd::DataType::INT32); + NDArray expected('c', {4,7}, {0,0,0,0,0,0,0, 0,0,1,2,3,0,0, 0,0,4,5,6,0,0, 0,0,0,0,0,0,0}, sd::DataType::FLOAT32); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + + +//////////////////////////////////////////////////////////////////// +// REFLECT mode 2D +TEST_F(DeclarableOpsTests12, pad_tests2) { + + float inBuff[] = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}; + int padBuff[] = {1,1,2,2}; + float expBuff[] = {6.f, 5.f, 4.f, 5.f, 6.f, 5.f, 4.f, 3.f, 2.f, 1.f, 2.f, 3.f, 2.f, 1.f, 6.f, 5.f, 4.f, 5.f, 6.f, 5.f, 4.f, 3.f, 2.f, 1.f, 2.f, 3.f, 2.f, 1.f}; + + auto input = NDArrayFactory::create(inBuff, 'c', {2,3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {2,2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + + +//////////////////////////////////////////////////////////////////// +// SYMMETRIC mode 2D +TEST_F(DeclarableOpsTests12, pad_tests3) { + + float inBuff[] = {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}; + Nd4jLong padBuff[] = {1,1,2,2}; + float expBuff[] = {2.f, 1.f, 1.f, 2.f, 3.f, 3.f, 2.f, 2.f,1.f,1.f,2.f,3.f,3.f,2.f, 5.f,4.f,4.f,5.f,6.f,6.f,5.f, 5.f,4.f,4.f,5.f,6.f,6.f,5.f}; + + auto input = NDArrayFactory::create(inBuff, 'c', {2,3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {2,2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + + +//////////////////////////////////////////////////////////////////// +// CONSTANT mode 3D +TEST_F(DeclarableOpsTests12, pad_tests4) { + + float inBuff[] = {1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f,11.f,12.f,13.f,14.f,15.f,16.f,17.f,18.f}; + int padBuff[] = {1,1,2,2,2,2}; + float expBuff[] = {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.f, 2.f, 3.f, 0.f, 0.f, 0.f, 0.f, 4.f, 5.f, 6.f, 0.f, 0.f, 0.f, 0.f, + 7.f, 8.f, 9.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.f, 11.f, 12.f, 0.f, + 0.f, 0.f, 0.f, 13.f, 14.f, 15.f, 0.f, 0.f, 0.f, 0.f, 16.f, 17.f, 18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}; + + auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {3,2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + // for(int i = 0; i < expected.lengthOf(); ++i) { + // float one = expected.e(i); + // float two = result->e(i); + // if(one != two) + // printf("%i : %f, %f\n", i, one, two); + // } + + +} + + + +//////////////////////////////////////////////////////////////////// +// REFLECT mode 3D +TEST_F(DeclarableOpsTests12, pad_tests5) { + + double inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; + int padBuff[] = {1,1,2,2,2,2}; + double expBuff[] = {18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 15,14,13,14,15,14,13, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 15,14,13,14,15,14,13, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}; + auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {3,2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + + +//////////////////////////////////////////////////////////////////// +// SYMMETRIC mode 3D +TEST_F(DeclarableOpsTests12, pad_tests6) { + + double inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; + int padBuff[] = {1,1,2,2,2,2}; + double expBuff[] = {5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 14,13,13,14,15,15,14, 11,10,10,11,12,12,11, 11,10,10,11,12,12,11, 14,13,13,14,15,15,14, 17,16,16,17,18,18,17, 17,16,16,17,18,18,17, 14,13,13,14,15,15,14, 14,13,13,14,15,15,14, 11,10,10,11,12,12,11, 11,10,10,11,12,12,11, 14,13,13,14,15,15,14, 17,16,16,17,18,18,17, 17,16,16,17,18,18,17, 14,13,13,14,15,15,14}; + + auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {3,2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +//////////////////////////////////////////////////////////////////// +// CONSTANT mode 4D +TEST_F(DeclarableOpsTests12, pad_tests7) +{ + + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; + double expBuff[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 0, 0, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 10, 0, 0, 11, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 14, 0, 0, 15, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +//////////////////////////////////////////////////////////////////// +// REFLECT mode 4D +TEST_F(DeclarableOpsTests12, pad_tests8) +{ + + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; + double expBuff[] = {16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +////////////////////////////////////////////////////////////////// +// SYMMETRIC mode 4D +TEST_F(DeclarableOpsTests12, pad_tests9) +{ + + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; + double expBuff[] = {1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, pad_tests10) { + + auto input = NDArrayFactory::create('c', {2,3,4}); + auto paddings = NDArrayFactory::create('c', {3,2}, {0,0, 0,1, 0,0}); + auto expected = NDArrayFactory::create('c', {2,4,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,0.,0.,0.,0.}); + + input = 1.f; + //input.assign(1.); + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, pad_tests11) { + + auto input = NDArrayFactory::create('c', {2,3,4}); + auto paddings = NDArrayFactory::create('c', {3,2}, {0,0, 0,1, 0,0}); + auto expected = NDArrayFactory::create('c', {2,4,4}, {1., 2., 3., 4., 5., 6., 7., 8., 9.,10.,11.,12., 5., 6., 7., 8.,13.,14.,15.,16.,17.,18.,19.,20.,21.,22.,23.,24.,17.,18.,19.,20.}); + + input.linspace(1.f); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, pad_tests12) { + + auto input = NDArrayFactory::create('c', {2,3,4,5}); + auto paddings = NDArrayFactory::create('c', {4,2}, {0,0, 0,1, 0,1, 0,0}); + auto expected = NDArrayFactory::create('c', {2,4,5,5}, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 16., 17., 18., 19., 20., + 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 36., 37., 38., 39., 40., + 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 56., 57., 58., 59., 60., + 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., 56., 57., 58., 59., 60., 56., 57., 58., 59., 60., + 61., 62., 63., 64., 65., 66., 67., 68., 69., 70., 71., 72., 73., 74., 75., 76., 77., 78., 79., 80., 76., 77., 78., 79., 80., + 81., 82., 83., 84., 85., 86., 87., 88., 89., 90., 91., 92., 93., 94., 95., 96., 97., 98., 99.,100., 96., 97., 98., 99.,100., + 101.,102.,103.,104.,105.,106.,107.,108.,109.,110.,111.,112.,113.,114.,115.,116.,117.,118.,119.,120.,116.,117.,118.,119.,120., + 101.,102.,103.,104.,105.,106.,107.,108.,109.,110.,111.,112.,113.,114.,115.,116.,117.,118.,119.,120.,116.,117.,118.,119.,120.}); + input.linspace(1.f); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, pad_tests13) { + + auto input = NDArrayFactory::create('c', {5}); + auto paddings = NDArrayFactory::create('c', {1,2}, {2,3}); + auto expected = NDArrayFactory::create('c', {10}, {3., 2., 1., 2., 3., 4., 5., 4., 3., 2.}); + input.linspace(1.f); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, pad_tests14) { + + auto input = NDArrayFactory::create('c', {1,5}); + auto paddings = NDArrayFactory::create('c', {2,2}, {0,0,2,3}); + auto expected = NDArrayFactory::create('c', {1,10}, {2., 1., 1., 2., 3., 4., 5., 5., 4., 3.}); + input.linspace(1.f); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, pad_tests15) { + + auto input = NDArrayFactory::create('c', {1,5}); + auto paddings = NDArrayFactory::create('c', {2,2}, {1,1,0,0}); + auto expected = NDArrayFactory::create('c', {3,5}, {1., 2., 3., 4., 5., 1., 2., 3., 4., 5., 1., 2., 3., 4., 5.}); + input.linspace(1.f); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, pad_tests16) { + + auto input = NDArrayFactory::create('c', {5,1}); + auto paddings = NDArrayFactory::create('c', {2,2}, {2,3,0,0}); + auto expected = NDArrayFactory::create('c', {10,1}, {3., 2., 1., 2., 3., 4., 5., 4., 3., 2.}); + input.linspace(1.f); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, pad_tests17) { + + auto input = NDArrayFactory::create('c', {5,1}); + auto paddings = NDArrayFactory::create('c', {2,2}, {0,0,1,0}); + auto expected = NDArrayFactory::create('c', {5,2}, {1.,1., 2.,2., 3.,3., 4.,4., 5.,5.}); + input.linspace(1.f); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, pad_tests18) { + + auto input = NDArrayFactory::create('c', {5}); + auto paddings = NDArrayFactory::create('c', {1,2}, {0,0}); + auto expected = NDArrayFactory::create('c', {5}, {1.,2.,3.,4.,5.}); + input.linspace(1.f); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, pad_tests19) { + + auto input = NDArrayFactory::create('c', {5,1}); + auto paddings = NDArrayFactory::create('c', {2,2}, {0,0,0,0}); + auto expected = NDArrayFactory::create('c', {5,1}, {1., 2., 3., 4., 5.}); + input.linspace(1.f); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, pad_tests20) { + + auto input = NDArrayFactory::create('c', {1,5}); + auto paddings = NDArrayFactory::create('c', {2,2}, {0,0,0,0}); + auto expected = NDArrayFactory::create('c', {1,5}, {1., 2., 3., 4., 5.}); + input.linspace(1.f); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, pad_tests21) { + + auto input = NDArrayFactory::create('c', {1,3,1,5}); + auto paddings = NDArrayFactory::create('c', {4,2}, {0,0, 0,1, 0,1, 0,0}); + auto expected = NDArrayFactory::create('c', {1,4,2,5}, {1., 2., 3., 4., 5., 1., 2., 3., 4., 5., 6., 7., 8., 9.,10., 6., 7., 8., 9.,10., + 11.,12.,13.,14.,15.,11.,12.,13.,14.,15.,11.,12.,13.,14.,15.,11.,12.,13.,14.,15.}); + input.linspace(1.f); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, pad_tests22) { + + auto input = NDArrayFactory::create('c', {1,1}); + auto paddings = NDArrayFactory::create('c', {2,2}, {0,0, 0,0}); + auto expected = NDArrayFactory::create('c', {1,1}, {1.}); + + input.linspace(1.f); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, pad_tests23) { + + auto input = NDArrayFactory::create('c', {1,1}); + auto paddings = NDArrayFactory::create('c', {2,2}, {0,0, 1,0}); + auto expected = NDArrayFactory::create('c', {1,2}, {0.,1.}); + + input.linspace(1.f); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printShapeInfo("r"); + // expected.printShapeInfo("e"); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, pad_tests24) { + + auto input = NDArrayFactory::create('c', {1}); + auto paddings = NDArrayFactory::create('c', {1,2}, {0,0}); + auto expected = NDArrayFactory::create('c', {1}, {1.}); + + input.linspace(1.f); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, pad_tests25) { + + auto input = NDArrayFactory::create('c', {1}); + auto paddings = NDArrayFactory::create('c', {1,2}, {1,1}); + auto expected = NDArrayFactory::create('c', {3}, {1.,1.,1}); + + input.linspace(1.f); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, pad_tests26) { + + auto input = NDArrayFactory::create('c', {1}); + auto paddings = NDArrayFactory::create('c', {1,2}, {3,2}); + auto expected = NDArrayFactory::create('c', {6}, {0., 0., 0., 1., 0., 0.}); + + input.linspace(1.f); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, pad_tests27) { + + NDArray input('c', {2,3}, sd::DataType::FLOAT32); + NDArray paddings('c', {2,2}, {0,0,0,1}, sd::DataType::INT32); + NDArray exp('c', {2,4}, {1,1,1,0,1,1,1,0}, sd::DataType::FLOAT32); + NDArray z('c', {2,4}, sd::DataType::FLOAT32); + input = 1.; + + sd::ops::pad op; + Nd4jStatus status = op.execute({&input, &paddings}, {&z}, {0}, {0}, {}); // constant + // z.printIndexedBuffer(); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(exp.isSameShapeStrict(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, pad_tests28) { + + NDArray input('c', {1,111,111,32}, sd::DataType::FLOAT32); + NDArray paddings('c', {4,2}, {0,0,0,1,0,1,0,0}, sd::DataType::INT32); + NDArray z('c', {1,112,112,32}, sd::DataType::FLOAT32); + input = 1.; + + sd::ops::pad op; + Nd4jStatus status = op.execute({&input, &paddings}, {&z}, {0}, {0}, {}); // constant + // z.printIndexedBuffer(); + + NDArray sum = z.reduceNumber(sd::reduce::Sum); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_EQ(sum.e(0), 111*111*32); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, pad_tests29) { + + auto in = NDArrayFactory::create({1., 1., 1., 1., 1.}); +// auto pad = NDArrayFactory::create('c', {1, 2}, {1., 1.});// = Nd4j.create(new double[]{1, 1}, new long[]{1, 2}); + auto pad = NDArrayFactory::create('c', {1, 2}, {1, 1}); +// auto value(10.0); + + auto exp = NDArrayFactory::create({10., 1., 1., 1., 1., 1., 10.}); + + sd::ops::pad op; + + auto res = op.evaluate({&in, &pad}, {10.0}, {0}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + ASSERT_TRUE(exp.equalsTo(res.at(0))); + +} + + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, pad_tests30) { + + auto in = NDArrayFactory::create({1., 11., 111., 11., 1.}); + auto pad = NDArrayFactory::create('c', {1, 2}, {1, 1}); + + auto exp = NDArrayFactory::create({1., 1., 11., 111., 11., 1., 1.}); + + sd::ops::pad op; + + auto res = op.evaluate({&in, &pad}, {10.0}, {2}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + ASSERT_TRUE(exp.equalsTo(res.at(0))); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, pad_tests31) { + + auto in = NDArrayFactory::create({1., 11., 111., 1111., 11111.}); +// auto pad = NDArrayFactory::create('c', {1, 2}, {1., 1.});// = Nd4j.create(new double[]{1, 1}, new long[]{1, 2}); + auto pad = NDArrayFactory::create('c', {1, 2}, {1, 1}); +// auto value(10.0); + + auto exp = NDArrayFactory::create({11., 1., 11., 111., 1111., 11111., 1111.}); + + sd::ops::pad op; + + auto res = op.evaluate({&in, &pad}, {10.0}, {1}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + ASSERT_TRUE(exp.equalsTo(res.at(0))); + +} + +/////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, pad_tests32) { + + auto in = NDArrayFactory::create('c', {3,3}, {1., 2., 3., 4., 5.,6,7,8,9}); + auto pad = NDArrayFactory::create('c', {2,2}, {1, 2, 2, 3}); + + auto exp = NDArrayFactory::create('c', {6,8}, {2, 1, 1, 2, 3, 3, 2, 1, 2, 1, 1, 2, 3, 3, 2, 1, 5, 4, 4, 5, 6, 6, 5, 4, 8, 7, 7, 8, 9, 9, 8, 7, 8, 7, 7, 8, 9, 9, 8, 7, 5, 4, 4, 5, 6, 6, 5, 4}); + + sd::ops::pad op; + + auto res = op.evaluate({&in, &pad}, {10.0}, {2}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + ASSERT_TRUE(exp.equalsTo(res.at(0))); + +} +/////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, pad_tests33) { + + auto in = NDArrayFactory::create('c', {2,3,4}, {1, 2, 3, 4,5, 6, 7, 8,9,10,11,12,13, 14, 15, 16,17, 18, 19, 20,21, 22, 23, 24}); + + auto pad = NDArrayFactory::create('c', {3,2}, {1, 2, 2, 3, 3,3}); + + auto exp = NDArrayFactory::create('c', {5,8,10}, { 7,6,5,5,6,7,8,8,7,6., 3,2,1,1,2,3,4,4,3,2., 3,2,1,1,2,3,4,4,3,2., 7,6,5,5,6,7,8,8,7,6., 11,10,9,9,10,11,12,12,11,10., + 11,10,9,9,10,11,12,12,11,10., 7,6,5,5,6,7,8,8,7,6., 3,2,1,1,2,3,4,4,3,2., 7,6,5,5,6,7,8,8,7,6., 3,2,1,1,2,3,4,4,3,2., + 3,2,1,1,2,3,4,4,3,2., 7,6,5,5,6,7,8,8,7,6., 11,10,9,9,10,11,12,12,11,10., 11,10,9,9,10,11,12,12,11,10.,7,6,5,5,6,7,8,8,7,6., + 3,2,1,1,2,3,4,4,3,2., 19,18,17,17,18,19,20,20,19,18., 15,14,13,13,14,15,16,16,15,14., 15,14,13,13,14,15,16,16,15,14., + 19,18,17,17,18,19,20,20,19,18., 23,22,21,21,22,23,24,24,23,22., 23,22,21,21,22,23,24,24,23,22., 19,18,17,17,18,19,20,20,19,18., + 15,14,13,13,14,15,16,16,15,14., 19,18,17,17,18,19,20,20,19,18., 15,14,13,13,14,15,16,16,15,14., 15,14,13,13,14,15,16,16,15,14., + 19,18,17,17,18,19,20,20,19,18., 23,22,21,21,22,23,24,24,23,22., 23,22,21,21,22,23,24,24,23,22., 19,18,17,17,18,19,20,20,19,18., + 15,14,13,13,14,15,16,16,15,14., 7,6,5,5,6,7,8,8,7,6., 3,2,1,1,2,3,4,4,3,2., 3,2,1,1,2,3,4,4,3,2., 7,6,5,5,6,7,8,8,7,6., + 11,10,9,9,10,11,12,12,11,10., 11,10,9,9,10,11,12,12,11,10., 7,6,5,5,6,7,8,8,7,6., 3,2,1,1,2,3,4,4,3,2.}); + sd::ops::pad op; + + auto res = op.evaluate({&in, &pad}, {10.0}, {2}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + ASSERT_TRUE(exp.equalsTo(res.at(0))); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, pad_tests34) { + + NDArray input('c', {5}, {0.778786, 0.801198, 0.724375, 0.230894, 0.727141}, sd::DataType::FLOAT32); + NDArray paddings('c', {1,2}, {1,1}, sd::DataType::INT32); + NDArray expected('c', {7}, {10., 0.778786, 0.801198, 0.724375, 0.230894, 0.727141, 10.}, sd::DataType::FLOAT32); + NDArray z('c', {7}, sd::DataType::FLOAT32); + + sd::ops::pad op; + Nd4jStatus status = op.execute({&input, &paddings}, {&z}, {10}, {0}, {}); // constant + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.isSameShapeStrict(z)); + ASSERT_TRUE(expected.equalsTo(z)); +} + +//////////////////////////////////////////////////////////////////// +// CONSTANT mode 2D +TEST_F(DeclarableOpsTests12, Pad_1) { + + double inBuff[] = {1,2,3,4,5,6}; + int padBuff[] = {1,1,2,2}; + double expBuff[] = {0,0,0,0,0,0,0, 0,0,1,2,3,0,0, 0,0,4,5,6,0,0, 0,0,0,0,0,0,0}; + + auto input = NDArrayFactory::create(inBuff, 'c', {2,3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {2,2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + + +//////////////////////////////////////////////////////////////////// +// REFLECT mode 2D +TEST_F(DeclarableOpsTests12, Pad_2) { + + double inBuff[] = {1,2,3,4,5,6}; + int padBuff[] = {1,1,2,2}; + double expBuff[] = {6,5,4,5,6,5,4, 3,2,1,2,3,2,1, 6,5,4,5,6,5,4, 3,2,1,2,3,2,1}; + + auto input = NDArrayFactory::create(inBuff, 'c', {2,3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {2,2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + + +//////////////////////////////////////////////////////////////////// +// SYMMETRIC mode 2D +TEST_F(DeclarableOpsTests12, Pad_3) { + + double inBuff[] = {1,2,3,4,5,6}; + int padBuff[] = {1,1,2,2}; + double expBuff[] = {2,1,1,2,3,3,2, 2,1,1,2,3,3,2, 5,4,4,5,6,6,5, 5,4,4,5,6,6,5}; + + auto input = NDArrayFactory::create(inBuff, 'c', {2,3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {2,2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4,7}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + + +//////////////////////////////////////////////////////////////////// +// CONSTANT mode 3D +TEST_F(DeclarableOpsTests12, Pad_4) { + + double inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; + int padBuff[] = {1,1,2,2,2,2}; + double expBuff[] = {0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 1, 2, 3,0,0,0,0, 4, 5, 6,0,0,0,0, 7, 8, 9,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0,10,11,12,0,0,0,0,13,14,15,0,0,0,0,16,17,18,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0,0,0, 0, 0, 0,0,0}; + + auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {3,2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + + + +//////////////////////////////////////////////////////////////////// +// REFLECT mode 3D +TEST_F(DeclarableOpsTests12, Pad_5) { + + double inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; + int padBuff[] = {1,1,2,2,2,2}; + double expBuff[] = {18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 15,14,13,14,15,14,13, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 15,14,13,14,15,14,13, 18,17,16,17,18,17,16, 15,14,13,14,15,14,13, 12,11,10,11,12,11,10, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 9, 8, 7, 8, 9, 8, 7, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}; + auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {3,2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + + +//////////////////////////////////////////////////////////////////// +// SYMMETRIC mode 3D +TEST_F(DeclarableOpsTests12, Pad_6) { + + double inBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18}; + int padBuff[] = {1,1,2,2,2,2}; + double expBuff[] = {5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5, 2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 8, 7, 7, 8, 9, 9, 8, 8, 7, 7, 8, 9, 9, 8, 5, 4, 4, 5, 6, 6, 5, 14,13,13,14,15,15,14, 11,10,10,11,12,12,11, 11,10,10,11,12,12,11, 14,13,13,14,15,15,14, 17,16,16,17,18,18,17, 17,16,16,17,18,18,17, 14,13,13,14,15,15,14, 14,13,13,14,15,15,14, 11,10,10,11,12,12,11, 11,10,10,11,12,12,11, 14,13,13,14,15,15,14, 17,16,16,17,18,18,17, 17,16,16,17,18,18,17, 14,13,13,14,15,15,14}; + + auto input = NDArrayFactory::create(inBuff, 'c', {2,3,3}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {3,2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4,7,7}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +//////////////////////////////////////////////////////////////////// +// CONSTANT mode 4D +TEST_F(DeclarableOpsTests12, Pad_7) +{ + + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; + double expBuff[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 6, 0, 0, 7, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 10, 0, 0, 11, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 13, 14, 0, 0, 15, 16, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +//////////////////////////////////////////////////////////////////// +// REFLECT mode 4D +TEST_F(DeclarableOpsTests12, Pad_8) +{ + + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; + double expBuff[] = {16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 16, 15, 16, 15, 14, 13, 14, 13, 16, 15, 16, 15, 14, 13, 14, 13, 12, 11, 12, 11, 10, 9, 10, 9, 12, 11, 12, 11, 10, 9, 10, 9, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1, 8, 7, 8, 7, 6, 5, 6, 5, 8, 7, 8, 7, 6, 5, 6, 5, 4, 3, 4, 3, 2, 1, 2, 1, 4, 3, 4, 3, 2, 1, 2, 1}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +////////////////////////////////////////////////////////////////// +// SYMMETRIC mode 4D +TEST_F(DeclarableOpsTests12, Pad_9) +{ + + double inBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + int padBuff[] = {1, 1, 1, 1, 1, 1, 1, 1}; + double expBuff[] = {1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 1, 1, 2, 2, 1, 1, 2, 2, 3, 3, 4, 4, 3, 3, 4, 4, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 5, 5, 6, 6, 5, 5, 6, 6, 7, 7, 8, 8, 7, 7, 8, 8, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 9, 9, 10, 10, 9, 9, 10, 10, 11, 11, 12, 12, 11, 11, 12, 12, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16, 13, 13, 14, 14, 13, 13, 14, 14, 15, 15, 16, 16, 15, 15, 16, 16}; + auto input = NDArrayFactory::create(inBuff, 'c', {2, 2, 2, 2}); + auto paddings = NDArrayFactory::create(padBuff, 'c', {4, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {4, 4, 4, 4}); + + sd::ops::pad op; + auto results = op.evaluate({&input, &paddings}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + // result->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(*result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +TEST_F(DeclarableOpsTests12, Test_Expose_1) { + auto input0 = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 6, 5, 4}); + auto input1 = NDArrayFactory::create('c', {2, 3}, {3, 2, 1, 4, 5, 6}); + + sd::ops::expose op; + + auto result = op.evaluate({&input0, &input1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z0 = result.at(0); + auto z1 = result.at(1); + + ASSERT_TRUE(input0.equalsTo(z0)); + ASSERT_TRUE(input1.equalsTo(z1)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, Pad_SGO_Test_1) { + + auto in = NDArrayFactory::create({1., 1., 1., 1., 1.}); +// auto pad = NDArrayFactory::create('c', {1, 2}, {1., 1.});// = Nd4j.create(new double[]{1, 1}, new long[]{1, 2}); + auto pad = NDArrayFactory::create('c', {1, 2}, {1, 1}); +// auto value(10.0); + + auto exp = NDArrayFactory::create({10., 1., 1., 1., 1., 1., 10.}); + + sd::ops::pad op; + + auto res = op.evaluate({&in, &pad}, {10.0}, {0}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + // res.at(0)->printIndexedBuffer("PAD_SGO"); + // exp.printIndexedBuffer("PAD_EXP"); + ASSERT_TRUE(exp.equalsTo(res.at(0))); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, LU_Test_1) { + + auto in = NDArrayFactory::create('c', {3,3}, {1., 2., 3., 0., 2., 3., 0., 0., 7.}); + auto exp = NDArrayFactory::create('c', {3,3}, {1., 2., 3., 0., 2., 3., 0., 0., 7}); + auto pExp = NDArrayFactory::create('c', {3}, {0, 1, 2}); + sd::ops::lu op; + + auto res = op.evaluate({&in}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + auto p = res.at(1); +// z->printIndexedBuffer("Triangulars"); +// p->printIndexedBuffer("Permutaions"); + + ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(pExp.equalsTo(p)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, LU_Test_2) { + auto in = NDArrayFactory::create('c', {3,3}, {1, 0, 0, 2, 3, 0, 4, 5, 6}); + + auto expLU = NDArrayFactory::create('c', {3,3}, {4., 5., 6., 0.25, -1.25, -1.5, 0.5, -0.4, -3.6}); + auto expP = NDArrayFactory::create({2, 0, 1}); + sd::ops::lu op; + + auto res = op.evaluate({&in}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + auto p = res.at(1); +// z->printIndexedBuffer("Triangulars2"); +// p->printIndexedBuffer("Permutaions2"); + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, LU_Test_3) { + auto in = NDArrayFactory::create('c', {3,3}, {1,2,3,4,7,9, 11, 12, 13}); + + auto expLU = NDArrayFactory::create('c', {3,3}, { + 11., 12., 13., + 0.36363637, 2.6363635, 4.272727, + 0.09090909, 0.3448276, 0.34482753}); + + auto expP = NDArrayFactory::create({2, 1, 0}); + sd::ops::lu op; + + auto res = op.evaluate({&in}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + auto p = res.at(1); +// z->printIndexedBuffer("Triangulars3"); +// p->printIndexedBuffer("Permutaions3"); + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, LU_Test_4) { + + auto in = NDArrayFactory::create('c', {10,10}, { + 1., 2., 3., 4., 5., 6., 7., 8., 1., 15., + 5., 1., 13., 4., 15., 1., 17., 9., 11., 25., + 1., 9., 1., 4., 5., 2., 13., 10, 21., 15., + 3., 9., 4., 1., 5., 3., 7., 1, 1., 5., + 2., 3., 2., 5., 4., 4., 7., 3, 3., 4., + 0., 1., 3., 3., 5., 1., 3., 1, 31., 15., + 2., 1., 4., 3., 1., 5., 1., 2, 31., 35., + 3., 4., 3., 3., 4., 4., 4., 1., 3., 1., + 1., 1., 1., 1., 5., 6., 5., 4., 3., 2., + 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}); + + auto expLU = NDArrayFactory::create('c', {10,10}, { + 5.0, 1.0, 13.0, 4.0, 15.0, 1.0, 17.0, 9.0, 11.0, 25.0, + 0.2, 8.8, -1.6, 3.2, 2.0, 1.8, 9.6, 8.2, 18.8, 10.0, + 0.6, 0.386364, -4.181818, -0.636364, -5.772727, 2.704545, -9.909091, -7.568182, -10.863636, -17.863636, + 0.6, 0.954545, 0.543478, -4.108696, -2.771739, -0.788043, -6.978261, -8.114130, -17.641304, -9.836957, + 0.4, 0.068182, 0.260870, -0.328042, -4.539683, 3.513228, -6.158730, -2.846561, 22.365079, 25.751323, + 0.2, 0.090909, 0.347826, -0.031746, -0.823427, 7.563520, -1.118881, 1.485431, 20.725524, 23.196387, + 0.0, 0.113636, -0.760870, -0.523810, 0.236014, 0.213036, -7.593805, -9.585099, 1.663379, -15.900300, + 0.4, 0.295455, 0.652174, -0.698413, 0.167832, 0.021727, -0.001360, -3.321530, -16.392106, - 9.022119, + 0.2, 0.204545, -0.173913, -0.592593, 0.232517, 0.610602, 0.277466, -0.244631, -39.715757, -18.928178, + 0.2, 0.090909, 0.347826, -0.031746, 0.057692, -0.070344, -0.030154, -0.243578, 0.087256, 0.112695 + }); + + auto expP = NDArrayFactory::create({1, 2, 7, 3, 6, 8, 5, 4, 0, 9}); + sd::ops::lu op; + + auto res = op.evaluate({&in}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + auto p = res.at(1); +// z->printBuffer("Triangulars4"); +// expLU.printBuffer("TriangulExp4"); +// p->printBuffer("Permutaions4"); + + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); + +} + +TEST_F(DeclarableOpsTests12, LU_Test_5) { + + auto in = NDArrayFactory::create('c', {2, 10,10}, { + 1., 2., 3., 4., 5., 6., 7., 8., 1., 15., + 5., 1., 13., 4., 15., 1., 17., 9., 11., 25., + 1., 9., 1., 4., 5., 2., 13., 10, 21., 15., + 3., 9., 4., 1., 5., 3., 7., 1, 1., 5., + 2., 3., 2., 5., 4., 4., 7., 3, 3., 4., + 0., 1., 3., 3., 5., 1., 3., 1, 31., 15., + 2., 1., 4., 3., 1., 5., 1., 2, 31., 35., + 3., 4., 3., 3., 4., 4., 4., 1., 3., 1., + 1., 1., 1., 1., 5., 6., 5., 4., 3., 2., + 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., + + 1., 2., 3., 4., 5., 6., 7., 8., 1., 15., + 5., 1., 13., 4., 15., 1., 17., 9., 11., 25., + 1., 9., 1., 4., 5., 2., 13., 10, 21., 15., + 3., 9., 4., 1., 5., 3., 7., 1, 1., 5., + 2., 3., 2., 5., 4., 4., 7., 3, 3., 4., + 0., 1., 3., 3., 5., 1., 3., 1, 31., 15., + 2., 1., 4., 3., 1., 5., 1., 2, 31., 35., + 3., 4., 3., 3., 4., 4., 4., 1., 3., 1., + 1., 1., 1., 1., 5., 6., 5., 4., 3., 2., + 1., 1., 1., 1., 1., 1., 1., 1., 1., 1. + }); + + auto expLU = NDArrayFactory::create('c', {2, 10,10}, { + 5.0, 1.0, 13.0, 4.0, 15.0, 1.0, 17.0, 9.0, 11.0, 25.0, + 0.2, 8.8, -1.6, 3.2, 2.0, 1.8, 9.6, 8.2, 18.8, 10.0, + 0.6, 0.386364, -4.181818, -0.636364, -5.772727, 2.704545, -9.909091, -7.568182, -10.863636, -17.863636, + 0.6, 0.954545, 0.543478, -4.108696, -2.771739, -0.788043, -6.978261, -8.114130, -17.641304, -9.836957, + 0.4, 0.068182, 0.260870, -0.328042, -4.539683, 3.513228, -6.158730, -2.846561, 22.365079, 25.751323, + 0.2, 0.090909, 0.347826, -0.031746, -0.823427, 7.563520, -1.118881, 1.485431, 20.725524, 23.196387, + 0.0, 0.113636, -0.760870, -0.523810, 0.236014, 0.213036, -7.593805, -9.585099, 1.663379, -15.900300, + 0.4, 0.295455, 0.652174, -0.698413, 0.167832, 0.021727, -0.001360, -3.321530, -16.392106, - 9.022119, + 0.2, 0.204545, -0.173913, -0.592593, 0.232517, 0.610602, 0.277466, -0.244631, -39.715757, -18.928178, + 0.2, 0.090909, 0.347826, -0.031746, 0.057692, -0.070344, -0.030154, -0.243578, 0.087256, 0.112695, + + 5.0, 1.0, 13.0, 4.0, 15.0, 1.0, 17.0, 9.0, 11.0, 25.0, + 0.2, 8.8, -1.6, 3.2, 2.0, 1.8, 9.6, 8.2, 18.8, 10.0, + 0.6, 0.386364, -4.181818, -0.636364, -5.772727, 2.704545, -9.909091, -7.568182, -10.863636, -17.863636, + 0.6, 0.954545, 0.543478, -4.108696, -2.771739, -0.788043, -6.978261, -8.114130, -17.641304, -9.836957, + 0.4, 0.068182, 0.260870, -0.328042, -4.539683, 3.513228, -6.158730, -2.846561, 22.365079, 25.751323, + 0.2, 0.090909, 0.347826, -0.031746, -0.823427, 7.563520, -1.118881, 1.485431, 20.725524, 23.196387, + 0.0, 0.113636, -0.760870, -0.523810, 0.236014, 0.213036, -7.593805, -9.585099, 1.663379, -15.900300, + 0.4, 0.295455, 0.652174, -0.698413, 0.167832, 0.021727, -0.001360, -3.321530, -16.392106, - 9.022119, + 0.2, 0.204545, -0.173913, -0.592593, 0.232517, 0.610602, 0.277466, -0.244631, -39.715757, -18.928178, + 0.2, 0.090909, 0.347826, -0.031746, 0.057692, -0.070344, -0.030154, -0.243578, 0.087256, 0.112695 + + }); + + auto expP = NDArrayFactory::create('c', {2, 10}, { + 1, 2, 7, 3, 6, 8, 5, 4, 0, 9, + 1, 2, 7, 3, 6, 8, 5, 4, 0, 9 + }); + sd::ops::lu op; + + auto res = op.evaluate({&in}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + auto p = res.at(1); +// z->printBuffer("Triangulars5"); +// expLU.printBuffer("TriangulExp5"); +// p->printBuffer("Permutaions5"); + + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, LU_Test_1_2) { + + auto in = NDArrayFactory::create('c', {2, 3,3}, {1., 2., 3., 0., 2., 3., 0., 0., 7.,1., 2., 3., 0., 2., 3., 0., 0., 7.}); + auto exp = NDArrayFactory::create('c', {2, 3,3}, {1., 2., 3., 0., 2., 3., 0., 0., 7, 1., 2., 3., 0., 2., 3., 0., 0., 7.}); + + sd::ops::lu op; + + auto res = op.evaluate({&in}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + auto p = res.at(1); +// z->printIndexedBuffer("Triangulars (2,3,3)"); +// p->printIndexedBuffer("Permutaions (2,3,3)"); + ASSERT_TRUE(exp.equalsTo(res.at(0))); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, LU_Test_3_2) { + + auto in = NDArrayFactory::create('c', {2, 3,3}, {1,2,3,4,7,9, 11, 12, 13,1,2,3,4,7,9, 11, 12, 13}); + + auto expLU = NDArrayFactory::create('c', {2, 3,3}, { + 11., 12., 13., + 0.36363637, 2.6363635, 4.272727, + 0.09090909, 0.3448276, 0.34482753, + + 11., 12., 13., + 0.36363637, 2.6363635, 4.272727, + 0.09090909, 0.3448276, 0.34482753 + }); + + auto expP = NDArrayFactory::create('c', {2,3}, {2, 1, 0, 2, 1, 0}); + sd::ops::lu op; + + auto res = op.evaluate({&in}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + auto p = res.at(1); +// z->printIndexedBuffer("Triangulars3_2"); +// p->printIndexedBuffer("Permutaions3_2"); + + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, LU_Test_3_3) { + + auto in = NDArrayFactory::create('c', {2, 3,3}, {1,2,3,4,7,9, 11, 12, 13,13,2,3,4,7,9, 11, 12, 1}); + auto expLU = NDArrayFactory::create('c', {2, 3,3}, { + 11., 12., 13., + 0.36363637, 2.6363635, 4.272727, + 0.09090909, 0.3448276, 0.34482753, + + 13., 2., 3., + 0.84615386, 10.307693, -1.5384617, + 0.30769232, 0.619403, 9.029851}); + + auto expP = NDArrayFactory::create('c', {2,3}, {2, 1, 0, 0, 2, 1}); + sd::ops::lu op; + + auto res = op.evaluate({&in}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + auto p = res.at(1); +// z->printIndexedBuffer("Triangulars3_3"); +// p->printIndexedBuffer("Permutaions3_3"); + + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, LU_Test_4_1) { + + auto in = NDArrayFactory::create('c', {2, 2,2}, { + 0.7788f, 0.8012f, 0.7244f, 0.2309f, + 0.7271f, 0.1804f, 0.5056f, 0.8925f + }); + + auto expLU = NDArrayFactory::create('c', {2, 2,2}, { + 0.7788f, 0.8012f, 0.930149f, -0.514335f, + 0.7271f, 0.1804f, 0.695365f, 0.767056f + }); + + auto expP = NDArrayFactory::create('c', {2,2}, {0, 1, 0, 1}); + sd::ops::lu op; + + auto res = op.evaluate({&in}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + auto p = res.at(1); +// z->printIndexedBuffer("Triangulars4_1"); +// p->printIndexedBuffer("Permutaions4_1"); + + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, LU_Test_4_2) { + + auto in = NDArrayFactory::create('c', {2, 2,2}, { + 0.7788f, 0.8012f, 0.7244f, 0.2309f, + 0.7271f, 0.1804f, 0.5056f, 0.8925f + }); + + auto expLU = NDArrayFactory::create('c', {2, 2,2}, { + 0.7788f, 0.8012f, 0.930149f, -0.514335f, + 0.7271f, 0.1804f, 0.695365f, 0.767056f + }); + + auto expP = NDArrayFactory::create('c', {2,2}, {0, 1, 0, 1}); + sd::ops::lu op; + + auto res = op.evaluate({&in}, {}, {sd::DataType::INT64}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + auto p = res.at(1); + +// z->printIndexedBuffer("Triangulars4_2"); +// p->printIndexedBuffer("Permutaions4_2"); + + ASSERT_TRUE(expLU.equalsTo(z)); + ASSERT_TRUE(expP.equalsTo(p)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, QR_Test_1) { + + auto in = NDArrayFactory::create('c', {5,3}, { + 12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3. + }); + auto expQ = NDArrayFactory::create('c', {5, 5}, { + 0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485 + }); + + auto expR = NDArrayFactory::create('c', {5,3}, { + -14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0. }); + sd::ops::qr op; + auto res = op.evaluate({&in}, {}, {}, {true}); + + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto q = res.at(0); + auto r = res.at(1); +// q->printIndexedBuffer("Orthogonal 5x5"); +// expQ.printBuffer("Orthogonal Exp"); +// r->printIndexedBuffer("Upper triangular 5x3"); +// expR.printBuffer("Upper triangular Exp"); +// q->printShapeInfo("Q shape"); +// r->printShapeInfo("R shape"); + sd::ops::matmul opMul; + auto res2 = opMul.evaluate({q, r}); //MmulHelper::matmul(q, r, &in, false, false); + auto exp = res2.at(0);//->printIndexedBuffer("Result as result"); + ASSERT_TRUE(exp->isSameShape(in)); +// ASSERT_TRUE(q->isSameShape(expQ)); + + //ASSERT_TRUE(expQ.equalsTo(q)); + ASSERT_TRUE(exp->equalsTo(in)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, QR_Test_1_1) { + + auto in = NDArrayFactory::create('c', {4, 5, 3}, { + 12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3., + 12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3., + 12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3., + 12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3. + }); + auto expQ = NDArrayFactory::create('c', {4, 5, 5}, { + 0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485, + 0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485, + 0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485, + 0.8464148, 0.3912908, -0.3431241, 0.06613743, -0.09146205, -0.42320737, -0.9040873, 0.02927014, 0.01737854, -0.04861044, 0.28213826, -0.17042054, -0.93285596, -0.02194202, 0.14371186, 0.07053456, -0.01404065, 0.00109937, 0.99740064, 0.00429488, -0.14106913, 0.0166551, 0.10577161, 0.00585613, 0.98417485 + }); + + auto expR = NDArrayFactory::create('c', {4, 5,3}, { + -14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0., + -14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0., + -14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0., + -14.177447, -20.666622, 13.401566, 0., -175.04254, 70.080315, 0., 0., 35.201546, 0., 0., 0., 0., 0., 0. + }); + sd::ops::qr op; + auto res = op.evaluate({&in}, {}, {}, {true}); + + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto q = res.at(0); + auto r = res.at(1); +// q->printIndexedBuffer("Orthogonal 5x5"); +// expQ.printBuffer("Orthogonal Exp"); +// r->printIndexedBuffer("Upper triangular 5x3"); +// expR.printBuffer("Upper triangular Exp"); +// q->printShapeInfo("Q shape"); +// r->printShapeInfo("R shape"); + sd::ops::matmul opMul; + auto res2 = opMul.evaluate({q, r}); //MmulHelper::matmul(q, r, &in, false, false); + auto exp = res2.at(0);//->printIndexedBuffer("Result as result"); + ASSERT_TRUE(exp->isSameShape(in)); +// ASSERT_TRUE(q->isSameShape(expQ)); + + //ASSERT_TRUE(expQ.equalsTo(q)); + ASSERT_TRUE(exp->equalsTo(in)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, QR_Test_2) { + + auto in = NDArrayFactory::create('c', {5,3}, {12., -51., 4., 6., 167., -68., -4., 24., -41., -1., 1., 0., 2., 0., 3.}); + auto expQ = NDArrayFactory::create('c', {5, 3}, {0.8464148,0.3912908,-0.3431241,-0.42320737, -0.9040873,0.02927014,0.28213826, -0.17042054, -0.93285596,0.07053456, -0.01404065,0.00109937,-0.14106913,0.0166551,0.10577161}); + auto expR = NDArrayFactory::create('c', {3,3}, {-14.177447,-20.666622,13.401566,0.,-175.04254,70.080315,0.,0.,35.201546}); + + sd::ops::qr op; + auto res = op.evaluate({&in}, {}, {}, {false}); + + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto q = res.at(0); + auto r = res.at(1); + ASSERT_TRUE(q->isSameShape(expQ)); + ASSERT_TRUE(r->isSameShape(expR)); + + sd::ops::matmul opMul; + auto res2 = opMul.evaluate({q, r}); //MmulHelper::matmul(q, r, &in, false, false); + auto exp = res2.at(0);//->printIndexedBuffer("Result as result"); + ASSERT_TRUE(exp->isSameShape(in)); + ASSERT_TRUE(exp->equalsTo(in)); + +} + +TEST_F(DeclarableOpsTests12, ImageResize_Test1) { + + NDArray input = NDArrayFactory::create('c', {1, 5, 5, 1}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 + }); + auto size = NDArrayFactory::create({7, 8}); + NDArray expected = NDArrayFactory::create('c', {1, 7, 8, 1}, { + 0.628328f, 0.97913796f, 1.8058043f, 2.563919f, 2.844548f, + 3.6026628f, 4.4293294f, 4.7801394f, 2.9474494f, 3.2982588f, + 4.1249247f, 4.8830395f, 5.1636696f, 5.9217834f, 6.7484493f, + 7.09926f, 8.165832f, 8.516642f, 9.3433075f, 10.101422f, + 10.382052f, 11.140167f, 11.966835f, 12.317646f, 10.924093f, + 11.274903f, 12.10157f, 12.859686f, 13.140315f, 13.898429f, + 14.725095f, 15.075906f, 13.682358f, 14.033167f, 14.859833f, + 15.617949f, 15.898578f, 16.656693f, 17.48336f, 17.834171f, + 18.900742f, 19.251549f, 20.078213f, 20.83633f, 21.11696f, + 21.875074f, 22.701742f, 23.052553f, 21.219858f, 21.57067f, + 22.397337f, 23.155449f, 23.436079f, 24.194195f, 25.020863f, + 25.371672f + }); + + sd::ops::image_resize op; + // resize with lancos5 without antialising and aspect ratio preserving + auto results = op.evaluate({&input, &size}, {}, {ops::helpers::kResizeLanczos5}, {false, false}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results[0];///.at(0); +// result->printBuffer("Lancos5 Resized to 7x8"); +// expected.printBuffer("Lancos5 Expect for 7x8"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests12, ImageResize_Test2) { + + NDArray input = NDArrayFactory::create('c', {1, 5, 5, 1}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 + }); + auto size = NDArrayFactory::create({7, 8}); + NDArray expected = NDArrayFactory::create('c', {1, 7, 8, 1}, { + 0.628328f, 0.97913796f, 1.8058043f, 2.563919f, 2.844548f, + 3.6026628f, 4.4293294f, 4.7801394f, 2.9474494f, 3.2982588f, + 4.1249247f, 4.8830395f, 5.1636696f, 5.9217834f, 6.7484493f, + 7.09926f, 8.165832f, 8.516642f, 9.3433075f, 10.101422f, + 10.382052f, 11.140167f, 11.966835f, 12.317646f, 10.924093f, + 11.274903f, 12.10157f, 12.859686f, 13.140315f, 13.898429f, + 14.725095f, 15.075906f, 13.682358f, 14.033167f, 14.859833f, + 15.617949f, 15.898578f, 16.656693f, 17.48336f, 17.834171f, + 18.900742f, 19.251549f, 20.078213f, 20.83633f, 21.11696f, + 21.875074f, 22.701742f, 23.052553f, 21.219858f, 21.57067f, + 22.397337f, 23.155449f, 23.436079f, 24.194195f, 25.020863f, + 25.371672f + }); + + sd::ops::image_resize op; + // resize with lanczos5 without antialising and aspect ratio preserving + auto results = op.evaluate({&input, &size}, {}, {ops::helpers::kResizeLanczos5}, {false, false}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results[0];///.at(0); +// result.printBuffer("Lanczos5 Resized to 8x7"); +// expected.printBuffer("Lanczos5 Expect for 8x7"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests12, ImageResize_Test3) { + + NDArray input = NDArrayFactory::create('c', {1, 5, 5, 1}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 + }); + auto size = NDArrayFactory::create({7, 8}); + NDArray expected = NDArrayFactory::create('c', {1, 7, 8, 1}, { + 0.6537938f, 1.0309073f, 1.8018917f, 2.4606667f, 2.9888396f, 3.6476145f, 4.418599f, + 4.7957115f, 3.1913466f, 3.5684595f, 4.3394437f, 4.998219f, 5.526393f, 6.185168f, + 6.956152f, 7.3332644f, 7.626866f, 8.00398f, 8.774965f, 9.433739f, 9.961912f, + 10.620688f, 11.391673f, 11.7687845f, 10.929041f, 11.306154f, 12.077138f, 12.735914f, + 13.264087f, 13.922862f, 14.693848f, 15.07096f, 14.231217f, 14.60833f, 15.379314f, + 16.038086f, 16.56626f, 17.225037f, 17.996023f, 18.373135f, 18.666735f, 19.043848f, + 19.814833f, 20.473606f, 21.00178f, 21.660557f, 22.431541f, 22.808653f, 21.204287f, + 21.581398f, 22.352386f, 23.01116f, 23.539333f, 24.19811f, 24.969095f, 25.346205f + }); + + sd::ops::image_resize op; + // resize with lanczos3 without antialising and aspect ratio preserving + auto results = op.evaluate({&input, &size}, {}, {ops::helpers::kResizeLanczos3}, {false, false}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results[0];///.at(0); +// result.printBuffer("Lanczos3 Resized to 8x7"); +// expected.printBuffer("Lanczos3 Expect for 8x7"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests12, ImageResize_Test4) { + + NDArray input = NDArrayFactory::create('c', {1, 5, 5, 1}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 + }); + auto size = NDArrayFactory::create({7, 8}); + NDArray expected = NDArrayFactory::create('c', {1, 7, 8, 1}, { + 1.4150869f, 1.7928237f, 2.4084527f, 3.0680697f, 3.6419308f, 4.301548f, 4.9171767f, + 5.294914f, 4.012885f, 4.390622f, 5.0062513f, 5.6658688f, 6.23973f, 6.899347f, + 7.514975f, 7.8927126f, 7.358912f, 7.736648f, 8.352278f, 9.011895f, 9.585756f, + 10.245375f, 10.861001f, 11.238739f, 11.060086f, 11.437822f, 12.0534525f, 12.713069f, + 13.28693f, 13.946548f, 14.562176f, 14.939912f, 14.761261f, 15.138998f, 15.754629f, + 16.414246f, 16.988108f, 17.647724f, 18.263351f, 18.641088f, 18.107288f, 18.485023f, + 19.100655f, 19.760273f, 20.334133f, 20.993752f, 21.609377f, 21.987114f, 20.705086f, + 21.082823f, 21.698452f, 22.35807f, 22.93193f, 23.591549f, 24.207174f, 24.584913f + }); + + sd::ops::image_resize op; + // resize with gaussian without antialising and aspect ratio preserving + auto results = op.evaluate({&input, &size}, {}, {ops::helpers::kResizeGaussian}, {false, false}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results[0];///.at(0); +// result.printBuffer("Lanczos3 Resized to 8x7"); +// expected.printBuffer("Lanczos3 Expect for 8x7"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests12, ImageResize_Test5) { + + NDArray input = NDArrayFactory::create('c', {1, 5, 5, 1}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 + }); + auto size = NDArrayFactory::create({7, 8}); + NDArray expected = NDArrayFactory::create('c', {1, 7, 8, 1}, { + 0.6372399f, 1.0536414f, 1.7716959f, 2.3966959f, 3.0216959f, 3.6466963f, 4.3647504f, 4.781152f, + 3.3926036f, 3.8090053f, 4.5270596f, 5.1520596f, 5.7770596f, 6.4020596f, 7.1201134f, 7.5365143f, + 7.358708f, 7.7751093f, 8.493164f, 9.118163f, 9.743165f, 10.368165f, 11.086218f, 11.502619f, + 10.928043f, 11.344445f, 12.0625f, 12.6875f, 13.3125f, 13.9375f, 14.655554f, 15.071955f, + 14.49738f, 14.913782f, 15.631836f, 16.256836f, 16.881836f, 17.506836f, 18.22489f, 18.64129f, + 18.463486f, 18.879889f, 19.597942f, 20.222942f, 20.847942f, 21.472942f, 22.190996f, 22.607397f, + 21.218851f, 21.635252f, 22.353308f, 22.978308f, 23.603308f, 24.228308f, 24.946362f, 25.362762f + }); + + sd::ops::image_resize op; + // resize with bicubic without antialising and aspect ratio preserving + auto results = op.evaluate({&input, &size}, {}, {ops::helpers::kResizeBicubic}, {false, false}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results[0];///.at(0); +// result->printBuffer("Bicubic Resized to 7x8"); +// expected.printBuffer("Bicubic Expect for 7x8"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests12, ImageResize_Test6) { + + NDArray input = NDArrayFactory::create('c', {1, 5, 5, 1}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 + }); + auto size = NDArrayFactory::create({7, 8}); + NDArray expected = NDArrayFactory::create('c', {1, 7, 8, 1}, { + 0.63678247f, 1.0531839f, 1.7712381f, 2.396238f, 3.021238f , 3.646238f, 4.364292f, 4.780694f, + 3.3934183f, 3.8098197f, 4.5278745f, 5.1528745f, 5.7778745f, 6.402874f, 7.1209283f, 7.5373297f, + 7.3566165f, 7.7730184f, 8.491073f, 9.116073f, 9.741073f, 10.366074f , 11.084127f , 11.500528f, + 10.928043f, 11.344445f, 12.0625f , 12.6875f , 13.3125f , 13.9375f , 14.655554f, 15.071955f , 14.499474f , 14.915876f , 15.633932f, 16.25893f, 16.883932f, 17.508932f, 18.226984f , 18.643385f, + 18.46267f, 18.87907f, 19.597128f, 20.222126f , 20.847128f, 21.472126f, 22.190182f , 22.606583f , 21.219305f, 21.635706f , + 22.353762f, 22.978762f , 23.603762f , 24.228764f, 24.946815f , 25.363216f + }); + + sd::ops::image_resize op; + // resize with bicubic with antialising and without aspect ratio preserving + auto results = op.evaluate({&input, &size}, {}, {ops::helpers::kResizeBicubic}, {false, true}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results[0];///.at(0); +// result->printBuffer("Bicubic Resized to 7x8"); +// expected.printBuffer("Bicubic Expect for 7x8"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests12, ImageResize_Test7) { + + NDArray input = NDArrayFactory::create('c', {1, 5, 5, 1}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 + }); + auto size = NDArrayFactory::create({7, 8}); + NDArray expected = NDArrayFactory::create('c', {1, 7, 8, 1}, { + 0.98593485f, 1.3872082f, 2.0625007f, 2.6875007f, 3.3125012f, 3.937501f, 4.612794f, 5.014066f, + 3.6096964f, 4.01097f, 4.6862626f, 5.311262f, 5.936263f, 6.561262f, 7.2365556f, 7.637828f, + 7.4145045f, 7.8157787f, 8.491071f, 9.116072f, 9.741073f, 10.366072f, 11.041365f, 11.4426365f, + 10.985933f, 11.387209f, 12.062499f, 12.687501f, 13.312502f, 13.9375f, 14.612794f, 15.014066f, + 14.557361f, 14.958637f, 15.633926f, 16.25893f, 16.88393f, 17.508926f, 18.18422f, 18.585491f, + 18.36217f, 18.763443f, 19.438736f, 20.063736f, 20.688738f, 21.313736f, 21.98903f, 22.3903f, + 20.985931f, 21.387209f, 22.0625f, 22.6875f, 23.3125f, 23.937498f, 24.612793f, 25.014061f + }); + + sd::ops::image_resize op; + // resize with Mitchell cubic with antialising and without aspect ratio preserving + auto results = op.evaluate({&input, &size}, {}, {ops::helpers::kResizeMitchellcubic}, {false, true}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results[0];///.at(0); +// result->printBuffer("Mitchell cubic Resized to 7x8"); +// expected.printBuffer("Mitchell cubic Expect for 7x8"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests12, ImageResize_Test8) { + + NDArray input = NDArrayFactory::create('c', {1, 5, 5, 1}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 + }); + auto size = NDArrayFactory::create({7, 8}); + NDArray expected = NDArrayFactory::create('c', {1, 7, 8, 1}, { + 1.f , 1.4375f , 2.0625f , 2.6875f , 3.3125f , 3.9375f , 4.5625f , 5.f , + 3.8571427f, 4.2946424f, 4.9196424f, 5.5446424f, 6.1696424f, 6.7946424f, 7.4196424f, 7.8571424f, + 7.4285717f, 7.8660717f, 8.491072f , 9.116072f , 9.741072f , 10.366072f , 10.991072f , 11.428572f , + 11.f , 11.4375f , 12.0625f , 12.6875f , 13.3125f , 13.9375f , 14.5625f , 15.f , + 14.571429f , 15.008929f, 15.633929f, 16.25893f , 16.88393f , 17.50893f , 18.13393f , 18.57143f , + 18.142857f , 18.580357f, 19.205357f, 19.830357f , 20.455357f , 21.080357f , 21.705357f , 22.142857f , + 21.f , 21.4375f , 22.0625f , 22.6875f , 23.3125f , 23.9375f , 24.5625f , 25.f + }); + + sd::ops::image_resize op; + // resize with bilinear without antialising and aspect ratio preserving + auto results = op.evaluate({&input, &size}, {}, {ops::helpers::kResizeBilinear}, {false, false}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results[0];///.at(0); +// result->printBuffer("Bilinear Resized to 7x8"); +// expected.printBuffer("Bilinear Expect for 7x8"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests12, ImageResize_Test9) { + + NDArray input = NDArrayFactory::create('c', {1, 5, 5, 1}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 + }); + auto size = NDArrayFactory::create({7, 8}); + NDArray expected = NDArrayFactory::create('c', {1, 7, 8, 1}, { + 1.f , 1.4f , 2.f , 2.8f , 3.2f , 4.f , 4.6f , 5.f , + 4.f , 4.4f , 5.f , 5.8f , 6.2f , 7.f , 7.6f , 8.f , + 6.999998f, 7.399998f, 7.999998f, 8.799997f, 9.199997f, 9.999997f, 10.599997f, 10.999996f, + 11.f, 11.399999f, 12.f, 12.799999f, 13.199999f, 13.999998f, 14.599998f, 14.999999f, + 15.f, 15.4f, 16.f, 16.8f, 17.2f, 18.f, 18.6f, 19.f, 17.999989f, + 18.399990f, 18.999989f, 19.799988f, 20.199987f, 20.999989f, 21.599989f, 21.999989f, 21.f, + 21.4f, 22.f, 22.8f, 23.2f, 24.f, 24.6f, 25.f + }); + + sd::ops::image_resize op; + // resize with area without antialising and aspect ratio preserving + auto results = op.evaluate({&input, &size}, {}, {ops::helpers::kResizeArea}, {false, false}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results[0];///.at(0); +// result->printBuffer("Area Resized to 7x8"); +// expected.printBuffer("Area Expect for 7x8"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests12, ImageResize_Test10) { + + NDArray input = NDArrayFactory::create('c', {1, 5, 5, 1}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 + }); + auto size = NDArrayFactory::create({7, 8}); + NDArray expected = NDArrayFactory::create('c', {1, 7, 8, 1}, { + 1, 1, 2, 3, 3, 4, 5, 5, 6, 6, 7, 8, 8, 9, 10, 10, 6, + 6, 7, 8, 8, 9, 10, 10, 11, 11, 12, 13, 13, 14, 15, 15, 16, 16, + 17, 18, 18, 19, 20, 20, 16, 16, 17, 18, 18, 19, 20, 20, 21, 21, 22, + 23, 23, 24, 25, 25 + }); + + sd::ops::image_resize op; + // resize with nearest neigbors without antialising and aspect ratio preserving + auto results = op.evaluate({&input, &size}, {}, {ops::helpers::kResizeNearest}, {false, false}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results[0];///.at(0); +// result->printBuffer("Nearest neighbor Resized to 7x8"); +// expected.printBuffer("Nearest neighbor Expect for 7x8"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +TEST_F(DeclarableOpsTests12, ImageResize_Test11) { + + NDArray input = NDArrayFactory::create('c', {1, 5, 5, 1}, { + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25 + }); + auto size = NDArrayFactory::create({7, 8}); + NDArray expected = NDArrayFactory::create('c', {1, 7, 8, 1}, { + 1, 1, 2, 3, 3, 4, 5, 5, 6, 6, 7, 8, 8, 9, 10, 10, 6, + 6, 7, 8, 8, 9, 10, 10, 11, 11, 12, 13, 13, 14, 15, 15, 16, 16, + 17, 18, 18, 19, 20, 20, 16, 16, 17, 18, 18, 19, 20, 20, 21, 21, 22, + 23, 23, 24, 25, 25 + }); + + sd::ops::image_resize op; + // resize with nearest neigbors without antialising and aspect ratio preserving + auto results = op.evaluate({&input, &size}, {}, {ops::helpers::kResizeNearest}, {false, false}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results[0];///.at(0); +// result->printBuffer("Nearest neighbor Resized to 7x8"); +// expected.printBuffer("Nearest neighbor Expect for 7x8"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, TriangularSolve_Test_1) { + + auto a = NDArrayFactory::create('c', {4, 4}, { + 3.f, 0.f, 0.f, 0.f, + 2.f, 1.f, 0.f, 0.f, + 1.f, 0.f, 1.f, 0.f, + 1.f, 1.f, 1.f, 1.f + }); + + auto b = NDArrayFactory::create('c', {4, 1}, { + 4.f, 2.f, 4.f, 2.f + }); + + auto exp = NDArrayFactory::create('c', {4, 1}, { + 1.333333f, -0.6666667f, 2.6666667f, -1.3333333f }); + + sd::ops::triangular_solve op; + + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + +// z->printIndexedBuffer("TriangularSolve"); + + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, TriangularSolve_Test_2) { + + auto a = NDArrayFactory::create('c', {4, 4}, { + 1.f, 1.f, 1.f, 1.f, + 0.f, 1.f, 1.f, 0.f, + 0.f, 0.f, 2.f, 1.f, + 0.f, 0.f, 0.f, 3.f, + }); + + auto b = NDArrayFactory::create('c', {4, 1}, { + 2.f, 4.f, 2.f, 4.f + }); + + auto exp = NDArrayFactory::create('c', {4, 1}, { + 2.f, 4.f, 1.f, 1.3333333f }); + + sd::ops::triangular_solve op; + + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + +// z->printIndexedBuffer("TriangularSolve"); + + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, TriangularSolve_Test_3) { + + auto a = NDArrayFactory::create('c', {2, 4, 4}, { + 3.f, 0.f, 0.f, 0.f, + 2.f, 1.f, 0.f, 0.f, + 1.f, 0.f, 1.f, 0.f, + 1.f, 1.f, 1.f, 1.f, + + 3.f, 0.f, 0.f, 0.f, + 2.f, 1.f, 0.f, 0.f, + 1.f, 0.f, 1.f, 0.f, + 1.f, 1.f, 1.f, 1.f + }); + + auto b = NDArrayFactory::create('c', {2, 4, 1}, { + 4.f, 2.f, 4.f, 2.f, + 4.f, 2.f, 4.f, 2.f + }); + + auto exp = NDArrayFactory::create('c', {2, 4, 1}, { + 1.333333f, -0.6666667f, 2.6666667f, -1.3333333f, + 1.333333f, -0.6666667f, 2.6666667f, -1.3333333f + }); + + sd::ops::triangular_solve op; + + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + +// z->printIndexedBuffer("TriangularSolve"); + + ASSERT_TRUE(exp.equalsTo(z)); + +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, TriangularSolve_Test_4) { + + auto a = NDArrayFactory::create('c', {4, 4}, { + 1.f, 1.f, 1.f, 1.f, + 0.f, 1.f, 1.f, 0.f, + 0.f, 0.f, 2.f, 1.f, + 0.f, 0.f, 0.f, 3.f, + }); + + auto b = NDArrayFactory::create('c', {4, 1}, { + 2.f, 4.f, 2.f, 4.f + }); + + auto exp = NDArrayFactory::create('c', {4, 1}, { + -3.3333333f, 3.6666666f, 0.333333f, 1.3333333f + }); + + sd::ops::triangular_solve op; + + auto res = op.evaluate({&a, &b}, {false}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + +// z->printIndexedBuffer("TriangularSolve"); + + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, TriangularSolve_Test_5) { + + auto a = NDArrayFactory::create('c', {4, 4}, { + 5.f, 1., -3.f, 3.f, + 0.f, 1.f, 1.f, -1.f, + 0.f, 0.f, 2.f, -9.f, + 0.f, 0.f, 0.f, 4.f + }); + + auto b = NDArrayFactory::create('c', {4, 1}, { + 5.f, 2.f, 0.f, -3.f + }); + + auto exp = NDArrayFactory::create('c', {4, 1}, { + 1.f, 1.f, 1.f, 1.f + }); + + sd::ops::triangular_solve op; + + auto res = op.evaluate({&a, &b}, {false, true}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + +// z->printIndexedBuffer("TriangularSolve with adjoint"); + + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, SolveLs_Test_1) { + + auto a = NDArrayFactory::create('c', {4, 4}, { + 3.f, 0.f, 0.f, 0.f, + 2.f, 1.f, 0.f, 0.f, + 1.f, 0.f, 1.f, 0.f, + 1.f, 1.f, 1.f, 1.f + }); + + auto b = NDArrayFactory::create('c', {4, 1}, { + 4.f, 2.f, 4.f, 2.f + }); + + auto exp = NDArrayFactory::create('c', {4, 1}, { + 1.333333f, -0.6666667f, 2.6666667f, -1.3333333f }); + + sd::ops::lstsq op; + + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + +// z->printIndexedBuffer("MatrixSolveLS"); + MmulHelper::matmul(&a, z, &exp, false, false); + + ASSERT_TRUE(exp.equalsTo(b)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, SolveLs_Test_2) { + + auto a = NDArrayFactory::create('c', {3, 3}, { + 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 11.f, 8.f, 21.f + }); + + auto b = NDArrayFactory::create('c', {3, 1}, { 1.f, 2.f, 3.f }); + + auto exp = NDArrayFactory::create('c', {3, 1}, { -0.24999914f, 0.4999994f, 0.08333314f }); + + sd::ops::lstsq op; + + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + + MmulHelper::matmul(&a, z, &exp, false, false); + +// z->printIndexedBuffer("MatrixSolveLS2"); + + ASSERT_TRUE(exp.equalsTo(b)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, SolveLs_Test_3) { + + auto a = NDArrayFactory::create('c', {3, 4}, { + 1.f,1.f,0.f,0.f,-1.f,1.f,0.f,0.f,1.f,1.f,-1.f,-1.f + }); + + auto b = NDArrayFactory::create('c', {3, 1}, { 1.f, 2.f, 3.f }); + + auto exp = NDArrayFactory::create('c', {3, 1}, { -0.5f, 1.5f, -2.f }); + + sd::ops::lstsq op; + + auto res = op.evaluate({&a, &b}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + +// z->printIndexedBuffer("MatrixSolveLS3"); + MmulHelper::matmul(&a, z, &exp, false, false); + ASSERT_TRUE(exp.equalsTo(b)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, SolveLs_Test_4) { + + auto a = NDArrayFactory::create('c', {3, 4}, { + 1.f,1.f,0.f,0.f,-1.f,1.f,0.f,0.f,1.f,1.f,-1.f,-1.f + }); + + auto b = NDArrayFactory::create('c', {3, 1}, { 1.f, 2.f, 3.f }); + + auto exp = NDArrayFactory::create('c', {4, 1}, { -0.5f, 1.5f, -2.f, 0.f}); + + sd::ops::lstsq op; + + auto res = op.evaluate({&a, &b}, {false}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); +// z->printIndexedBuffer("Output_12.4"); +// z->printShapeInfo("Output_12.4 shape"); +// MmulHelper::matmul(&a, z, &exp, false, false); + +// z->printIndexedBuffer("MatrixSolveLS4"); + + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, SolveLs_Test_5) { + + auto a = NDArrayFactory::create('c', {1, 0, 3, 4}); + auto b = NDArrayFactory::create('c', {1, 0, 3, 1}); + + sd::ops::lstsq op; + + auto res = op.evaluate({&a, &b}, {false}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + ASSERT_TRUE(z->isEmpty()); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, Solve_Test_6) { + + auto a = NDArrayFactory::create('c', {1, 0, 3, 3}); + auto b = NDArrayFactory::create('c', {1, 0, 3, 1}); + + sd::ops::solve op; + + auto res = op.evaluate({&a, &b}, {true}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + ASSERT_TRUE(z->isEmpty()); + + +} + +////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests12, TriangularSolve_Test_6) { + + auto a = NDArrayFactory::create('c', {4, 4}, { + 5.f, 1.f, -3.f, 3.f, + 0.f, 1.f, 1.f, -1.f, + 0.f, 0.f, 2.f, -9.f, + 0.f, 0.f, 0.f, 4.f + }); + + auto b = NDArrayFactory::create('c', {4, 2}, { + 5.f, 1.f, 2.f, 1.f, 0.f, 1.f, -3.f, 1.f + }); + + auto exp = NDArrayFactory::create('c', {4, 2}, { + 1.f,0.2f, 1.f,0.8f, 1.f,0.4f, 1.f,1.2f + }); + + sd::ops::triangular_solve op; + + auto res = op.evaluate({&a, &b}, {}, {}, {false, true}); + ASSERT_EQ(res.status(), ND4J_STATUS_OK); + auto z = res.at(0); + + z->printIndexedBuffer("TriangularSolve with adjoint"); + + ASSERT_TRUE(exp.equalsTo(z)); + +} diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests13.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests13.cpp new file mode 100644 index 000000000..713548a0e --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests13.cpp @@ -0,0 +1,2862 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + + +// +// Created by raver on 8/4/2018. +// + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include + +using namespace sd; + + +class DeclarableOpsTests13 : public testing::Test { +public: + + DeclarableOpsTests13() { + //printf("\n"); + //fflush(stdout); + } +}; + +template +class TypedDeclarableOpsTests13 : public testing::Test { +public: + + TypedDeclarableOpsTests13() { + printf("\n"); + fflush(stdout); + } +}; + +typedef ::testing::Types TestingTypes; +TYPED_TEST_CASE(TypedDeclarableOpsTests13, TestingTypes); + +TEST_F(DeclarableOpsTests13, test_pow_1) { + auto x = NDArrayFactory::create('c', {2, 2}, {2.f, 2.f, 2.f, 2.f}); + auto y = NDArrayFactory::create('c', {2}, {3, 3}); + auto e = NDArrayFactory::create('c', {2, 2}, {8.f, 8.f, 8.f, 8.f}); + + sd::ops::Pow op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_EQ(e, *z); +} + +TEST_F(DeclarableOpsTests13, test_empty_range_1) { + auto start = NDArrayFactory::create(0); + auto limit = NDArrayFactory::create(0); + + sd::ops::range op; + auto result = op.evaluate({&start, &limit}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_TRUE(z->isEmpty()); + + +} + +TEST_F(DeclarableOpsTests13, test_empty_range_2) { + + sd::ops::range op; + auto result = op.evaluate({}, {1.0, 1.0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_TRUE(z->isEmpty()); +} + +TEST_F(DeclarableOpsTests13, test_empty_range_3) { + + sd::ops::range op; + auto result = op.evaluate({}, {1, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_TRUE(z->isEmpty()); +} + +TEST_F(DeclarableOpsTests13, test_argmax_edge_1) { + auto ctx = new Context(1); + auto arr = NDArrayFactory::create_('c', {1024,1}); + + ctx->setInputArray(0, arr, true); + ctx->setOutputArray(0, NDArrayFactory::create_('c', {1}), true); + ctx->setInputArray(1, NDArrayFactory::create_(0), true); //Axis 0 + + + sd::ops::argmax op; + auto result = op.execute(ctx); + ASSERT_EQ(Status::OK(), result); + + //nd4j_printf("Done\n",""); + delete ctx; +} + +TEST_F(DeclarableOpsTests13, test_add_1) { + auto x = NDArrayFactory::create('c', {1, 768}); + auto y = NDArrayFactory::create('c', {768}); + auto e = NDArrayFactory::create('c', {1, 768});; + y. assign(1.0f); + e.assign(1.0f); + + x += y; + + ASSERT_EQ(e, x); +} + +TEST_F(DeclarableOpsTests13, test_listdiff_1) { + auto x = NDArrayFactory::create('c', {4}, {0, 1, 2, 3}); + auto y = NDArrayFactory::create('c', {2}, {3, 1}); + + auto od = NDArrayFactory::create('c', {2}); + auto oi = NDArrayFactory::create('c', {2}); + + sd::ops::listdiff op; + auto result = op.execute({&x, &y}, std::vector{&od, &oi}, {}, {}, {}); + ASSERT_EQ(Status::OK(), result); +} + +TEST_F(DeclarableOpsTests13, test_greater_1) { + auto x = NDArrayFactory::create('c', {3, 1}); + auto y = NDArrayFactory::create('c', {1, 4}); + + sd::ops::greater op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); +} + +TEST_F(DeclarableOpsTests13, test_eval_reduction_shape_1) { + Nd4jLong axis = 0L; + auto x = NDArrayFactory::create('c', {2}, {4, 2}); + auto y = NDArrayFactory::create('c', {1}, {axis}); + auto exp = NDArrayFactory::create('c', {2}, {1, 2}); + + sd::ops::evaluate_reduction_shape op; + auto result = op.evaluate({&x, &y}, {true}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_EQ(exp, *z); +} + +TEST_F(DeclarableOpsTests13, test_or_1) { + + NDArray x('c', {4}, {false, true, false, true}, sd::DataType::BOOL); + NDArray y('c', {4}, {false, false, true, true}, sd::DataType::BOOL); + NDArray e('c', {4}, {false, true, true, true}, sd::DataType::BOOL); + + NDArray z('c', {4}, sd::DataType::BOOL); + + x.applyPairwiseTransform(pairwise::Or, y, z); + + ASSERT_EQ(e, z); +} + +TEST_F(DeclarableOpsTests13, test_and_1) { + auto x = NDArrayFactory::create('c', {4}, {false, true, false, true}); + auto y = NDArrayFactory::create('c', {4}, {false, false, true, true}); + auto e = NDArrayFactory::create('c', {4}, {false, false, false, true}); + + auto z = NDArrayFactory::create('c', {4}); + + x.applyPairwiseTransform(pairwise::And, y, z); + + ASSERT_EQ(e, z); +} + +TEST_F(DeclarableOpsTests13, test_xor_1) { + auto x = NDArrayFactory::create('c', {4}, {false, true, false, true}); + auto y = NDArrayFactory::create('c', {4}, {false, false, true, true}); + auto e = NDArrayFactory::create('c', {4}, {false, true, true, false}); + + auto z = NDArrayFactory::create('c', {4}); + + x.applyPairwiseTransform(pairwise::Xor, y, z); + + ASSERT_EQ(e, z); +} + +TEST_F(DeclarableOpsTests13, BarnesHutTsne_GainsTest_1) { + auto x = NDArrayFactory::create('c', {2,3}, {1,2,3, 4, 5, 6}); + auto y = NDArrayFactory::create('c', {2,3}, {1,-2,3, -4, 5, -6}); + auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); + auto exp = NDArrayFactory::create('c', {2,3}, {1.2,2.2,3.2,4.2,5.2,6.2}); + sd::ops::barnes_gains op; + auto result = op.evaluate({&x, &y, &eps}); + ASSERT_EQ(result.status(), Status::OK()); + //result.at(0)->printBuffer("Gains out"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); +} + +TEST_F(DeclarableOpsTests13, BarnesHutTsne_GainsTest_2) { + auto x = NDArrayFactory::create('c', {2,3}, {1, -2, 3, -4, 5, -6}); + auto y = NDArrayFactory::create('c', {2,3}, {1, -2, 3, -4, 5, -6}); + auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); + auto exp = NDArrayFactory::create('c', {2,3}, {1.2, 0.01, 3.2, 0.01, 5.2, 0.01}); + sd::ops::barnes_gains op; + auto result = op.evaluate({&x, &y, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + //result.at(0)->printBuffer("Gains out"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + //ASSERT_EQ(e, z); +} + +TEST_F(DeclarableOpsTests13, BarnesHutTsne_GainsTest_3) { + auto x = NDArrayFactory::create('c', {2,3}, {-1, 2, -3, 4, -5, 6}); + auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, -0.5, -6}); + auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); + auto exp = NDArrayFactory::create('c', {2,3}, {0.01, 2.2, 0.01, 4.2, 0.01, 6.2}); + sd::ops::barnes_gains op; + auto result = op.evaluate({&x, &y, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + //result.at(0)->printBuffer("Gains out"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + +} + +TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_1) { + auto data = NDArrayFactory::create('c', {5,4}); + auto rows = NDArrayFactory::create('c', {2}, {2, 3}); + auto cols = NDArrayFactory::create('c', {5}, {0, 2, 1, 4, 3}); + auto vals = NDArrayFactory::create('c', {5}, {10., 20., 30., 40., 50.}); + //auto buf = NDArrayFactory::create('c', {4}); + auto exp1 = NDArrayFactory::create('c', {5,4}, {-1.846154, -1.846154, -1.846154, -1.846154, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}); + //auto exp2 = NDArrayFactory::create({-4., -4., -4., -4. + //std::vector exp({&exp1, &exp2}); + data.linspace(1); + +// auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, -0.5, -6}); +// auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); +// auto exp = NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); + sd::ops::barnes_edge_forces op; + auto result = op.evaluate({&rows, &cols, &vals, &data}, {}, {1}); + + + ASSERT_EQ(result.status(), Status::OK()); + //result.at(0)->printBuffer("Output"); + ASSERT_TRUE(exp1.equalsTo(result.at(0))); + +} + +TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_2) { + auto data = NDArrayFactory::create('c', {5,4}); + auto rows = NDArrayFactory::create('c', {3}, {1,2,3}); + auto cols = NDArrayFactory::create('c', {5}, {1, 2, 0, 4, 3}); + auto vals = NDArrayFactory::create('c', {5}, {10., 20., 30., 40., 50.}); + //auto buf = NDArrayFactory::create('c', {4}); + auto exp = NDArrayFactory::create('c', {5,4}, {-0.622568, -0.622568, -0.622568, -0.622568, 1.846154, 1.846154, 1.846154, 1.846154, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}); + //auto exp2 = NDArrayFactory::create({-4., -4., -4., -4. + //std::vector exp({&exp1, &exp2}); + data.linspace(1); + +// auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, -0.5, -6}); +// auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); +// auto exp = NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); + sd::ops::barnes_edge_forces op; + auto result = op.evaluate({&rows, &cols, &vals, &data}, {}, {2}); + + + ASSERT_EQ(result.status(), Status::OK()); + //result.at(0)->printBuffer("Output"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + +} + +TEST_F(DeclarableOpsTests13, BarnesHutTsne_EdgeForceTest_3) { + auto data = NDArrayFactory::create('c', {11, 5}, {0.3, 0.2625, 0.2674, 0.8604, 0.4803, 0.1096, 0.795, 0.5918, 0.2738, 0.952, 0.969, 0.8586, 0.8088, 0.5338, 0.5961, 0.7187, 0.463, 0.0867, 0.7748, 0.4802, 0.2493, 0.3227, 0.3064, 0.698, 0.7977, 0.7674, 0.168, 0.3107, 0.0217, 0.138, 0.8619, 0.8413, 0.5285, 0.9703, 0.6774, 0.2624, 0.4374, 0.1569, 0.1107, 0.0601, 0.4094, 0.9564, 0.5994, 0.8279, 0.3859, 0.6202, 0.7604, 0.0788, 0.0865, 0.7445, 0.6548, 0.3385, 0.0582, 0.6249, 0.7432}); + auto rows = NDArrayFactory::create({0, 9, 18, 27, 36, 45, 54, 63, 72, 81, 90, 99}); + auto cols = NDArrayFactory::create({4, 3, 10, 8, 6, 7, 1, 5, 9, 4, 9, 8, 10, 2, 0, 6, 7, 3, 6, 8, 3, 9, 10, 1, 4, 0, 5, 10, 0, 4, 6, 8, 9, 2, 5, 7, 0, 10, 3, 1, 8, 9, 6, 7, 2, 7, 9, 3, 10, 0, 4, 2, 8, 1, 2, 8, 3, 10, 0, 4, 9, 1, 5, 5, 9, 0, 3, 10, 4, 8, 1, 2, 6, 2, 0, 3, 4, 1, 10, 9, 7, 10, 1, 3, 7, 4, 5, 2, 8, 6, 3, 4, 0, 9, 6, 5, 8, 7, 1}); + auto vals = NDArrayFactory::create({0.6199614579042966, 0.19644097697184246, 0.13824979367331638, 0.01949900138247239, 0.008923198738222747, 0.008392793826291798, 0.0033348224714784204, 0.0026246189757042166, 0.0025733360563748838, 0.5877136110798608, 0.28250257562439585, 0.08098135424273815, 0.014862718272075049, 0.01219187321450782, 0.01152346362368888, 0.004243137936786281, 0.0034626999030188577, 0.0025185661029283168, 0.6777005651521399, 0.18321248222489303, 0.04018202465629351, 0.02941935889988646, 0.02164146250842832, 0.019898422145651618, 0.011683461395713935, 0.008439076090480863, 0.007823146926512332, 0.6770900431883232, 0.16617511239723026, 0.06039349887686468, 0.04650913399744179, 0.016886531410284355, 0.014591049666869658, 0.006407638669806174, 0.006074413005122801, 0.0058725787880570205, 0.6278185083409108, 0.235127797795446, 0.07023700015217448, 0.030885483448633774, 0.01229522088606573, 0.009238279699136107, 0.008219511168822047, 0.004303744819835723, 0.0018744536889749907, 0.7122603898978483, 0.07862620103245824, 0.07061257369349086, 0.06721483653169834, 0.028957853952131768, 0.01778978123182596, 0.01481713955181034, 0.005492728917348627, 0.0042284951913875955, 0.5266844101016999, 0.3304104787383107, 0.10930017433210941, 0.018514917515240075, 0.006969360999637938, 0.0063776901975396, 0.0010590388116165708, 6.526830884629785E-4, 3.1246215383067865E-5, 0.7176179284835663, 0.08741734015883978, 0.05927699083866909, 0.04663169573956976, 0.03287576269194147, 0.02993912340339554, 0.013365238657916641, 0.010616858763291145, 0.002259061262810172, 0.6891905160321706, 0.1397658294110526, 0.05438284759722162, 0.05437184733708826, 0.028683289714498808, 0.020986120697576355, 0.007218358114741088, 0.0032834770669826364, 0.002117714028667893, 0.6823873496503976, 0.1345267083671607, 0.08712863515505885, 0.04286621088946242, 0.02544804597749639, 0.01689343932533317, 0.007219134659004873, 0.0019232929717404616, 0.0016071830043453991, 0.6425809622897437, 0.18474464886441516, 0.10897036475298316, 0.03466939253836615, 0.013288054277817787, 0.005149178177380355, 0.0037974063158903518, 0.0037851733015991287, 0.0030148194818042273}); + //auto buf = NDArrayFactory::create('c', {4}); + auto exp = NDArrayFactory::create('c', {11, 5}, {-0.080205, -0.085862, 0.024045, 0.133551, -0.199896, -0.170597, 0.187301, 0.205824, -0.165268, 0.131228, 0.155135, 0.021446, 0.217583, -0.262873, -0.021075, 0.114537, 0.088023, -0.039205, 0.087984, -0.179565, -0.132683, 0.003677, 0.072081, -0.068737, 0.204481, 0.287223, -0.193989, 0.104569, -0.123401, -0.036368, 0.086745, 0.002961, -0.091327, 0.234853, 0.120270, -0.304006, 0.128305, -0.084867, -0.017550, -0.130837, -0.288569, 0.124679, 0.054078, -0.034187, -0.192599, 0.033196, 0.228182, -0.044972, -0.314217, 0.020287, 0.054427, -0.078887, -0.078246, -0.104543, 0.169803}); + //auto exp2 = NDArrayFactory::create({-4., -4., -4., -4. + //std::vector exp({&exp1, &exp2}); + //data.assign(1.0); //linspace(1); + +// auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, -0.5, -6}); +// auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); +// auto exp = NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); + sd::ops::barnes_edge_forces op; + auto result = op.evaluate({&rows, &cols, &vals, &data}, {}, {11}); + + //nd4j_printf("rows %lld, cols %lld, vals %lld, res full %lld\n", rows.lengthOf(), cols.lengthOf(), vals.lengthOf(), exp1.lengthOf()); + ASSERT_EQ(result.status(), Status::OK()); + //result.at(0)->printBuffer("Output"); + //exp.printBuffer("Expect"); + //result.at(0)->printShapeInfo("Shape output"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + +} + +TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_1) { +// auto data = NDArrayFactory::create('c', {5,4}); + auto rows = NDArrayFactory::create('c', {2}, {0, 1}); + auto cols = NDArrayFactory::create('c', {4}, {0, 1, 1, 0}); + auto vals = NDArrayFactory::create('c', {4}, {20., 30., 40., 50.}); + auto exp = NDArrayFactory::create('c', {1,1}, {20.}); +// data.linspace(1); + +// auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, -0.5, -6}); +// auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); +// auto exp = NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); + sd::ops::barnes_symmetrized op; + auto result = op.evaluate({&rows, &cols, &vals}, {}, {1}); + ASSERT_EQ(result.status(), Status::OK()); + //result.at(2)->printBuffer("Symmetrized1"); + ASSERT_TRUE(exp.equalsTo(result.at(2))); +} + +TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_2) { + auto rows = NDArrayFactory::create('c', {4}, {0, 2, 2, 3}); + auto cols = NDArrayFactory::create('c', {8}, {0, 1, 1, 0, 0, 1, 1, 1}); + auto vals = NDArrayFactory::create('c', {8}, {20., 30., 40., 50., 120., 130., 140., 150.}); + auto exp = NDArrayFactory::create('c', {1,5}, {20., 15., 15., 20., 20.}); +// data.linspace(1); + +// auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, -0.5, -6}); +// auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); +// auto exp = NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); + sd::ops::barnes_symmetrized op; + auto result = op.evaluate({&rows, &cols, &vals}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + //result.at(2)->printBuffer("Symmetrized2"); + // ASSERT_TRUE(exp[i]->equalsTo(result.at(i))); + ASSERT_TRUE(exp.equalsTo(result.at(2))); + +} + +TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_3) { + auto rows = NDArrayFactory::create('c', {12}, {0, 2, 3, 5, 7, 8, 9, 11, 12, 14, 18, 21}); + auto cols = NDArrayFactory::create('c', {24}, {0, 1, 2, 3, 4, 5, 4, 3, 2, 1, 0, 1, 0, 2, 4, 3, 2, 1, 0, 1, 2, 3, 4, 5}); + auto vals = NDArrayFactory::create('c', {24}, {20., 30., 40., 50., 120., 130., 140., 150.,220., 230., 240., 250., 2120., 2130., 2140., 2150., 320., 330., 340., 350., 3120., 3130., 3140., 3150.}); + auto exp = NDArrayFactory::create('c', {1, 39}, {15.000000, 0.000000, 0.000000, 65.000000, 60.000000, 145.000000, 20.000000, 25.000000, 65.000000, 145.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000}); +// data.linspace(1); + +// auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, -0.5, -6}); +// auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); +// auto exp = NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); + sd::ops::barnes_symmetrized op; + auto result = op.evaluate({&rows, &cols, &vals}, {}, {11}); + ASSERT_EQ(result.status(), Status::OK()); + //result.at(2)->printBuffer("Symmetrized3"); + //exp.printBuffer("EXPect symm3"); + // ASSERT_TRUE(exp[i]->equalsTo(result.at(i))); + //ASSERT_TRUE(exp.equalsTo(result.at(0))); + +} + +TEST_F(DeclarableOpsTests13, BarnesHutTsne_symmetrized_4) { + auto rows = NDArrayFactory::create({0, 9, 18, 27, 36, 45, 54, 63, 72, 81, 90, 99}); + auto cols = NDArrayFactory::create({4, 3, 10, 8, 6, 7, 1, 5, 9, 4, 9, 8, 10, 2, 0, 6, 7, 3, 6, 8, 3, 9, 10, 1, 4, 0, 5, 10, 0, 4, 6, 8, 9, 2, 5, 7, 0, 10, 3, 1, 8, 9, 6, 7, 2, 7, 9, 3, 10, 0, 4, 2, 8, 1, 2, 8, 3, 10, 0, 4, 9, 1, 5, 5, 9, 0, 3, 10, 4, 8, 1, 2, 6, 2, 0, 3, 4, 1, 10, 9, 7, 10, 1, 3, 7, 4, 5, 2, 8, 6, 3, 4, 0, 9, 6, 5, 8, 7, 1}); + auto vals = NDArrayFactory::create( {0.6200, 0.1964, 0.1382, 0.0195, 0.0089, 0.0084, 0.0033, 0.0026, 0.0026, 0.5877, 0.2825, 0.0810, 0.0149, 0.0122, 0.0115, 0.0042, 0.0035, 0.0025, 0.6777, 0.1832, 0.0402, 0.0294, 0.0216, 0.0199, 0.0117, 0.0084, 0.0078, 0.6771, 0.1662, 0.0604, 0.0465, 0.0169, 0.0146, 0.0064, 0.0061, 0.0059, 0.6278, 0.2351, 0.0702, 0.0309, 0.0123, 0.0092, 0.0082, 0.0043, 0.0019, 0.7123, 0.0786, 0.0706, 0.0672, 0.0290, 0.0178, 0.0148, 0.0055, 0.0042, 0.5267, 0.3304, 0.1093, 0.0185, 0.0070, 0.0064, 0.0011, 0.0007, 3.1246e-5, 0.7176, 0.0874, 0.0593, 0.0466, 0.0329, 0.0299, 0.0134, 0.0106, 0.0023, 0.6892, 0.1398, 0.0544, 0.0544, 0.0287, 0.0210, 0.0072, 0.0033, 0.0021, 0.6824, 0.1345, 0.0871, 0.0429, 0.0254, 0.0169, 0.0072, 0.0019, 0.0016, 0.6426, 0.1847, 0.1090, 0.0347, 0.0133, 0.0051, 0.0038, 0.0038, 0.0030}); + //auto exp = NDArrayFactory::create('c', {1, 39}, {15.000000, 0.000000, 0.000000, 65.000000, 60.000000, 145.000000, 20.000000, 25.000000, 65.000000, 145.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000}); +// data.linspace(1); + auto exp4 = NDArrayFactory::create('c', {1, 108}, {0.6239, 0.1813, 0.1236, 0.03695, 0.00795, 0.03385, 0.0074, 0.0158, 0.0013, 0.0042, 0.0074, 0.3093, 0.2085, 0.051, 0.00895, 0.01605, 0.00245, 0.00705, 0.00125, 0.0021, 0.01605, 0.6022, 0.1615, 0.0233, + 0.0183, 0.0108, 0.0068, 0.0042, 0.0113, 0.00115, 0.1813, 0.00125, 0.0233, 0.65985, 0.0653, 0.0779, 0.03565, 0.05085, 0.03835, 0.02625, 0.6239, 0.3093, 0.0068, 0.0653, 0.2099, 0.0205, 0.0173, 0.0073, + 0.0171, 0.0089, 0.0158, 0.0113, 0.03835, 0.71495, 0.04775, 0.03615, 0.0089, 0.00275, 0.0021, 1.5623E-5, 0.00795, 0.00245, 0.6022, 0.0779, 0.0073, 0.5098, 0.0159, 0.00135, 1.5623E-5, 0.03385, 0.00705, + 0.02625, 0.0171, 0.71495, 0.06515, 0.01835, 0.00775, 0.00115, 0.03695, 0.051, 0.1615, 0.03565, 0.0205, 0.00275, 0.5098, 0.00775, 0.0055, 0.0026, 0.0013, 0.2085, 0.0183, 0.05085, 0.0173, 0.04775, + 0.00135, 0.06515, 0.0026, 0.35855, 0.1236, 0.00895, 0.0108, 0.65985, 0.2099, 0.03615, 0.0159, 0.01835, 0.0055, 0.35855}); +// auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, -0.5, -6}); +// auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); +// auto exp = NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); + sd::ops::barnes_symmetrized op; + auto result = op.evaluate({&rows, &cols, &vals}, {}, {11}); + ASSERT_EQ(result.status(), Status::OK()); + auto res = result.at(2); + // res->printBuffer("Symmetrized4"); + // exp4.printBuffer("Expected sym"); + // nd4j_printf("Total res is {1, %lld}\n", res->lengthOf()); + // nd4j_printf("Expected is {1, %lld}\n", exp4.lengthOf()); + + //exp.printBuffer("EXPect symm3"); + // ASSERT_TRUE(exp[i]->equalsTo(result.at(i))); + ASSERT_TRUE(exp4.equalsTo(res)); + +} + +TEST_F(DeclarableOpsTests13, CellContains_test_1) { + + auto corners = NDArrayFactory::create( {0.5384, 0.5640, 0.3449, 0.5257, 0.5505}); + auto width = NDArrayFactory::create({0.4306, 0.3960, 0.4639, 0.5040, 0.4904}); + auto point = NDArrayFactory::create({0.3000, 0.2625, 0.2674, 0.8604, 0.4803}); + //auto exp = NDArrayFactory::create('c', {1, 39}, {15.000000, 0.000000, 0.000000, 65.000000, 60.000000, 145.000000, 20.000000, 25.000000, 65.000000, 145.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000}); + // data.linspace(1); + + // auto y = NDArrayFactory::create('c', {2,3}, {-0.1,-2,3, -4, -0.5, -6}); + // auto eps = NDArrayFactory::create('c', {2,3}, {-0.1, 0.2, -0.3, 0.4, -0.5, 0.6}); + // auto exp = NDArrayFactory::create('c', {2,3}, {1, 2, 1, 2, 2, 2}); + sd::ops::cell_contains op; + auto result = op.evaluate({&corners, &width, &point}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(result.at(0)->e(0)); + //result.at(2)->printBuffer("Symmetrized3"); + //exp.printBuffer("EXPect symm3"); + // ASSERT_TRUE(exp[i]->equalsTo(result.at(i))); + //ASSERT_TRUE(exp.equalsTo(result.at(0))); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, adjustHue_1) { + + NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, sd::DataType::FLOAT32); + NDArray factor = NDArrayFactory::create(0.5); + NDArray exp ('c', {2,2,3}, {100,0,44, 208,5,220, 177,230,97, 2,255,244}, sd::DataType::FLOAT32); + + + sd::ops::adjust_hue op; + auto results (op.evaluate({&input, &factor}, {}, {2})); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result.printIndexedBuffer(); + + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, adjustHue_2) { + + NDArray input('c', { 2,2,3 }, { 0.f,100.f / 255.f,56.f / 255.f, 17.f / 255.f,220.f / 255.f,5.f / 255.f, 150.f / 255.f,97.f / 255.f,230.f / 255.f, 255.f / 255.f,2.f / 255.f,13.f / 255.f }, sd::DataType::FLOAT32); + NDArray exp('c', { 2,2,3 }, { 4.f / 255.f,100.f / 255.f,0.f, 146.f / 255.f,220.f / 255.f,5.f / 255.f, 97.f / 255.f,123.8f / 255.f,230.f / 255.f, 255.f / 255.f,2.f / 255.f,164.8f / 255.f }, sd::DataType::FLOAT32); + + + sd::ops::adjust_hue op; + auto results(op.evaluate({&input}, {0.9}, {2})); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); +} + + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, adjustHue_3) { + + NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, sd::DataType::FLOAT32); + NDArray exp ('c', {2,2,3}, {0.,84.,100., 5.,220.,122.0001, 229.8,97.,230., 255.,142.8002,2.}, sd::DataType::FLOAT32); + + sd::ops::adjust_hue op; + auto results(op.evaluate({&input}, {-0.9}, {2})); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, adjustHue_4) { + + NDArray input('c', {2,3,2}, {0,17, 100,220, 56,5, 150,255, 97,2, 230,13}, sd::DataType::FLOAT32); + NDArray exp ('c', {2,3,2}, {100,208, 0,5, 44,220, 177,2, 230,255, 97,244}, sd::DataType::FLOAT32); + + sd::ops::adjust_hue op; + auto results(op.evaluate({&input}, {0.5}, {1})); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, adjustHue_5) { + + NDArray input('c', {3,2,2}, {0,17, 150,255, 100,220, 97,2, 56,5, 230,13}, sd::DataType::FLOAT32); + NDArray exp ('c', {3,2,2}, {100,208, 177,2, 0,5, 230,255, 44,220, 97,244}, sd::DataType::FLOAT32); + + sd::ops::adjust_hue op; + auto results(op.evaluate({&input}, {0.5}, {0})); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, adjustSaturation_1) { + + NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, sd::DataType::FLOAT32); + NDArray factor = NDArrayFactory::create(0.5); + NDArray exp ('c', {2,2,3}, {50,100,78, 118.5,220,112.5, 190,163.5,230, 255,128.5,134}, sd::DataType::FLOAT32); + + sd::ops::adjust_saturation op; + auto results = op.evaluate({&input, &factor}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, adjustSaturation_2) { + + NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, sd::DataType::DOUBLE); + NDArray exp ('c', {2,2,3}, {0.,100.,56., 12.279087,220.,0., 91.654228,0.,230., 255.,0.,11.087015}, sd::DataType::DOUBLE); + + sd::ops::adjust_saturation op; + auto results = op.evaluate({&input}, {10}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); +// result.printIndexedBuffer("Result2"); +// exp.printIndexedBuffer("Expect2"); + + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, adjustSaturation_3) { + + NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, sd::DataType::FLOAT32); + NDArray exp ('c', {2,2,3}, {100.,100.,100., 220.,220.,220., 230.,230.,230., 255., 255., 255.}, sd::DataType::FLOAT32); + + sd::ops::adjust_saturation op; + auto results = op.evaluate({&input}, {-10}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); + +} + + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, adjustSaturation_4) { + + NDArray input('c', {2,3,2}, {0,17, 100,220, 56,5, 150,255, 97,2, 230,13}, sd::DataType::FLOAT32); + NDArray exp ('c', {2,3,2}, {50,118.5, 100,220, 78,112.5, 190,255, 163.5,128.5, 230,134}, sd::DataType::FLOAT32); + + sd::ops::adjust_saturation op; + auto results = op.evaluate({&input}, {0.5}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + // result.printIndexedBuffer(); + + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, adjustSaturation_5) { + + NDArray input('c', {3,2,2}, {0,17, 150,255, 100,220, 97,2, 56,5, 230,13}, sd::DataType::FLOAT32); + NDArray exp ('c', {3,2,2}, {50,118.5, 190,255, 100,220, 163.5,128.5, 78,112.5, 230,134}, sd::DataType::FLOAT32); + + sd::ops::adjust_saturation op; + auto results = op.evaluate({&input}, {0.5}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto result = results.at(0); + + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); + +} + + +TEST_F(DeclarableOpsTests13, shift_bits_1) { + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create(4); + auto e = x.ulike(); + x.assign(32); + e.assign(512); + + sd::ops::shift_bits op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_EQ(e, *z); + +} + +TEST_F(DeclarableOpsTests13, rshift_bits_1) { + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create(4); + auto e = x.ulike(); + x.assign(512); + e.assign(32); + + sd::ops::rshift_bits op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_EQ(e, *z); + +} + +TEST_F(DeclarableOpsTests13, cyclic_shift_bits_1) { + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create(4); + auto e = x.ulike(); + x.assign(32); + e.assign(512); + + sd::ops::cyclic_shift_bits op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_EQ(e, *z); + +} + +TEST_F(DeclarableOpsTests13, cyclic_rshift_bits_1) { + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create(4); + auto e = x.ulike(); + x.assign(512); + e.assign(32); + + sd::ops::cyclic_rshift_bits op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_EQ(e, *z); + +} + +TEST_F(DeclarableOpsTests13, shift_bits_2) { + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create('c', {5}); + auto e = x.ulike(); + x.assign(32); + y.assign(4); + e.assign(512); + + sd::ops::shift_bits op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_EQ(e, *z); + +} + +TEST_F(DeclarableOpsTests13, rshift_bits_2) { + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create('c', {5}); + auto e = x.ulike(); + x.assign(512); + y.assign(4); + e.assign(32); + + sd::ops::rshift_bits op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_EQ(e, *z); + + +} + +TEST_F(DeclarableOpsTests13, cyclic_shift_bits_2) { + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create('c', {5}); + auto e = x.ulike(); + x.assign(32); + y.assign(4); + e.assign(512); + + sd::ops::cyclic_shift_bits op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_EQ(e, *z); + +} + +TEST_F(DeclarableOpsTests13, cyclic_rshift_bits_2) { + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create('c', {5}); + auto e = x.ulike(); + x.assign(512); + y.assign(4); + e.assign(32); + + sd::ops::cyclic_rshift_bits op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_EQ(e, *z); + +} +TEST_F(DeclarableOpsTests13, shift_bits_3) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {1, 5}); + auto e = x.ulike(); + x.assign(32); + y.assign(4); + e.assign(512); + + sd::ops::shift_bits op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_EQ(e, *z); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, space_to_batch_nd_1) { + + NDArray x('c', {1, 2, 2, 2, 3}, sd::DataType::FLOAT32); + NDArray blockShape('c', {3}, {2, 2, 2} , sd::DataType::INT32); // three spatial dimensions + NDArray paddings('c', {3, 2}, std::vector{0, 0, 0, 0, 0, 0} , sd::DataType::INT32); + + NDArray exp('c', {8, 1, 1, 1, 3}, sd::DataType::FLOAT32); + + x.linspace(1); + exp.linspace(1); + + sd::ops::space_to_batch_nd op; + auto result = op.evaluate({&x, &blockShape, &paddings}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, space_to_batch_nd_2) { + + NDArray x('c', {2, 2,4,3, 1}, sd::DataType::FLOAT32); + NDArray blockShape('c', {3}, {2, 2, 3} , sd::DataType::INT32); // three spatial dimensions + NDArray paddings('c', {3, 2}, {0,0, 0,2, 2,1} , sd::DataType::INT32); + + NDArray exp('c', {24, 1,3,2, 1}, { 0, 2, 0, 8, 0, 0, 0, 26, 0, 32, 0, 0, 0, 3, 0, 9, 0, 0, 0, 27, 0, 33, 0, 0, 1, + 0, 7, 0, 0, 0, 25, 0, 31, 0, 0, 0, 0, 5, 0, 11, 0, 0, 0, 29, 0, 35, 0, 0, 0, 6, + 0, 12, 0, 0, 0, 30, 0, 36, 0, 0, 4, 0, 10, 0, 0, 0, 28, 0, 34, 0, 0, 0, 0, 14, + 0, 20, 0, 0, 0, 38, 0, 44, 0, 0, 0, 15, 0, 21, 0, 0, 0, 39, 0, 45, 0, 0, 13, 0, + 19, 0, 0, 0, 37, 0, 43, 0, 0, 0, 0, 17, 0, 23, 0, 0, 0, 41, 0, 47, 0, 0, 0, 18, + 0, 24, 0, 0, 0, 42, 0, 48, 0, 0, 16, 0, 22, 0, 0, 0, 40, 0, 46, 0, 0, 0}, sd::DataType::FLOAT32); + x.linspace(1); + + sd::ops::space_to_batch_nd op; + auto result = op.evaluate({&x, &blockShape, &paddings}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + // z->printBuffer(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, space_to_batch_nd_3) { + + NDArray x('c', {2, 2,4,3, 1}, sd::DataType::FLOAT32); + NDArray blockShape('c', {3}, {2, 2, 3} , sd::DataType::INT32); // three spatial dimensions + NDArray paddings('c', {3, 2}, {1,1, 0,2, 2,1} , sd::DataType::INT32); + + NDArray exp('c', {24, 2,3,2, 1}, { 0, 0, 0, 0, 0, 0, 0, 14, 0, 20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 38, 0, 44, 0, 0, 0, 0, 0, 0, 0, 0, 0, 15, + 0, 21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 39, 0, 45, 0, 0, 0, 0, 0, 0, 0, 0, 13, 0, 19, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 37, 0, 43, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 17, 0, 23, 0, 0, 0, 0, 0, 0, 0, 0, 0, 41, 0, 47, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 18, 0, 24, 0, 0, 0, 0, 0, 0, 0, 0, 0, 42, 0, 48, 0, 0, 0, 0, 0, 0, 0, 0, 16, 0, + 22, 0, 0, 0, 0, 0, 0, 0, 0, 0, 40, 0, 46, 0, 0, 0, 0, 2, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 26, 0, 32, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 9, 0, 0, 0, 0, 0, 0, 0, 0, 0, 27, 0, 33, 0, 0, 0, 0, 0, 0, 0, 0, 1, + 0, 7, 0, 0, 0, 0, 0, 0, 0, 0, 0, 25, 0, 31, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 11, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 29, 0, 35, 0, 0, 0, 0, 0, 0, 0, 0, 0, 6, 0, 12, 0, 0, 0, 0, 0, 0, 0, 0, 0, 30, 0, 36, 0, 0, + 0, 0, 0, 0, 0, 0, 4, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 28, 0, 34, 0, 0, 0, 0, 0, 0, 0, 0, 0}, sd::DataType::FLOAT32); + x.linspace(1); + + sd::ops::space_to_batch_nd op; + auto result = op.evaluate({&x, &blockShape, &paddings}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + // z->printBuffer(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batch_to_space_nd_1) { + + NDArray x('c', {8, 1, 1, 1, 3}, sd::DataType::FLOAT32); + + NDArray blockShape('c', {3}, {2., 2, 2} , sd::DataType::INT32); // three spatial dimensions + NDArray crop('c', {3, 2}, {0., 0, 0, 0, 0, 0} , sd::DataType::INT32); + + NDArray exp('c', {1, 2, 2, 2, 3}, sd::DataType::FLOAT32); + + x.linspace(1); + exp.linspace(1); + + sd::ops::batch_to_space_nd op; + auto result = op.evaluate({&x, &blockShape, &crop}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batch_to_space_nd_2) { + + NDArray x('c', {24, 1,3,2, 1}, sd::DataType::FLOAT32); + NDArray blockShape('c', {3}, {2, 2, 3} , sd::DataType::INT32); // three spatial dimensions + NDArray crop('c', {3, 2}, {0,0, 0,2, 2,1} , sd::DataType::INT32); + + NDArray exp('c', {2, 2,4,3, 1}, {25, 2, 14, 61, 38, 50, 27, 4, 16, 63, 40, 52, 97, 74, 86, 133, 110, 122, 99, 76, 88, 135, 112, 124, + 31, 8, 20, 67, 44, 56, 33, 10, 22, 69, 46, 58, 103, 80, 92, 139, 116, 128, 105, 82, 94, 141, 118, 130}, sd::DataType::FLOAT32); + x.linspace(1); + + sd::ops::batch_to_space_nd op; + auto result = op.evaluate({&x, &blockShape, &crop}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + // z->printBuffer(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batch_to_space_nd_3) { + + NDArray x('c', {24, 2,3,2, 1}, sd::DataType::FLOAT32); + NDArray blockShape('c', {3}, {2, 2, 3} , sd::DataType::INT32); // three spatial dimensions + NDArray crop('c', {3, 2}, {1,1, 0,2, 2,1} , sd::DataType::INT32); + + NDArray exp('c', {2, 2,4,3, 1}, {193, 146, 170, 265, 218, 242, 195, 148, 172, 267, 220, 244, 55, 8, 32, 127, 80, 104, 57, 10, 34, 129, 82, + 106, 205, 158, 182, 277, 230, 254, 207, 160, 184, 279, 232, 256, 67, 20, 44, 139, 92, 116, 69, 22, 46, 141, 94, 118}, sd::DataType::FLOAT32); + x.linspace(1); + + sd::ops::batch_to_space_nd op; + auto result = op.evaluate({&x, &blockShape, &crop}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + // z->printBuffer(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, mergemax_1) { + + NDArray x1('c', {5, 5}, sd::DataType::FLOAT32); + NDArray x2('c', {5, 5}, sd::DataType::FLOAT32); + NDArray x3('c', {5, 5}, sd::DataType::FLOAT32); + NDArray e('c', {5, 5}, sd::DataType::FLOAT32); + x1.assign(3); + x2.assign(1); + x3.assign(2); + e.assign(3); + + + sd::ops::mergemax op; + auto result = op.evaluate({&x1, &x2, &x3}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + // z->printBuffer(); + + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, mergemax_2) { + + NDArray x1('c', {1, 3}, {0., 1, 2}, sd::DataType::FLOAT32); + NDArray x2('c', {1, 1}, std::vector{1.}, sd::DataType::FLOAT32); + NDArray out('c', {1, 3}, {-1., -1, -1}, sd::DataType::FLOAT32); + + sd::ops::mergemax op; + auto status = op.execute({&x1, &x2}, {&out}, {}, {}, {}); + + ASSERT_EQ(20, status); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, mergemax_bp_1) { + + NDArray x1('c', { 5, 5 }, sd::DataType::FLOAT32); + NDArray x2('c', { 5, 5 }, sd::DataType::FLOAT32); + NDArray x3('c', { 5, 5 }, sd::DataType::FLOAT32); + NDArray grad('c', { 5, 5 }, sd::DataType::FLOAT32); + + x1.assign(3); + x2.assign(1); + x3.assign(2); + grad.linspace(.1, .1); + + + sd::ops::mergemax_bp op; + auto result = op.evaluate({ &x1, &x2, &x3, &grad }, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(3, result.size()); + + auto z = result.at(0); + + ASSERT_TRUE(grad.isSameShape(z)); + ASSERT_TRUE(grad.equalsTo(z)); + +} +///////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, mergemax_bp_2) { + + NDArray x1('c', { 2, 5 }, { 1,2,3,4,5,4,3,2,1,0 }, sd::DataType::FLOAT32); + NDArray x2('c', { 2, 5 }, { 0,1,2,3,4,5,6,7,8,9 }, sd::DataType::FLOAT32); + NDArray x3('c', { 2, 5 }, { 0,1,1,2,3,4,7,5,8,10 }, sd::DataType::FLOAT32); + NDArray grad('c', { 2, 5 }, sd::DataType::FLOAT32); + + grad.linspace(.1, .1); + + NDArray exp1('c', { 2, 5 }, { 0.1, 0.2, 0.3, 0.4, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0 }, sd::DataType::FLOAT32); + NDArray exp2('c', { 2, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0, 0.6, 0.0, 0.8, 0.9, 0.0 }, sd::DataType::FLOAT32); + NDArray exp3('c', { 2, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7, 0.0, 0.0, 1.0 }, sd::DataType::FLOAT32); + + sd::ops::mergemax_bp op; + auto result = op.evaluate({ &x1, &x2, &x3, &grad }, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(3, result.size()); + + auto z1 = result.at(0); + auto z2 = result.at(1); + auto z3 = result.at(2); + + ASSERT_TRUE(exp1.isSameShape(z1)); + ASSERT_TRUE(exp1.equalsTo(z1)); + ASSERT_TRUE(exp2.isSameShape(z2)); + ASSERT_TRUE(exp2.equalsTo(z2)); + ASSERT_TRUE(exp3.isSameShape(z3)); + ASSERT_TRUE(exp3.equalsTo(z3)); + +} +///////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, mergemax_bp_3) { + + NDArray x1C('c', { 2, 5 }, { 1,2,3,4,5,4,3,2,1,0 }, sd::DataType::FLOAT32); + NDArray x2C('c', { 2, 5 }, { 0,1,2,3,4,5,6,7,8,9 }, sd::DataType::FLOAT32); + NDArray x3C('c', { 2, 5 }, { 0,1,1,2,3,4,7,5,8,10 }, sd::DataType::FLOAT32); + NDArray grad('c', { 2, 5 }, sd::DataType::FLOAT32); + + grad.linspace(.1, .1); + + NDArray x1('f', { 2, 5 }, sd::DataType::FLOAT32); + NDArray x2('f', { 2, 5 }, sd::DataType::FLOAT32); + NDArray x3('f', { 2, 5 }, sd::DataType::FLOAT32); + + NDArray exp1C('c', { 2, 5 }, { 0.1, 0.2, 0.3, 0.4, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0 }, sd::DataType::FLOAT32); + NDArray exp2C('c', { 2, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0, 0.6, 0.0, 0.8, 0.9, 0.0 }, sd::DataType::FLOAT32); + NDArray exp3C('c', { 2, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7, 0.0, 0.0, 1.0 }, sd::DataType::FLOAT32); + + NDArray exp1('f', { 2, 5 }, sd::DataType::FLOAT32); + NDArray exp2('f', { 2, 5 }, sd::DataType::FLOAT32); + NDArray exp3('f', { 2, 5 }, sd::DataType::FLOAT32); + + x1.assign(x1C); + x2.assign(x2C); + x3.assign(x3C); + + exp1.assign(exp1C); + exp2.assign(exp2C); + exp3.assign(exp3C); + + sd::ops::mergemax_bp op; + auto result = op.evaluate({ &x1, &x2, &x3, &grad }, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(3, result.size()); + + auto z1 = result.at(0); + auto z2 = result.at(1); + auto z3 = result.at(2); + + ASSERT_TRUE(exp1.isSameShape(z1)); + ASSERT_TRUE(exp1.equalsTo(z1)); + ASSERT_TRUE(exp2.isSameShape(z2)); + ASSERT_TRUE(exp2.equalsTo(z2)); + ASSERT_TRUE(exp3.isSameShape(z3)); + ASSERT_TRUE(exp3.equalsTo(z3)); + +} +///////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, mergeadd_bp_1) { + + NDArray x1('c', { 5, 5 }, sd::DataType::FLOAT32); + NDArray x2('c', { 5, 5 }, sd::DataType::FLOAT32); + NDArray x3('c', { 5, 5 }, sd::DataType::FLOAT32); + NDArray grad('c', { 5, 5 }, sd::DataType::FLOAT32); + + x1.assign(3); + x2.assign(1); + x3.assign(2); + grad.linspace(.1, .1); + + sd::ops::mergeadd_bp op; + auto result = op.evaluate({ &x1, &x2, &x3, &grad }, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(3, result.size()); + + for (int i = 0; i < 3; i++) { + auto z = result.at(0); + ASSERT_TRUE(grad.isSameShape(z)); + ASSERT_TRUE(grad.equalsTo(z)); + } +} +///////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, mergeavg_bp_1) { + + NDArray x1('c', { 5, 5 }, sd::DataType::FLOAT32); + NDArray x2('c', { 5, 5 }, sd::DataType::FLOAT32); + NDArray x3('c', { 5, 5 }, sd::DataType::FLOAT32); + NDArray grad('c', { 5, 5 }, sd::DataType::FLOAT32); + + x1.assign(3); + x2.assign(1); + x3.assign(2); + grad.linspace(.1, .1); + + sd::ops::mergeavg_bp op; + auto result = op.evaluate({ &x1, &x2, &x3, &grad }, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(3, result.size()); + + grad.applyScalar(sd::scalar::Divide, 3, grad); + + for (int i = 0; i < 3; i++) { + auto z = result.at(i); + ASSERT_TRUE(grad.isSameShape(z)); + ASSERT_TRUE(grad.equalsTo(z)); + } + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_1) { + + const int sL = 5; + const int bS = 3; + const int nIn = 3; + const int nOut = 3; + + // input arguments + + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 0; // forward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = false; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); + NDArray b('c', {4*nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003; + Wr = 0.006; + b = 0.5; + hI = 1.; + cI = 2.; + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + auto expH = NDArrayFactory::create('c', {sL, bS, nOut}, {0.57574f, 0.57574f, 0.57574f, 0.58006f, 0.58006f, 0.58006f, 0.58434f, 0.58434f, 0.58434f, + 0.55114f, 0.55114f, 0.55114f, 0.55732f, 0.55732f, 0.55732f, 0.56338f, 0.56338f, 0.56338f, + 0.53763f, 0.53763f, 0.53763f, 0.54534f, 0.54534f, 0.54534f, 0.55287f, 0.55287f, 0.55287f, + 0.53626f, 0.53626f, 0.53626f, 0.54487f, 0.54487f, 0.54487f, 0.55327f, 0.55327f, 0.55327f, + 0.54484f, 0.54484f, 0.54484f, 0.55379f, 0.55379f, 0.55379f, 0.5625f, 0.5625f, 0.5625f}); + + auto expClast = NDArrayFactory::create('c', {bS, nOut}, {1.1589154f, 1.1589154f, 1.1589154f, 1.1892855f, 1.1892855f, 1.1892855f, 1.219861f, 1.219861f, 1.219861f}); + + sd::ops::lstmLayer op; + auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *h = results.at(0); + auto *cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expClast.isSameShape(cL)); + ASSERT_TRUE(expClast.equalsTo(cL)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_2) { + + const int sL = 5; + const int bS = 3; + const int nIn = 3; + const int nOut = 3; + + // input arguments + + const int dataFormat = 1; // [bS,sL,nIn] + const int directionMode = 0; // forward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = false; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {bS, sL, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); + NDArray b('c', {4*nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003; + Wr = 0.006; + b = 0.5; + hI = 1.; + cI = 2.; + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + auto expH = NDArrayFactory::create('c', {bS, sL, nOut}, {0.575735f, 0.575735f, 0.575735f, 0.541562f, 0.541562f, 0.541562f, 0.514003f, 0.514003f, 0.514003f, 0.495597f, 0.495597f, 0.495597f, 0.485999f, 0.485999f, 0.485999f, + 0.596965f, 0.596965f, 0.596965f, 0.571978f, 0.571978f, 0.571978f, 0.552888f, 0.552888f, 0.552888f, 0.540606f, 0.540606f, 0.540606f, 0.534764f, 0.534764f, 0.534764f, + 0.61725f, 0.61725f, 0.61725f, 0.599828f, 0.599828f, 0.599828f, 0.587627f, 0.587627f, 0.587627f, 0.580408f, 0.580408f, 0.580408f, 0.577735f, 0.577735f, 0.577735f}); + + auto expClast = NDArrayFactory::create('c', {bS, nOut}, {0.996965f, 0.996965f, 0.996965f, 1.146756f, 1.146756f, 1.146756f, 1.301922f, 1.301922f, 1.301922f}); + + sd::ops::lstmLayer op; + auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *h = results.at(0); + auto *cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expClast.isSameShape(cL)); + ASSERT_TRUE(expClast.equalsTo(cL)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_3) { + + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 1; // backward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = false; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL,bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); + NDArray b('c', {4*nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003; + Wr = 0.006; + b = 0.5; + hI = 1.; + cI = 2.; + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH('c', {sL, bS, nOut}, {0.493883f, 0.493883f, 0.493883f, 0.510990f, 0.510990f, 0.510990f, 0.534701f, 0.534701f, 0.534701f, 0.549139f, + 0.549139f, 0.549139f, 0.571900f, 0.571900f, 0.571900f, 0.583561f, 0.583561f, 0.583561f, 0.605106f, 0.605106f, + 0.605106f, 0.614114f, 0.614114f, 0.614114f, 0.635354f, 0.635354f, 0.635354f, 0.642045f, 0.642045f, 0.642045f}, sd::DataType::FLOAT32); + + NDArray expHL('c', {bS, nOut}, {0.493883f, 0.493883f, 0.493883f, 0.510990f, 0.510990f, 0.510990f}, sd::DataType::FLOAT32); + NDArray expCL('c', {bS, nOut}, {1.061274f, 1.061274f, 1.061274f, 1.115888f, 1.115888f, 1.115888f}, sd::DataType::FLOAT32); + + sd::ops::lstmLayer op; + auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hL = results.at(1); + auto cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + +} + + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_4) { + + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 3; // bidirectional concat + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = false; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {2,nIn, 4*nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {2,nOut, 4*nOut}, sd::DataType::FLOAT32); + NDArray b('c', {2,4*nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {2,bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {2,bS, nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx({0,1, 0,0, 0,0}) = 0.003f; + Wx({1,2, 0,0, 0,0}) = -0.003f; + Wr({0,1, 0,0, 0,0}) = 0.006f; + Wr({1,2, 0,0, 0,0}) = -0.006f; + b({0,1, 0,0}) = 0.5f; + b({1,2, 0,0}) = -0.5f; + hI({0,1, 0,0, 0,0}) = 1; + hI({1,2, 0,0, 0,0}) = -1; + cI({0,1, 0,0, 0,0}) = 2; + cI({1,2, 0,0, 0,0}) = -2; + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH('c', {sL, bS, 2 * nOut}, { + 0.577661f, 0.577661f, 0.577661f, -0.107642f, -0.107642f, -0.107642f, 0.585289f, 0.585289f, 0.585289f, + -0.106937f, -0.106937f, -0.106937f, 0.556517f, 0.556517f, 0.556517f, -0.111647f, -0.111647f, -0.111647f, + 0.567274f, 0.567274f, 0.567274f, -0.110214f, -0.110214f, -0.110214f, 0.547395f, 0.547395f, 0.547395f, + -0.123305f, -0.123305f, -0.123305f, 0.560640f, 0.560640f, 0.560640f, -0.120862f, -0.120862f, -0.120862f, + 0.550714f, 0.550714f, 0.550714f, -0.156223f, -0.156223f, -0.156223f, 0.565308f, 0.565308f, 0.565308f, + -0.152313f, -0.152313f, -0.152313f, 0.563741f, 0.563741f, 0.563741f, -0.234128f, -0.234128f, -0.234128f, + 0.578676f, 0.578676f, 0.578676f, -0.228917f, -0.228917f, -0.228917f}, sd::DataType::FLOAT32); + + NDArray expHL('c', {2,bS, nOut}, {0.563741f, 0.563741f, 0.563741f, 0.578676f, 0.578676f, 0.578676f, -0.107642f, + -0.107642f, -0.107642f, -0.106937f, -0.106937f, -0.106937f}, sd::DataType::FLOAT32); + NDArray expCL('c', {2,bS, nOut}, {1.217757f, 1.217757f, 1.217757f, 1.272398f, 1.272398f, 1.272398f, -0.295768f, + -0.295768f, -0.295768f, -0.298453f, -0.298453f, -0.298453f}, sd::DataType::FLOAT32); + + sd::ops::lstmLayer op; + auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hL = results.at(1); + auto cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_5) { + + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 1; // [bS,sL,nIn] + const int directionMode = 3; // bidirectional concat + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = false; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {bS, sL, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {2,nIn, 4*nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {2,nOut, 4*nOut}, sd::DataType::FLOAT32); + NDArray b('c', {2,4*nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {2,bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {2,bS, nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx({0,1, 0,0, 0,0}) = 0.003; + Wx({1,2, 0,0, 0,0}) = -0.003; + Wr({0,1, 0,0, 0,0}) = 0.006; + Wr({1,2, 0,0, 0,0}) = -0.006; + b({0,1, 0,0}) = 0.5; + b({1,2, 0,0}) = -0.5; + hI({0,1, 0,0, 0,0}) = 1; + hI({1,2, 0,0, 0,0}) = -1; + cI({0,1, 0,0, 0,0}) = 2; + cI({1,2, 0,0, 0,0}) = -2; + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH('c', {bS, sL, 2*nOut}, { + 0.577661f, 0.577661f, 0.577661f, -0.107659f, -0.107659f, -0.107659f, 0.548099f, 0.548099f, 0.548099f, -0.113406f, -0.113406f, -0.113406f, + 0.526881f, 0.526881f, 0.526881f, -0.12883f, -0.12883f, -0.12883f, 0.515882f, 0.515882f, 0.515882f, -0.16868f, -0.16868f, -0.16868f, + 0.51409f, 0.51409f, 0.51409f, -0.255185f, -0.255185f, -0.255185f, 0.614599f, 0.614599f, 0.614599f, -0.102739f, -0.102739f, -0.102739f, + 0.599572f, 0.599572f, 0.599572f, -0.105802f, -0.105802f, -0.105802f, 0.591089f, 0.591089f, 0.591089f, -0.116681f, -0.116681f, -0.116681f, + 0.588694f, 0.588694f, 0.588694f, -0.149201f, -0.149201f, -0.149201f, 0.591492f, 0.591492f, 0.591492f, -0.228917f, -0.228917f, -0.228917f}, sd::DataType::FLOAT32); + + NDArray expHL('c', {2,bS, nOut}, {0.51409f, 0.51409f, 0.51409f, 0.591492f, 0.591492f, 0.591492f, + -0.107659f, -0.107659f, -0.107659f, -0.102739f, -0.102739f, -0.102739f}, sd::DataType::FLOAT32); + NDArray expCL('c', {2,bS, nOut}, {1.07293f , 1.07293f , 1.07293f, 1.346609f, 1.346609f, 1.346609f, + -0.295811f, -0.295811f, -0.295811f, -0.305394f, -0.305394f, -0.305394f}, sd::DataType::FLOAT32); + + sd::ops::lstmLayer op; + auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hL = results.at(1); + auto cL = results.at(2); + + // h->printBuffer(); + // hL->printBuffer(); + // cL->printBuffer(); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_6) { + + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 2; // bidirectional sum + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = false; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {2,nIn, 4*nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {2,nOut, 4*nOut}, sd::DataType::FLOAT32); + NDArray b('c', {2,4*nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {2,bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {2,bS, nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx({0,1, 0,0, 0,0}) = 0.003f; + Wx({1,2, 0,0, 0,0}) = -0.003f; + Wr({0,1, 0,0, 0,0}) = 0.006f; + Wr({1,2, 0,0, 0,0}) = -0.006f; + b({0,1, 0,0}) = 0.5f; + b({1,2, 0,0}) = -0.5f; + hI({0,1, 0,0, 0,0}) = 1; + hI({1,2, 0,0, 0,0}) = -1; + cI({0,1, 0,0, 0,0}) = 2; + cI({1,2, 0,0, 0,0}) = -2; + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH('c', {sL, bS, nOut}, { + 0.470019f, 0.470019f, 0.470019f, 0.478352f, 0.478352f, 0.478352f, 0.444871f, 0.444871f, 0.444871f, 0.457060f, + 0.457060f, 0.457060f, 0.424090f, 0.424090f, 0.424090f, 0.439778f, 0.439778f, 0.439778f, 0.394491f, 0.394491f, + 0.394491f, 0.412995f, 0.412995f, 0.412995f, 0.329613f, 0.329613f, 0.329613f, 0.349760f, 0.349760f, 0.349760f}, sd::DataType::FLOAT32); + + NDArray expHL('c', {2,bS, nOut}, {0.563741f, 0.563741f, 0.563741f, 0.578676f, 0.578676f, 0.578676f, + -0.107642f, -0.107642f, -0.107642f, -0.106937f, -0.106937f, -0.106937f}, + sd::DataType::FLOAT32); + NDArray expCL('c', {2,bS, nOut}, {1.217757f, 1.217757f, 1.217757f, 1.272398f, 1.272398f, 1.272398f, + -0.295768f, -0.295768f, -0.295768f, -0.298453f, -0.298453f, -0.298453f}, + sd::DataType::FLOAT32); + + sd::ops::lstmLayer op; + auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hL = results.at(1); + auto cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_7) { + #ifndef HAVE_MKLDNN + + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 0; // forward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); + NDArray b('c', {4*nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray Wp('c', {3*nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003; + Wr = 0.006; + b = 0.5; + hI = 1.; + cI = 2.; + Wp = -0.05; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH('c', {sL, bS, nOut}, {0.55533 , 0.55533 , 0.55533 , 0.562925, 0.562925, 0.562925, 0.531795, 0.531795, 0.531795, 0.542556, + 0.542556, 0.542556, 0.521466, 0.521466, 0.521466, 0.534638, 0.534638, 0.534638, 0.524805, 0.524805, + 0.524805, 0.539187, 0.539187, 0.539187, 0.538309, 0.538309, 0.538309, 0.552923, 0.552923, 0.552923}, sd::DataType::FLOAT32); + + NDArray expHL('c', {bS, nOut}, {0.538309, 0.538309, 0.538309,0.552923, 0.552923, 0.552923}, sd::DataType::FLOAT32); + NDArray expCL('c', {bS, nOut}, {1.147089, 1.147089, 1.147089,1.197228, 1.197228, 1.197228}, sd::DataType::FLOAT32); + + sd::ops::lstmLayer op; + auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hL = results.at(1); + auto cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + + + #endif +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_8) { + #ifndef HAVE_MKLDNN + + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 1; // backward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 1.; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); + NDArray b('c', {4*nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray Wp('c', {3*nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003; + Wr = 0.006; + b = 0.5; + hI = 1.; + cI = 2.; + Wp = -0.05; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH('c', {sL, bS, nOut}, { + 0.436221f, 0.436221f, 0.436221f, 0.450573f, 0.450573f, 0.450573f, 0.463602f, 0.463602f, 0.463602f, 0.474674f, 0.474674f, 0.474674f, + 0.484039f, 0.484039f, 0.484039f, 0.490679f, 0.490679f, 0.490679f, 0.494871f, 0.494871f, 0.494871f, 0.499028f, 0.499028f, 0.499028f, + 0.504649f, 0.504649f, 0.504649f, 0.508719f, 0.508719f, 0.508719f}, sd::DataType::FLOAT32); + + NDArray expHL('c', {bS, nOut}, {0.436221f, 0.436221f, 0.436221f, 0.450573f, 0.450573f, 0.450573f}, sd::DataType::FLOAT32); + NDArray expCL('c', {bS, nOut}, {0.879804f, 0.879804f, 0.879804f, 0.914666f, 0.914666f, 0.914666f}, sd::DataType::FLOAT32); + + sd::ops::lstmLayer op; + auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hL = results.at(1); + auto cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + + + #endif +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_9) { + #ifndef HAVE_MKLDNN + + const int sL = 5; + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 3; // bidirectional concat + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {2,nIn, 4*nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {2,nOut, 4*nOut}, sd::DataType::FLOAT32); + NDArray b('c', {2,4*nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {2,bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {2,bS, nOut}, sd::DataType::FLOAT32); + NDArray Wp('c', {2,3*nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx({0,1, 0,0, 0,0}) = 0.003; + Wx({1,2, 0,0, 0,0}) = -0.003; + Wr({0,1, 0,0, 0,0}) = 0.006; + Wr({1,2, 0,0, 0,0}) = -0.006; + b({0,1, 0,0}) = 0.5; + b({1,2, 0,0}) = -0.5; + hI({0,1, 0,0, 0,0}) = 1; + hI({1,2, 0,0, 0,0}) = -1; + cI({0,1, 0,0, 0,0}) = 2; + cI({1,2, 0,0, 0,0}) = -2; + Wp({0,1, 0,0}) = -0.05; + Wp({1,2, 0,0}) = 0.05; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH('c', {sL, bS, 2*nOut}, { + 0.55533f, 0.55533f, 0.55533f, -0.104502f, -0.104502f, -0.104502f, 0.562925f, 0.562925f, 0.562925f, -0.103843f, -0.103843f, -0.103843f, + 0.531795f, 0.531795f, 0.531795f, -0.107456f, -0.107456f, -0.107456f, 0.542556f, 0.542556f, 0.542556f, -0.106139f, -0.106139f, -0.106139f, + 0.521466f, 0.521466f, 0.521466f, -0.11681f, -0.11681f, -0.11681f, 0.534638f, 0.534638f, 0.534638f, -0.11458f, -0.11458f, -0.11458f, + 0.524805f, 0.524805f, 0.524805f, -0.145177f, -0.145177f, -0.145177f, 0.539187f, 0.539187f, 0.539187f, -0.14157f, -0.14157f, -0.14157f, + 0.538309f, 0.538309f, 0.538309f, -0.218056f, -0.218056f, -0.218056f, 0.552923f, 0.552923f, 0.552923f, -0.213068f, -0.213068f, -0.213068f}, sd::DataType::FLOAT32); + + NDArray expHL('c', {2,bS, nOut}, {0.538309f, 0.538309f, 0.538309f, 0.552923f, 0.552923f, 0.552923f, -0.104502f, -0.104502f, -0.104502f, + -0.103843f, -0.103843f, -0.103843f}, sd::DataType::FLOAT32); + NDArray expCL('c', {2,bS, nOut}, {1.147089f, 1.147089f, 1.147089f, 1.197228f, 1.197228f, 1.197228f, -0.289425f, -0.289425f, -0.289425f, + -0.292174f, -0.292174f, -0.292174f}, sd::DataType::FLOAT32); + + sd::ops::lstmLayer op; + auto results = op.evaluate({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hL = results.at(1); + auto cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + + + #endif +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_10) { + #ifndef HAVE_MKLDNN + + const int sL = 6; + const int bS = 5; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 0; // forward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); + NDArray b('c', {4*nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray seqLen('c', {bS}, {0,1,2,3,5}, sd::DataType::FLOAT32); + NDArray Wp('c', {3*nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003; + Wr = 0.006; + b = 0.5; + hI = 1.; + cI = 2.; + Wp = -0.05; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH('c', {sL, bS, nOut}, { + 0.f, 0.f, 0.f, 0.562925f, 0.562925f, 0.562925f, 0.570404f, 0.570404f, 0.570404f, 0.57777f, + 0.57777f, 0.57777f, 0.585023f, 0.585023f, 0.585023f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.576568f, 0.576568f, 0.576568f, 0.586163f, 0.586163f, 0.586163f, 0.595462f, 0.595462f, 0.595462f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.611224f, + 0.611224f, 0.611224f, 0.621298f, 0.621298f, 0.621298f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.655858f, 0.655858f, 0.655858f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.692315f, 0.692315f, 0.692315f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}, + sd::DataType::FLOAT32); + + NDArray expHL('c', {bS, nOut}, {0.f, 0.f, 0.f, 0.562925f, 0.562925f, 0.562925f, 0.576568f, 0.576568f, 0.576568f, 0.611224f, 0.611224f, 0.611224f, 0.692315f, 0.692315f, 0.692315f}, sd::DataType::FLOAT32); + NDArray expCL('c', {bS, nOut}, {0.f, 0.f, 0.f, 1.534275f, 1.534275f, 1.534275f, 1.40183f, 1.40183f, 1.40183f, 1.449675f, 1.449675f, 1.449675f, 1.767702f, 1.767702f, 1.767702f}, sd::DataType::FLOAT32); + + sd::ops::lstmLayer op; + auto results = op.evaluate({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hL = results.at(1); + auto cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + + + #endif +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_11) { + #ifndef HAVE_MKLDNN + + const int sL = 6; + const int bS = 5; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 1; // backward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); + NDArray b('c', {4*nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray seqLen('c', {bS}, {0,1,2,3,5}, sd::DataType::FLOAT32); + NDArray Wp('c', {3*nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx = 0.003f; + Wr = 0.006f; + b = 0.5f; + hI = 1.f; + cI = 2.f; + Wp = -0.05f; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH('c', {sL, bS, nOut}, { + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.61209f, + 0.61209f, 0.61209f,0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.652042f, 0.652042f, 0.652042f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.677708f, 0.677708f, 0.677708f, 0.684177f, 0.684177f, 0.684177f, 0.f, 0.f, 0.f,0.f, 0.f, 0.f, 0.699627f, 0.699627f, + 0.699627f, 0.705371f, 0.705371f, 0.705371f, 0.710989f, 0.710989f, 0.710989f, 0., 0., 0., 0.719014, 0.719014, 0.719014, 0.724087, + 0.724087f, 0.724087f, 0.729084f, 0.729084f, 0.729084f, 0.734004f, 0.734004f, 0.734004f }, sd::DataType::FLOAT32); + + NDArray expHL('c', {bS, nOut}, {0.f, 0.f, 0.f, 0.719014f, 0.719014f, 0.719014f, 0.699627f, 0.699627f, 0.699627f, 0.677708f, 0.677708f, 0.677708f, 0.61209f, 0.61209f, 0.61209f}, sd::DataType::FLOAT32); + NDArray expCL('c', {bS, nOut}, {0.f, 0.f, 0.f, 2.092814f, 2.092814f, 2.092814f, 2.08832f, 2.08832f, 2.08832f, 2.009851f, 2.009851f, 2.009851f, 1.646034f, 1.646034f, 1.646034f}, sd::DataType::FLOAT32); + + sd::ops::lstmLayer op; + auto results = op.evaluate({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hL = results.at(1); + auto cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + + + #endif +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_12) { + #ifndef HAVE_MKLDNN + + const int sL = 6; + const int bS = 5; + const int nIn = 4; + const int nOut = 3; + + // input arguments + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 3; // bidirectional concat + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = true; // do not return output at last time step + const auto retLastC = true; // return cells state at last time step + + const double cellClip = 0; // do not apply clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {2,nIn, 4*nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {2,nOut, 4*nOut}, sd::DataType::FLOAT32); + NDArray b('c', {2,4*nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {2,bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {2,bS, nOut}, sd::DataType::FLOAT32); + NDArray seqLen('c', {bS}, {0,1,2,3,5}, sd::DataType::FLOAT32); + NDArray Wp('c', {2,3*nOut}, sd::DataType::FLOAT32); + + x.linspace(0.5, 0.5); + Wx({0,1, 0,0, 0,0}) = 0.003f; + Wx({1,2, 0,0, 0,0}) = -0.003f; + Wr({0,1, 0,0, 0,0}) = 0.006f; + Wr({1,2, 0,0, 0,0}) = -0.006f; + b({0,1, 0,0}) = 0.5f; + b({1,2, 0,0}) = -0.5f; + hI({0,1, 0,0, 0,0}) = 1; + hI({1,2, 0,0, 0,0}) = -1; + cI({0,1, 0,0, 0,0}) = 2; + cI({1,2, 0,0, 0,0}) = -2; + Wp({0,1, 0,0}) = -0.05f; + Wp({1,2, 0,0}) = 0.05f; + + std::initializer_list tArgs = {cellClip}; + std::initializer_list iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::initializer_list bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + NDArray expH('c', {sL, bS, 2*nOut}, {0., 0., 0., 0., 0., 0., 0.562925, 0.562925, 0.562925, -0.25361 , -0.25361 , -0.25361 , 0.570404, 0.570404, 0.570404, -0.157103, + -0.157103, -0.157103, 0.57777 , 0.57777 , 0.57777 , -0.116502, -0.116502, -0.116502,0.585023, 0.585023, 0.585023, -0.100025, + -0.100025, -0.100025, 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., 0.576568, 0.576568, 0.576568, -0.223072, -0.223072, -0.223072, + 0.586163, 0.586163, 0.586163, -0.135714, -0.135714, -0.135714,0.595462, 0.595462, 0.595462, -0.094438, -0.094438, -0.094438, + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.611224, 0.611224, 0.611224, -0.193473, -0.193473, -0.193473, + 0.621298, 0.621298, 0.621298, -0.090626, -0.090626, -0.090626, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0.655858, 0.655858, 0.655858, -0.098015, -0.098015, -0.098015, 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.692315, 0.692315, 0.692315, -0.143704, -0.143704, -0.143704, 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}, sd::DataType::FLOAT32); + + NDArray expHL('c', {2,bS, nOut}, {0.f, 0.f, 0.f, 0.562925f, 0.562925f, 0.562925f, 0.576568f, 0.576568f, 0.576568f, 0.611224f, 0.611224f, 0.611224f, 0.692315f, 0.692315f, 0.692315f, + 0.f, 0.f, 0.f, -0.25361f, -0.25361f, -0.25361f, -0.157103f, -0.157103f, -0.157103f, -0.116502f, -0.116502f, -0.116502f, -0.100025f, -0.100025f, -0.100025f}, sd::DataType::FLOAT32); + NDArray expCL('c', {2,bS, nOut}, {0.f, 0.f, 0.f, 1.534275f, 1.534275f, 1.534275f, 1.40183f, 1.40183f, 1.40183f, 1.449675f, 1.449675f, 1.449675f, 1.767702f, 1.767702f, 1.767702f, + 0.f, 0.f, 0.f, -0.86636f, -0.86636f, -0.86636f, -0.470245f, -0.470245f, -0.470245f, -0.341856f, -0.341856f, -0.341856f, -0.294986f, -0.294986f, -0.294986f}, sd::DataType::FLOAT32); + + sd::ops::lstmLayer op; + auto results = op.evaluate({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hL = results.at(1); + auto cL = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expHL.isSameShape(hL)); + ASSERT_TRUE(expHL.equalsTo(hL)); + + ASSERT_TRUE(expCL.isSameShape(cL)); + ASSERT_TRUE(expCL.equalsTo(cL)); + + #endif +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_test1) { + + NDArray input ('c', {2,4}, sd::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5f, 0.7f, 0.9f, 1.1f}, sd::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, sd::DataType::FLOAT32); + NDArray beta ('c', {4}, {10.f, 20.f, -10.f, -20.f}, sd::DataType::FLOAT32); + + NDArray expected('c', {2,4}, {11.61218734f, 18.52390321f, -8.67185076f, -21.28716864f, 10.93337162f, 19.14541765f, -9.26213931f, -20.71509369f}, sd::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + + sd::ops::batchnorm op; + + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto output = results.at(0); + // output->printBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests13, batchnorm_test2) { + + auto input = NDArrayFactory::create('c', {2,3,4}); + auto mean = NDArrayFactory::create('c', {4}); + auto variance = NDArrayFactory::create('c', {4}); + auto gamma = NDArrayFactory::create('c', {4}); + auto beta = NDArrayFactory::create('c', {4}); + + auto expected = NDArrayFactory::create('c', {2,3,4}, {-0.52733537f, -0.35763144f, -0.18792751f, -0.01822358f, 0.15148035f, 0.32118428f, 0.49088821f, 0.66059214f, 0.83029607f, 1.f, 1.16970393f, 1.33940786f, + 1.50911179f, 1.67881572f, 1.84851965f, 2.01822358f, 2.18792751f, 2.35763144f, 2.52733537f, 2.6970393f, 2.86674323f, 3.03644717f, 3.2061511f, 3.37585503f}); + + input.linspace(0.1, 0.1); + mean.assign(1.); + variance.assign(0.5); + gamma.assign(1.2); + beta.assign(1.); + + sd::ops::batchnorm op; + + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto output = results.at(0); + // output->printBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +//////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests13, batchnorm_test3) { + + auto input = NDArrayFactory::create('c', {2,3,4}); + auto mean = NDArrayFactory::create('c', {3}, {1.05f, 1.1f, 1.15f}); + auto variance = NDArrayFactory::create('c', {3}, {0.5f, 0.6f, 0.7f}); + auto gamma = NDArrayFactory::create('c', {3}, {1.2f, 1.3f, 1.4f}); + auto beta = NDArrayFactory::create('c', {3}, {0.1f, 0.2f, 0.3f}); + + auto expected = NDArrayFactory::create('c', {2,3,4}, {-1.51218734f, -1.34248341f, -1.17277948f, -1.00307555f, -0.80696728f, -0.6391394f, -0.47131152f, -0.30348364f, -0.11832703f, 0.04900378f, 0.21633459f, 0.38366541f, + 0.52425983f, 0.69396376f, 0.86366769f, 1.03337162f, 1.20696728f, 1.37479516f, 1.54262304f, 1.71045092f, 1.8896427f, 2.05697351f, 2.22430432f, 2.39163513f}); + + input.linspace(0.1, 0.1); + + sd::ops::batchnorm op; + + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +//////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests13, batchnorm_test4) { + + auto input = NDArrayFactory::create('c', {2,3,4}); + auto mean = NDArrayFactory::create('c', {2,1,4}, {1.05f, 1.1f, 1.15f, 1.2f, 1.25f, 1.3f, 1.35f, 1.4f}); + auto variance = NDArrayFactory::create('c', {2,1,4}, {0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f}); + auto gamma = NDArrayFactory::create('c', {2,1,4}, {1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f, 1.9f}); + auto beta = NDArrayFactory::create('c', {2,1,4}, {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.66f, 0.7f, 0.8f}); + + auto expected = NDArrayFactory::create('c', {2,3,4}, {-1.51218734f, -1.31045092f, -1.12231189f, -0.9416324f, -0.83337162f, -0.6391394f, -0.45298865f, -0.2708162f, -0.1545559f, 0.03217212f, 0.21633459f, 0.4f, + 0.58432694f, 0.82999915f, 0.95743373f, 1.14688951f, 1.25894242f, 1.50999575f, 1.64392367f, 1.84066852f, 1.93355791f, 2.18999235f, 2.33041362f, 2.53444754f}); + + input.linspace(0.1, 0.1); + + sd::ops::batchnorm op; + + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,0,2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_test5) { + + NDArray input ('c', {2,4,2,2}, sd::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5f, 0.7f, 0.9f, 1.1f}, sd::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, sd::DataType::FLOAT32); + NDArray beta ('c', {4}, {10.f, 20.f, -10.f, -20.f}, sd::DataType::FLOAT32); + + NDArray expected('c', {2,4,2,2}, { 11.612187f, 11.442483f, 11.272779f, 11.103076f, 18.990039f, 19.145418f, 19.300796f, 19.456175f, -9.557284f, -9.704856f, -9.852428f, -10.f, -20.f, + -19.856981f, -19.713963f, -19.570944f, 8.896924f, 8.727221f, 8.557517f, 8.387813f, 21.476097f, 21.631475f, 21.786854f, 21.942233f, -11.918438f, + -12.06601f, -12.213582f, -12.361154f, -17.7117f, -17.568681f, -17.425663f, -17.282644f}, sd::DataType::FLOAT32); + input.linspace(0.1, 0.1); + + sd::ops::batchnorm op; + + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1, 1, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto output = results.at(0); + // output->printBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_test6) { + + NDArray input ('c', {2,2,2,4}, sd::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05f, 1.15f, 1.2f, 1.3f}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5f, 0.7f, 0.9, 1.1f}, sd::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2f, 1.3f, -1.4f, 1.5f}, sd::DataType::FLOAT32); + NDArray beta ('c', {4}, {10.f, 20.f, -10.f, -20.f}, sd::DataType::FLOAT32); + + NDArray expected('c', {2,2,2,4}, {11.612187f, 18.523903f, -8.671851f, -21.287169f, 10.933372f, 19.145418f, -9.262139f, -20.715094f, 10.254556f, 19.766932f, -9.852428f, -20.143019f, 9.57574f, + 20.388447f, -10.442716f, -19.570944f, 8.896924f, 21.009961f, -11.033005f, -18.998869f, 8.218109f, 21.631475f, -11.623294f, -18.426794f, 7.539293f, 22.25299f, + -12.213582f, -17.854719f, 6.860477f, 22.874504f, -12.803871f, -17.282644f}, sd::DataType::FLOAT32); + input.linspace(0.1, 0.1); + + sd::ops::batchnorm op; + + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1,3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_test7) { + + NDArray input1('c', {3,3,15,15}, sd::DataType::FLOAT32); + NDArray input2('c', {3,15,15,3}, sd::DataType::FLOAT32); + input2.permutei({0,3,1,2}); + + NDArray mean ('c', {3}, {0., 0, 0}, sd::DataType::FLOAT32); + NDArray variance('c', {3}, {1., 1, 1}, sd::DataType::FLOAT32); + NDArray gamma ('c', {3}, {1., 1, 1}, sd::DataType::FLOAT32); + NDArray beta ('c', {3}, {0., 0, 0}, sd::DataType::FLOAT32); + + NDArray out1('c', {3,3,15,15}, sd::DataType::FLOAT32); + NDArray out2('c', {3,3,15,15}, sd::DataType::FLOAT32); + + input1.linspace(-1012, 1); + input2.assign(input1); + + sd::ops::batchnorm op; + + auto res1 = op.execute({&input1, &mean, &variance, &gamma, &beta}, {&out1}, {1e-5}, {1,1,1}, {}); + ASSERT_EQ(ND4J_STATUS_OK, res1); + + auto res2 = op.execute({&input2, &mean, &variance, &gamma, &beta}, {&out2}, {1e-5}, {1,1,1}, {}); + ASSERT_EQ(ND4J_STATUS_OK, res2); + + ASSERT_TRUE(out1.equalsTo(out2)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_test8) { + + NDArray input('c', {2,3,4,5}, sd::DataType::FLOAT32); + + NDArray mean ('c', {1,3,4,5}, sd::DataType::FLOAT32); + NDArray variance('c', {1,3,4,5}, sd::DataType::FLOAT32); + NDArray gamma ('c', {1,3,4,5}, sd::DataType::FLOAT32); + NDArray beta ('c', {1,3,4,5}, sd::DataType::FLOAT32); + + NDArray expected('c', {2,3,4,5}, {-105.019394, -103.322357, -101.625313, -99.928276, -98.231239, -96.534195, -94.837158, -93.140121, -91.443077, -89.746040, -88.049004, -86.351959, -84.654922, + -82.957886, -81.260841, -79.563805, -77.866768, -76.169724, -74.472687, -72.775650, -71.078606, -69.381569, -67.684532, -65.987488, -64.290451, -62.593414, + -60.896374, -59.199333, -57.502296, -55.805256, -54.108215, -52.411179, -50.714138, -49.017097, -47.320061, -45.623020, -43.925980, -42.228943, -40.531902, + -38.834862, -37.137825, -35.440784, -33.743744, -32.046707, -30.349667, -28.652628, -26.955589, -25.258549, -23.561510, -21.864471, -20.167431, -18.470392, + -16.773354, -15.076314, -13.379274, -11.682236, -9.985196, -8.288157, -6.591118, -4.894078, -3.197039, -1.500000, 0.197039, 1.894078, 3.591118, 5.288157, + 6.985196, 8.682236, 10.379274, 12.076314, 13.773354, 15.470392, 17.167431, 18.864471, 20.561510, 22.258549, 23.955589, 25.652628, 27.349667, 29.046707, 30.743744, + 32.440784, 34.137825, 35.834862, 37.531902, 39.228943, 40.925980, 42.623020, 44.320061, 46.017097, 47.714138, 49.411179, 51.108215, 52.805256, 54.502296, 56.199333, + 57.896374, 59.593414, 61.290451, 62.987488, 64.684532, 66.381569, 68.078606, 69.775650, 71.472687, 73.169724, 74.866768, 76.563805, 78.260841, 79.957886, 81.654922, + 83.351959, 85.049004, 86.746040, 88.443077, 90.140121, 91.837158, 93.534195, 95.231239, 96.928276}, sd::DataType::FLOAT32); + + input.linspace(-60, 1); + mean.assign(1.); + variance.assign(0.5); + gamma.assign(1.2); + beta.assign(-1.5); + + sd::ops::batchnorm op; + + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1, 1,2,3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShape(*output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_test9) { + + NDArray input('c', {2,3,3,3,3}, sd::DataType::FLOAT32); + + NDArray mean ('c', {1,3,3,3,3}, sd::DataType::FLOAT32); + NDArray variance('c', {1,3,3,3,3}, sd::DataType::FLOAT32); + NDArray gamma ('c', {1,3,3,3,3}, sd::DataType::FLOAT32); + NDArray beta ('c', {1,3,3,3,3}, sd::DataType::FLOAT32); + + NDArray expected('c', {2,3,3,3,3}, {-138.960175, -137.263138, -135.566101, -133.869064, -132.172028, -130.474976, -128.777954, -127.080902, -125.383865, -123.686829, -121.989784, -120.292747, + -118.595711, -116.898666, -115.201630, -113.504593, -111.807549, -110.110512, -108.413475, -106.716431, -105.019394, -103.322357, -101.625313, -99.928276, + -98.231239, -96.534195, -94.837158, -93.140121, -91.443077, -89.746040, -88.049004, -86.351959, -84.654922, -82.957886, -81.260841, -79.563805, -77.866768, + -76.169724, -74.472687, -72.775650, -71.078606, -69.381569, -67.684532, -65.987488, -64.290451, -62.593414, -60.896374, -59.199333, -57.502296, -55.805256, + -54.108215, -52.411179, -50.714138, -49.017097, -47.320061, -45.623020, -43.925980, -42.228943, -40.531902, -38.834862, -37.137825, -35.440784, -33.743744, + -32.046707, -30.349667, -28.652628, -26.955589, -25.258549, -23.561510, -21.864471, -20.167431, -18.470392, -16.773354, -15.076314, -13.379274, -11.682236, + -9.985196, -8.288157, -6.591118, -4.894078, -3.197039, -1.500000, 0.197039, 1.894078, 3.591118, 5.288157, 6.985196, 8.682236, 10.379274, 12.076314, 13.773354, + 15.470392, 17.167431, 18.864471, 20.561510, 22.258549, 23.955589, 25.652628, 27.349667, 29.046707, 30.743744, 32.440784, 34.137825, 35.834862, 37.531902, 39.228943, + 40.925980, 42.623020, 44.320061, 46.017097, 47.714138, 49.411179, 51.108215, 52.805256, 54.502296, 56.199333, 57.896374, 59.593414, 61.290451, 62.987488, 64.684532, + 66.381569, 68.078606, 69.775650, 71.472687, 73.169724, 74.866768, 76.563805, 78.260841, 79.957886, 81.654922, 83.351959, 85.049004, 86.746040, 88.443077, 90.140121, + 91.837158, 93.534195, 95.231239, 96.928276, 98.625313, 100.322357, 102.019394, 103.716431, 105.413475, 107.110512, 108.807549, 110.504593, 112.201630, 113.898666, + 115.595711, 117.292747, 118.989784, 120.686829, 122.383865, 124.080902, 125.777946, 127.474976, 129.172028, 130.869064, 132.566101, 134.263138}, sd::DataType::FLOAT32); + + input.linspace(-80, 1); + mean.assign(1.); + variance.assign(0.5); + gamma.assign(1.2); + beta.assign(-1.5); + + sd::ops::batchnorm op; + + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta}, {1e-5}, {1,1, 1,2,3,4}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto output = results.at(0); + // output->printBuffer(); + + ASSERT_TRUE(expected.isSameShape(*output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test1) { + + NDArray input ('c', {2,3,4}, sd::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.1, 1.2, 1.3, 1.4}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, sd::DataType::FLOAT32); + NDArray gamma ('c', {4}, sd::DataType::FLOAT32); + NDArray beta ('c', {4}, sd::DataType::FLOAT32); + NDArray gradO ('c', {2,3,4}, sd::DataType::FLOAT32); + + NDArray expdLdI('c', {2,3,4}, {-0.000056, -0.000056, -0.000056, -0.000056, -0.000034, -0.000034, -0.000034, -0.000034, -0.000011, -0.000011, -0.000011, -0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000034, 0.000034, 0.000034, 0.000034, 0.000056, 0.000056, 0.000056, 0.000056}, sd::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {6.148104, 6.148104, 6.148105, 6.148105}, sd::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {3.6, 4.5, 5.4, 6.3}, sd::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + variance.assign(0.46666667); + gamma.assign(1.2); + beta.assign(1.); // has no effect on gradient calculations + gradO.linspace(-0.9, 0.15); + + sd::ops::batchnorm_bp op; + + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + +} + + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test2) { + + NDArray input ('c', {2,3,4}, sd::DataType::FLOAT32); + NDArray mean ('c', {3}, {1.05, 1.1, 1.15}, sd::DataType::FLOAT32); + NDArray variance('c', {3}, {0.5, 0.6, 0.7}, sd::DataType::FLOAT32); + NDArray gamma ('c', {3}, {1.2, 1.3, 1.4}, sd::DataType::FLOAT32); + NDArray beta ('c', {3}, sd::DataType::FLOAT32); + NDArray gradO ('c', {2,3,4}, sd::DataType::FLOAT32); + + NDArray expdLdI('c', {2,3,4}, {-0.601415, -0.521226, -0.441037, -0.360849, -0.456306, -0.395465, -0.334624, -0.273784, 0.396631, 0.343747, + 0.290863, 0.237978, 0.360849, 0.441037, 0.521226, 0.601415, 0.273784, 0.334625, 0.395465, 0.456306, -0.237978, + -0.290863, -0.343746, -0.396631}, sd::DataType::FLOAT32); + NDArray expdLdG('c', {3}, {5.81236 , 7.048771, 12.155388}, sd::DataType::FLOAT32); + NDArray expdLdB('c', {3}, {1.8, 6.6, 11.4}, sd::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + // beta.assign(1.); // has no effect on gradient calculations + gradO.linspace(-0.9, 0.15); + + sd::ops::batchnorm_bp op; + + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test3) { + + NDArray input ('c', {2,3,4}, sd::DataType::FLOAT32); + NDArray mean ('c', {2,1,4}, {1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4}, sd::DataType::FLOAT32); + NDArray variance('c', {2,1,4}, {0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2}, sd::DataType::FLOAT32); + NDArray gamma ('c', {2,1,4}, {1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9}, sd::DataType::FLOAT32); + NDArray beta ('c', {2,1,4}, sd::DataType::FLOAT32); + NDArray gradO ('c', {2,3,4}, sd::DataType::FLOAT32); + + NDArray expdLdI('c', {2,3,4}, {-0.577002, -0.744041, -0.850999, -0.922373, -0.000000, -0.000000, -0.000000, -0.000000, 0.577002, + 0.744041, 0.850999, 0.922373, -0.386037, -0.350205, -0.312047, -0.271737, -0.000000, -0.000000, + -0.000000, -0.000000, 0.386037, 0.350205, 0.312047, 0.271736}, sd::DataType::FLOAT32); + NDArray expdLdG('c', {2,1,4}, {1.378844, 0.910144, 0.573706, 0.335408, 2.640487, 2.954985, 3.289431, 3.64234 }, sd::DataType::FLOAT32); + NDArray expdLdB('c', {2,1,4}, {-0.9 , -0.45, 0. , 0.45, 4.5 , 4.95, 5.4 , 5.85}, sd::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + // beta.assign(1.); // has no effect on gradient calculations + gradO.linspace(-0.9, 0.15); + + sd::ops::batchnorm_bp op; + + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,0,2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test4) { + + NDArray input ('c', {2,4}, sd::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, sd::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); + NDArray beta ('c', {4}, sd::DataType::FLOAT32); + NDArray gradO ('c', {2,4}, sd::DataType::FLOAT32); + + NDArray expdLdI('c', {2,4}, {0.162923, -0.289673, 0.354174, -0.386151, -0.162923, 0.289673, -0.354174, 0.386151}, sd::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {1.442483, 0.950200, 0.569207, 0.314641}, sd::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {-1.2, -0.9, -0.6, -0.3}, sd::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); + + sd::ops::batchnorm_bp op; + + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test5) { + +#if defined(HAVE_CUDNN) +return; +#endif + NDArray input ('c', {2,4,2,2}, sd::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, sd::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); + NDArray beta ('c', {4}, sd::DataType::FLOAT32); + NDArray gradO ('c', {2,4,2,2}, sd::DataType::FLOAT32); + + NDArray expdLdI('c', {2,4,2,2}, {-0.737512, -0.659880, -0.582247, -0.504614, 0.561404, 0.502309, 0.443214, 0.384118, -1.168243, + -1.045270, -0.922297, -0.799324, 1.899026, 1.699128, 1.499231, 1.299333, 0.504614, 0.582247, 0.659880, 0.737512, -0.384118, + -0.443214, -0.502308, -0.561404, 0.799324, 0.922297, 1.045270, 1.168243, -1.299334, -1.499231, -1.699129, -1.899026}, sd::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {11.073181, 12.585667, 17.708657, 24.313186}, sd::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {4.2, 9. , 13.8, 18.6}, sd::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); + + sd::ops::batchnorm_bp op; + + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test6) { + +#if defined(HAVE_CUDNN) +return; +#endif + + NDArray input ('c', {2,2,2,4}, sd::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, sd::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); + NDArray beta ('c', {4}, sd::DataType::FLOAT32); + NDArray gradO ('c', {2,2,2,4}, sd::DataType::FLOAT32); + + NDArray expdLdI('c', {2,2,2,4}, {-4.989124, 2.540357, -1.515022, 0.791769, -3.563660, 1.814540, -1.082159, 0.565549, -2.138196, 1.088724, -0.649295, + 0.339329, -0.712732, 0.362908, -0.216432, 0.113110, 0.712732, -0.362908, 0.216432, -0.113110, 2.138195, -1.088724, 0.649295, + -0.339330, 3.563660,-1.814540, 1.082159, -0.565549, 4.989125, -2.540356, 1.515022, -0.791770}, sd::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {20.364472, 17.856588, 16.949714, 15.903684}, sd::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {9.6, 10.8, 12. , 13.2}, sd::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); + + sd::ops::batchnorm_bp op; + + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test7) { + +#if defined(HAVE_CUDNN) +return; +#endif + + NDArray input ('c', {2,2,2,2,4}, sd::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, sd::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); + NDArray beta ('c', {4}, sd::DataType::FLOAT32); + NDArray gradO ('c', {2,2,2,2,4}, sd::DataType::FLOAT32); + + NDArray expdLdI('c', {2,2,2,2,4}, {-119.435059, 78.159744, -58.732986, 46.630123, -103.510391, 67.738441, -50.901920, 40.412773, -87.585716, 57.317142, + -43.070854, 34.195419, -71.661041, 46.895844, -35.239792, 27.978071, -55.736359, 36.474548, -27.408726, 21.760721, -39.811687, 26.053242, -19.577662, + 15.543370, -23.887009, 15.631950, -11.746595, 9.326023, -7.962326, 5.210644, -3.915531, 3.108671, 7.962341, -5.210655, 3.915535, -3.108677, 23.887032, + -15.631958, 11.746601, -9.326031, 39.811691, -26.053246, 19.577671, -15.543377, 55.736382, -36.474548, 27.408726, -21.760731, 71.661064, -46.895851, 35.239788, + -27.978077, 87.585732, -57.317154, 43.070866, -34.195431, 103.510384, -67.738464, 50.901920, -40.412777, 119.435097, -78.159744, 58.732998, -46.630131}, sd::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {282.38734 , 244.542027, 224.140995, 207.548793}, sd::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {57.6, 60. , 62.4, 64.8}, sd::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); + + + sd::ops::batchnorm_bp op; + + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,4}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); + + // dLdI->printBuffer(); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test8) { + +#if defined(HAVE_CUDNN) +return; +#endif + + NDArray input ('c', {2,4,2,2,2}, sd::DataType::FLOAT32); + NDArray mean ('c', {4}, {1.05, 1.15, 1.2, 1.3}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, {0.5, 0.7, 0.9, 1.1}, sd::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); + NDArray beta ('c', {4}, sd::DataType::FLOAT32); + NDArray gradO ('c', {2,4,2,2,2}, sd::DataType::FLOAT32); + + NDArray expdLdI('c', {2,4,2,2,2}, {-34.373802, -32.611046, -30.848286, -29.085529, -27.322769, -25.560009, -23.797251, -22.034491, 36.146996, 34.293301, + 32.439610, 30.585917, 28.732227, 26.878534, 25.024841, 23.171150, -42.876553, -40.677757, -38.478958, -36.280159, -34.081367, -31.882565, -29.683767, + -27.484968, 50.674446, 48.075760, 45.477066, 42.878380, 40.279686, 37.681000, 35.082310, 32.483616, 22.034489, 23.797249, 25.560009, 27.322765, 29.085526, + 30.848286, 32.611046, 34.373802, -23.171146, -25.024837, -26.878536, -28.732231, -30.585918, -32.439613, -34.293297, -36.146996, 27.484982, 29.683773, + 31.882572, 34.081364, 36.280178, 38.478970, 40.677776, 42.876560, -32.483627, -35.082329, -37.681023, -40.279701, -42.878403, -45.477081, -48.075775, -50.674484}, sd::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {134.490365, 179.785003, 248.933114, 330.087248}, sd::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {32.4, 51.6, 70.8, 90.}, sd::DataType::FLOAT32); + + input.linspace(0.1, 0.1); + gradO.linspace(-0.9, 0.15); + + sd::ops::batchnorm_bp op; + + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); + + // dLdI->printBuffer(); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test9) { + + NDArray input ('c', {2,4,2,2}, sd::DataType::FLOAT32); + NDArray mean ('c', {4}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, sd::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); + NDArray beta ('c', {4}, sd::DataType::FLOAT32); + NDArray gradO ('c', {2,4,2,2}, sd::DataType::FLOAT32); + + NDArray expdLdI('c', {2,4,2,2}, {0.032378, 0.028967, 0.025558, 0.022147, -0.035056, -0.031364, -0.027669, -0.024006, 0.037742, 0.033766, 0.029791, 0.025818, + -0.040429, -0.036172, -0.031913, -0.027656, -0.022155, -0.025564, -0.028974, -0.032359, 0.023982, 0.027677, 0.031373, 0.035063, + -0.025822, -0.029794, -0.033770, -0.037747, 0.027653, 0.031913, 0.036168, 0.040426}, sd::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {9.685875, 9.685880, 9.685887, 9.685891}, sd::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {4.2, 9. , 13.8, 18.6}, sd::DataType::FLOAT32); + + input.linspace(1,0.01); + gradO.linspace(-0.9, 0.15); + + // calculate mean and variance of input + PointersManager manager(input.getContext(), "DeclarableOpsTests13.batchnorm_bp_test9"); + std::vector dimensions = {0,2,3}; + int* dims = reinterpret_cast(manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int))); + input.reduceAlongDimension(sd::reduce::Mean, mean, dimensions); + NDArray::prepareSpecialUse({&variance}, {&input}); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input.shapeInfo(), dimensions); + NativeOpExecutioner::execSummaryStats(input.getContext(), 0,input.buffer(), input.shapeInfo(),input.specialBuffer(), input.specialShapeInfo(),nullptr,variance.buffer(), variance.shapeInfo(),variance.specialBuffer(), variance.specialShapeInfo(), dims, dimensions.size(),packX.platformShapeInfo(), packX.platformOffsets(),false); + manager.synchronize(); + NDArray::registerSpecialUse({&variance}, {&input}); + + sd::ops::batchnorm_bp op; + + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test10) { + + NDArray input ('c', {2,2,2,4}, sd::DataType::FLOAT32); + NDArray mean ('c', {4}, sd::DataType::FLOAT32); + NDArray variance('c', {4}, sd::DataType::FLOAT32); + NDArray gamma ('c', {4}, {-1.2, 1.3, -1.4, 1.5}, sd::DataType::FLOAT32); + NDArray beta ('c', {4}, sd::DataType::FLOAT32); + NDArray gradO ('c', {2,2,2,4}, sd::DataType::FLOAT32); + + NDArray expdLdI('c', {2,2,2,4}, {0.032634, -0.035423, 0.038110, -0.040864, 0.023302, -0.025294, 0.027213, -0.029205, 0.013996, -0.015192, 0.016343, + -0.017519, 0.004664, -0.005062, 0.005445, -0.005833, -0.004668, 0.005067, -0.005452, 0.005824, -0.013974, 0.015171, + -0.016325, 0.017508, -0.023309, 0.025301, -0.027221, 0.029197, -0.032639, 0.035428, -0.038118, 0.040878}, sd::DataType::FLOAT32); + NDArray expdLdG('c', {4}, {10.991656, 10.991631, 10.991643, 10.991632}, sd::DataType::FLOAT32); + NDArray expdLdB('c', {4}, {9.6, 10.8, 12., 13.2}, sd::DataType::FLOAT32); + + input.linspace(1,0.01); + gradO.linspace(-0.9, 0.15); + + // calculate mean and variance of input + PointersManager manager(input.getContext(), "DeclarableOpsTests13.batchnorm_bp_test9"); + std::vector dimensions = {0,1,2}; + int* dims = reinterpret_cast(manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int))); + input.reduceAlongDimension(sd::reduce::Mean, mean, dimensions); + NDArray::prepareSpecialUse({&variance}, {&input}); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input.shapeInfo(), dimensions); + NativeOpExecutioner::execSummaryStats(input.getContext(), 0,input.buffer(), input.shapeInfo(),input.specialBuffer(), input.specialShapeInfo(),nullptr,variance.buffer(), variance.shapeInfo(),variance.specialBuffer(), variance.specialShapeInfo(), dims, dimensions.size(),packX.platformShapeInfo(), packX.platformOffsets(),false); + manager.synchronize(); + NDArray::registerSpecialUse({&variance}, {&input}); + + sd::ops::batchnorm_bp op; + + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1,3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, batchnorm_bp_test11) { + + NDArray input ('c', {2,3,4,5}, sd::DataType::FLOAT32); + NDArray mean ('c', {1,3,4,5}, sd::DataType::FLOAT32); + NDArray variance('c', {1,3,4,5}, sd::DataType::FLOAT32); + NDArray gamma ('c', {1,3,4,5}, sd::DataType::FLOAT32); + NDArray beta ('c', {1,3,4,5}, sd::DataType::FLOAT32); + NDArray gradO ('c', {2,3,4,5}, sd::DataType::FLOAT32); + + NDArray expdLdI('c', {2,3,4,5}, {0.004981, 0.004818, 0.004652, 0.004483, 0.004319, 0.004153, 0.003985, 0.003832, 0.003661, 0.003505, 0.003340, 0.003171, 0.003001, 0.002837, + 0.002670, 0.002505, 0.002337, 0.002167, 0.002003, 0.001835, 0.001666, 0.001499, 0.001327, 0.001162, 0.000996, 0.000830, 0.000664, 0.000498, + 0.000332, 0.000166, -0.0, -0.000166, -0.000333, -0.000500, -0.000668, -0.000835, -0.001003, -0.001168, -0.001337, -0.001502, -0.001670, + -0.001838, -0.002003, -0.002172, -0.002330, -0.002499, -0.002669, -0.002832, -0.003002, -0.003162, -0.003332, -0.003495, -0.003665, -0.003821, + -0.004001, -0.004163, -0.004324, -0.004516, -0.004678, -0.004851, -0.004981, -0.004818, -0.004652, -0.004483, -0.004319, -0.004151, -0.003985, + -0.003836, -0.003661, -0.003505, -0.003338, -0.003171, -0.003004, -0.002837, -0.002670, -0.002503, -0.002337, -0.002170, -0.002003, -0.001835, + -0.001664, -0.001499, -0.001328, -0.001162, -0.000996, -0.000829, -0.000664, -0.000498, -0.000332, -0.000166, 0.0, 0.000166, 0.000334, + 0.000500, 0.000668, 0.000834, 0.001003, 0.001170, 0.001337, 0.001502, 0.001669, 0.001838, 0.002005, 0.002172, 0.002330, 0.002496, 0.002669, + 0.002836, 0.003002, 0.003162, 0.003328, 0.003495, 0.003670, 0.003828, 0.003992, 0.004158, 0.004324, 0.004522, 0.004689, 0.004843}, sd::DataType::FLOAT32); + NDArray expdLdG('c', {1,3,4,5}, {8.999503, 8.999502, 8.999502, 8.999503, 8.999502, 8.999503, 8.999503, 8.999499, 8.999501, 8.999498, 8.999498, 8.999498, 8.999498, 8.999498, 8.999498, + 8.999498, 8.999498, 8.999498, 8.999498, 8.999499, 8.999501, 8.999500, 8.999503, 8.999503, 8.999503, 8.999504, 8.999503, 8.999503, 8.999504, 8.999503, + 8.999504, 8.999504, 8.999499, 8.999500, 8.999497, 8.999498, 8.999496, 8.999496, 8.999496, 8.999498, 8.999498, 8.999496, 8.999496, 8.999496, 8.999501, + 8.999501, 8.999499, 8.999499, 8.999499, 8.999501, 8.999501, 8.999501, 8.999499, 8.999500, 8.999501, 8.999501, 8.999501, 8.999495, 8.999495, 8.999497}, sd::DataType::FLOAT32); + NDArray expdLdB('c', {1,3,4,5}, {7.2, 7.5, 7.8, 8.1, 8.4, 8.7, 9.0, 9.3, 9.6, 9.9, 10.2, 10.5, 10.8, 11.1, 11.4, 11.7, 12.0, 12.3, 12.6, 12.9, 13.2, 13.5, 13.8, 14.1, 14.4, 14.7, 15.0, + 15.3, 15.6, 15.9, 16.2, 16.5, 16.8, 17.1, 17.4, 17.7, 18.0, 18.3, 18.6, 18.9, 19.2, 19.5, 19.8, 20.1, 20.4, 20.7, 21.0, 21.3, 21.6, 21.9, 22.2, 22.5, + 22.8, 23.1, 23.4, 23.7, 24.0, 24.3, 24.6, 24.9}, sd::DataType::FLOAT32); + + input.linspace(1,0.01); + gradO.linspace(-0.9, 0.15); + gamma.linspace(-3, 0.1); + + // calculate mean and variance of input + PointersManager manager(input.getContext(), "DeclarableOpsTests13.batchnorm_bp_test9"); + std::vector dimensions = {0}; + int* dims = reinterpret_cast(manager.replicatePointer(dimensions.data(), dimensions.size() * sizeof(int))); + input.reduceAlongDimension(sd::reduce::Mean, mean, dimensions, true); + NDArray::prepareSpecialUse({&variance}, {&input}); + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(input.shapeInfo(), dimensions); + NativeOpExecutioner::execSummaryStats(input.getContext(), 0,input.buffer(), input.shapeInfo(),input.specialBuffer(), input.specialShapeInfo(),nullptr,variance.buffer(), variance.shapeInfo(),variance.specialBuffer(), variance.specialShapeInfo(), dims, dimensions.size(),packX.platformShapeInfo(), packX.platformOffsets(),false); + manager.synchronize(); + NDArray::registerSpecialUse({&variance}, {&input}); + + sd::ops::batchnorm_bp op; + + auto results = op.evaluate({&input, &mean, &variance, &gamma, &beta, &gradO}, {1e-5}, {1,1, 1,2,3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto dLdI = results.at(0); + auto dLdG = results.at(3); + auto dLdB = results.at(4); + + ASSERT_TRUE(expdLdI.isSameShapeStrict(*dLdI)); + ASSERT_TRUE(expdLdI.equalsTo(dLdI, 1e-4)); + + ASSERT_TRUE(expdLdG.isSameShapeStrict(*dLdG)); + ASSERT_TRUE(expdLdG.equalsTo(dLdG)); + + ASSERT_TRUE(expdLdB.isSameShapeStrict(*dLdB)); + ASSERT_TRUE(expdLdB.equalsTo(dLdB)); + +} diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests14.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests14.cpp new file mode 100644 index 000000000..47bd5af3b --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests14.cpp @@ -0,0 +1,2454 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +// +// Created by raver on 8/4/2018. +// + +#include "testlayers.h" +#include +#include +#include +#include + + +using namespace sd; + + +class DeclarableOpsTests14 : public testing::Test { +public: + + DeclarableOpsTests14() { + printf("\n"); + fflush(stdout); + } +}; + +TEST_F(DeclarableOpsTests14, Test_Validation_Edge_1) { + auto x = NDArrayFactory::create('c', {2}, {2, 2}); + auto exp = NDArrayFactory::create('c', {2, 2}, Environment::getInstance().defaultFloatDataType()); + exp.assign(4.0f); + + sd::ops::fill op; + auto result = op.evaluate({&x}, {4.0f}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_EQ(exp, *z); + + +} + +TEST_F(DeclarableOpsTests14, Test_Inf_Comparison_1) { + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, std::numeric_limits::infinity(), 5}); + auto y = NDArrayFactory::create('c', {5}, {1, 2, 3, std::numeric_limits::infinity(), 5}); + + ASSERT_EQ(x, y); +} + +TEST_F(DeclarableOpsTests14, Test_Inf_Comparison_2) { +#ifdef FFAST_MATH + if (1 > 0) + return; +#endif + + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, std::numeric_limits::infinity(), 5}); + auto y = NDArrayFactory::create('c', {5}, {1, 2, 3, -std::numeric_limits::infinity(), 5}); + + ASSERT_NE(x, y); +} + +TEST_F(DeclarableOpsTests14, Multiply_test) { + + for(int k=2;k<10;k++){ + //nd4j_printf("k=%d\n", k); + NDArray x = NDArrayFactory::create('c', {k, 1}); + NDArray y = NDArrayFactory::create('c', {k}); + NDArray e = NDArrayFactory::create('c', {k, k}); + x.assign(1.0); + y.assign(1.0); + e.assign(1.0); + + sd::ops::multiply op; + auto result = op.evaluate({&x, &y}); + auto f = result.at(0); + NDArray r = *f; + + ASSERT_EQ(e, r); + ASSERT_EQ(e, *f); + + + } +} + +TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_1) { + auto x = NDArrayFactory::create('c', {3}, {5, 3, 4}); + auto y = NDArrayFactory::create('c', {1}, {1}); + auto e = NDArrayFactory::create('c', {2}, {5, 4}); + + sd::ops::evaluate_reduction_shape op; + auto result = op.evaluate({&x, &y}, {}, {}, {false, false}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_EQ(e, *z); + + +} + +TEST_F(DeclarableOpsTests14, Test_EvalReductionShape_2) { + auto x = NDArrayFactory::create('c', {3}, {5, 3, 4}); + auto y = NDArrayFactory::create('c', {1}, {1}); + auto e = NDArrayFactory::create('c', {3}, {5, 1, 4}); + + sd::ops::evaluate_reduction_shape op; + auto result = op.evaluate({&x, &y}, {}, {}, {true, false}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_EQ(e, *z); + + +} + +TEST_F(DeclarableOpsTests14, Test_Reduce_Min_Small_0) { + auto x = NDArrayFactory::create('c', {3, 4}, {-999.f, 0.2236f, 0.7973f, 0.0962f, 0.7231f, 0.3381f, -0.7301f, 0.9115f, -0.5094f, 0.9749f, -2.1340f, 0.6023f}); + auto z = NDArrayFactory::create('c', {4}); + auto e = NDArrayFactory::create('c', {4}, {-999.f, 0.2236f, -2.1340f, 0.0962f}); + + sd::ops::reduce_min op; + op.execute({&x}, {&z}, {}, {0}, {}); + + //z.printIndexedBuffer("Z"); + + ASSERT_EQ(e, z); +} + +TEST_F(DeclarableOpsTests14, Test_Reduce_Min_Small_1) { + auto x = NDArrayFactory::create('c', {3, 4}, {-999.f, 0.2236f, 0.7973f, 0.0962f, 0.7231f, 0.3381f, -0.7301f, 0.9115f, -0.5094f, 0.9749f, -2.1340f, 0.6023f}); + auto z = NDArrayFactory::create('c', {3}); + auto e = NDArrayFactory::create('c', {3}, {-999.f, -0.7301f, -2.1340f}); + + sd::ops::reduce_min op; + op.execute({&x}, {&z}, {}, {1}, {}); + + //z.printIndexedBuffer("Z"); + + ASSERT_EQ(e, z); +} + +TEST_F(DeclarableOpsTests14, Test_Diag_Zeros_1) { + auto x = NDArrayFactory::create('c', {2}, {1, 2}); + auto z = NDArrayFactory::create('c', {2, 2}, {-119, -119, -119, -119}); + auto exp = NDArrayFactory::create('c', {2, 2}, {1, 0, 0, 2}); + + sd::ops::diag op; + auto status = op.execute({&x}, {&z}, {}, {}, {}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(exp, z); +} + +TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_1) { + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create('c', {5, 10}); + auto e = NDArrayFactory::create('c', {5, 10}); + e.assign(1.0); + + + sd::ops::add op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_EQ(e, *result.at(0)); + + +} + +TEST_F(DeclarableOpsTests14, Test_scalar_broadcast_2) { + auto x = NDArrayFactory::create(1.0f); + auto y = NDArrayFactory::create('c', {5, 10}); + auto e = NDArrayFactory::create('c', {5, 10}); + y.assign(2.0f); + e.assign(-1.0f); + + + sd::ops::subtract op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_EQ(e, *result.at(0)); + + +} + +TEST_F(DeclarableOpsTests14, test_empty_fill_1) { + auto x = NDArrayFactory::empty(); + auto y = NDArrayFactory::create(1); + + sd::ops::fill op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_EQ(y, *z); + + +} + +TEST_F(DeclarableOpsTests14, test_lstmBlockCell_1) { + auto a = NDArrayFactory::create('c', {1, 5}, {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f}); + auto b = NDArrayFactory::create('c', {1, 3}); + auto c = NDArrayFactory::create('c', {1, 3}); + auto d = NDArrayFactory::create('c', {8, 12}, {-0.15320599,-0.120416045,0.33126968,0.13921785,-0.32313538,-0.43956736,0.4756174,0.4335605,-0.5450856,-0.3943429,-0.28687626,0.068032146,-0.2793799,0.17298919,-0.36553562,-0.097853184,-0.2544747,-0.39872527,-0.14556861,-0.31479517,0.2559092,0.47166896,-0.31330687,0.47313118,0.5134543,-0.4678212,-0.12853557,0.26142156,0.43472284,-0.42842552,-0.1895876,0.538689,0.508651,-0.020272732,0.112327516,0.2704304,-0.046546757,0.32570732,-0.15148133,-0.19145513,0.18631572,-0.024152994,0.41603214,-0.3421499,0.0106860995,-0.2966229,-0.36713937,0.25841123,0.0843398,0.49082482,0.10800403,0.1874243,-0.26379472,-0.22531849,0.24924624,0.23119557,0.49940765,-0.051413506,0.20315129,-0.41888732,0.44097036,0.40453392,0.013338983,0.23434466,0.23942488,0.47894,-0.19898453,0.09253675,-0.032358468,-0.15213022,-0.3441009,-0.15600958,-0.08235118,0.12165731,-0.4481289,-0.4842423,-0.45797008,-0.4606034,0.08163166,-0.2981107,0.50207126,0.44195646,0.13850057,0.072246075,-0.34388685,0.030900061,0.35821778,0.47900867,0.5094063,0.23683065,0.18020362,-0.1369732,0.015235603,0.2786904,0.07954317,0.12543976}); + auto e = NDArrayFactory::create('c', {3}); + auto f = NDArrayFactory::create('c', {3}); + auto g = NDArrayFactory::create('c', {3}); + auto h = NDArrayFactory::create('c', {12}); + + auto z0 = NDArrayFactory::create('c', {1, 3}); + auto z1 = NDArrayFactory::create('c', {1, 3}); + auto z2 = NDArrayFactory::create('c', {1, 3}); + auto z3 = NDArrayFactory::create('c', {1, 3}); + auto z4 = NDArrayFactory::create('c', {1, 3}); + auto z5 = NDArrayFactory::create('c', {1, 3}); + auto z6 = NDArrayFactory::create('c', {1, 3}); + + sd::ops::lstmBlockCell op; + auto result = op.execute({&a, &b, &c, &d, &e, &f, &g, &h}, {&z0, &z1, &z2, &z3, &z4, &z5, &z6}, {1.0, -1.0}, {0}, {}); + ASSERT_EQ(Status::OK(), result); +} + +TEST_F(DeclarableOpsTests14, test_empty_reduce_min_1) { + + auto e = NDArrayFactory::create('c', {1, 0}); + sd::ops::reduce_min sumOp; + auto res2 = sumOp.evaluate({&e}, {1.}, {1}); + ASSERT_EQ(res2.status(), Status::OK()); + auto out = res2.at(0); + + ASSERT_EQ(out->e(0), DataTypeUtils::infOrMax()); + +} + +TEST_F(DeclarableOpsTests14, test_empty_reduce_max_1) { + + auto e = NDArrayFactory::create('c', {1, 0}); + sd::ops::reduce_max sumOp; + auto res2 = sumOp.evaluate({&e}, {1.}, {1}); + ASSERT_EQ(res2.status(), Status::OK()); + auto out = res2.at(0); + + ASSERT_EQ(out->e(0), -DataTypeUtils::infOrMax()); + +} + +TEST_F(DeclarableOpsTests14, test_empty_reduce_sum_1) { +#ifdef FFAST_MATH + if (1 > 0) + return; +#endif + + auto e = NDArrayFactory::create('c', {1, 0}); + sd::ops::reduce_sum sumOp; + auto res2 = sumOp.evaluate({&e}, {1.}, {1}); + ASSERT_EQ(res2.status(), Status::OK()); + auto out = res2.at(0); + ASSERT_EQ(out->e(0), 0.f); + +} + +TEST_F(DeclarableOpsTests14, test_empty_reduce_mean_1) { +#ifdef FFAST_MATH + if (1 > 0) + return; +#endif + + auto e = NDArrayFactory::create('c', {1, 0}); + sd::ops::reduce_mean sumOp; + auto res2 = sumOp.evaluate({&e}, {1.}, {1}); + ASSERT_EQ(res2.status(), Status::OK()); + auto out = res2.at(0); + // out->printShapeInfo("ReduceMean empty shape with keep dims"); + // out->printIndexedBuffer("ReduceMean scalar"); + ASSERT_TRUE(std::isnan(out->e(0))); + +} + +TEST_F(DeclarableOpsTests14, Test_StridedSliceZeros_1) { + auto matrix = NDArrayFactory::create('c', {1, 2, 0, 4}); + auto b = NDArrayFactory::create('c', {3}, {0, 0, 0}); + auto e = NDArrayFactory::create('c', {3}, {2,0,2}); + auto s = NDArrayFactory::create('c', {3}, {1,1,1}); + + auto exp = NDArrayFactory::create('c', {1,0,0,4}); + + matrix.linspace(1); + + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + + +} + +TEST_F(DeclarableOpsTests14, Test_StridedSliceZeros_2) { + auto matrix = NDArrayFactory::create('c', {1, 2, 0, 4}); + auto b = NDArrayFactory::create('c', {3}, {0, 0, 0}); + auto e = NDArrayFactory::create('c', {3}, {2,0,2}); + auto s = NDArrayFactory::create('c', {3}, {1,1,1}); + + auto exp = NDArrayFactory::create('c', {0,0,4}); + + matrix.linspace(1); + + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + + +} + +TEST_F(DeclarableOpsTests14, test_empty_argmax_1) { + auto x = NDArrayFactory::create('c', {1, 0}); + auto y = NDArrayFactory::create(0); + auto e = NDArrayFactory::create('c', {0}); + + sd::ops::argmax op; + //sd::ops::reduce_max op; + + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_EQ(e, *z); + + +} + +TEST_F(DeclarableOpsTests14, test_empty_argmax_2) { + auto x = NDArrayFactory::create('c', {1, 0}); + auto y = NDArrayFactory::create(1); + + sd::ops::argmax op; + try { + auto result = op.execute({&x, &y}, {&y}, {}, {}, {}); + ASSERT_TRUE(false); + } catch (std::exception &e) { + // + } +} + +TEST_F(DeclarableOpsTests14, test_empty_tanh_5) { + auto x = NDArrayFactory::create('c', {32, 0}); + + sd::ops::tanh op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(x.isSameShape(z)); + ASSERT_EQ(x, *z); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, repeat_1) { + + NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + NDArray e('c', {4, 3}, {1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6}); + + sd::ops::repeat op; + auto result = op.evaluate({&x}, {}, {2, 0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, repeat_2) { + + NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + NDArray e('c', {2, 6}, {1, 1, 2, 2, 3, 3,4, 4, 5, 5, 6, 6}); + + sd::ops::repeat op; + auto result = op.evaluate({&x}, {}, {2, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, repeat_3) { + + NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + NDArray e('c', {2, 6}, {1, 2, 2, 3, 3, 3,4, 5, 5, 6, 6, 6}); + + sd::ops::repeat op; + auto result = op.evaluate({&x}, {}, {1,2,3, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, repeat_4) { + + NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + NDArray e('c', {7, 3}, {1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, 4, 5, 6}); + + sd::ops::repeat op; + auto result = op.evaluate({&x}, {}, {3,4, 0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, repeat_5) { + + NDArray x('c', {2, 3, 4}, {1, 2, 3, 4, 5, 6, 7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}); + NDArray e('c', {2, 4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 17, 18, 19, 20, 21, 22, 23, 24}); + + sd::ops::repeat op; + auto result = op.evaluate({&x}, {}, {1,2,1, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); + + +} +///////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest) { + + auto y = NDArray('c', { 3 }, sd::DataType::FLOAT32); + auto x = NDArray('c', { 5, 2, 1 }, sd::DataType::FLOAT32); + + auto e = NDArray('c', { 5, 2, 3 }, { 2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5., 6., 6., 6., 7., 7., 7., 8., 8., 8., 9., 9., 9., 10., 10., 10., 11., 11., 11. }, sd::DataType::FLOAT32); + + y.assign(1.0); + x.linspace(1.0); + + sd::ops::add op; + auto result = op.evaluate({ &x, &y }); + ASSERT_EQ(Status::OK(), result.status()); + + auto res = *result.at(0); + + ASSERT_EQ(e, res); + + +} +///////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest2) { + + auto y = NDArray('c', { 1, 3 }, sd::DataType::FLOAT32); + auto x = NDArray('c', { 5, 2, 1 }, sd::DataType::FLOAT32); + + auto e = NDArray('c', { 5, 2, 3 }, { 2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5., 5., 6., 6., 6., 7., 7., 7., 8., 8., 8., 9., 9., 9., 10., 10., 10., 11., 11., 11. }, sd::DataType::FLOAT32); + + y.assign(1.0); + x.linspace(1.0); + + sd::ops::add op; + auto result = op.evaluate({ &x, &y }); + ASSERT_EQ(Status::OK(), result.status()); + + auto res = *result.at(0); + + ASSERT_EQ(e, res); + + +} + +/////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest3) { + + auto x = NDArray('c', { 3, 5, 1 }, sd::DataType::FLOAT32); + auto y = NDArray('c', { 3, 1, 4 }, sd::DataType::FLOAT32); + auto z = NDArray('c', { 3, 5, 4 }, sd::DataType::FLOAT32); + // recieved by main algorithm + auto e = NDArray('c', { 3, 5, 4 }, { 10., 11., 12., 13., 20., 22., 24., 26., 30., 33., 36., 39., 40., 44., 48., 52., 50., 55., 60., 65., 84., 90., 96., 102., 98., 105., 112., 119., 112., 120., 128., 136., 126., 135., 144., 153., 140., 150., 160., 170., 198., 209., 220., 231., 216., 228., 240., 252., 234., 247., 260., 273., 252., 266., 280., 294., 270., 285., 300., 315. }, sd::DataType::FLOAT32); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); + ASSERT_EQ(e, z); +} +/////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest4) { + + auto x = NDArray('c', { 2, 3, 5, 1 }, sd::DataType::FLOAT32); + auto y = NDArray('c', { 2, 3, 1, 4 }, sd::DataType::FLOAT32); + auto z = NDArray('c', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); + // recieved by main algorithm + auto e = NDArray('c', { 2, 3, 5, 4 }, { 10., 11., 12., 13.,20., 22., 24., 26.,30., 33., 36., 39.,40., 44., 48., 52.,50., 55., 60., 65.,84., 90., 96., 102.,98., 105., 112., 119.,112., 120., 128., 136.,126., 135., 144., 153.,140., 150., 160., 170.,198., 209., 220., 231.,216., 228., 240., 252.,234., 247., 260., 273.,252., 266., 280., 294.,270., 285., 300., 315.,352., 368., 384., 400.,374., 391., 408., 425.,396., 414., 432., 450.,418., 437., 456., 475.,440., 460., 480., 500.,546., 567., 588., 609.,572., 594., 616., 638.,598., 621., 644., 667.,624., 648., 672., 696.,650., 675., 700., 725.,780., 806., 832., 858.,810., 837., 864., 891.,840., 868., 896., 924.,870., 899., 928., 957.,900., 930., 960., 990. }, sd::DataType::FLOAT32); + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); + ASSERT_EQ(e, z); +} +/////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest5) { + + auto x = NDArray('c', { 3, 5, 1 }, sd::DataType::FLOAT32); + auto y = NDArray('c', { 3, 1, 4 }, sd::DataType::FLOAT32); + auto z = NDArray('c', { 3, 5, 4 }, sd::DataType::FLOAT32); + // recieved by main algorithm + auto e = NDArray('c', { 3, 5, 4 }, { 0.1, 0.090909, 0.083333, 0.076923,0.2, 0.181818, 0.166667, 0.153846,0.3, 0.272727, 0.250000, 0.230769,0.4, 0.363636, 0.333333, 0.307692,0.5, 0.454545, 0.416667, 0.384615, 0.428571, 0.400000, 0.375000, 0.352941, 0.500000, 0.466667, 0.437500, 0.411765, 0.571429, 0.533333, 0.500000, 0.470588, 0.642857, 0.600000, 0.562500, 0.529412, 0.714286, 0.666667, 0.625000, 0.588235, 0.611111, 0.578947, 0.550000, 0.523810, 0.666667, 0.631579, 0.600000, 0.571429, 0.722222, 0.684211, 0.650000, 0.619048, 0.777778, 0.736842, 0.700000, 0.666667, 0.833333, 0.789474, 0.750000, 0.714286 }, sd::DataType::FLOAT32); + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Divide(), y, z); + ASSERT_EQ(e, z); +} +/////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest6) { + + auto x = NDArray('c', { 2, 3, 5, 1 }, sd::DataType::FLOAT32); + auto y = NDArray('c', { 2, 3, 1, 4 }, sd::DataType::FLOAT32); + auto z = NDArray('c', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); + // recieved by main algorithm + auto e = NDArray('c', { 2, 3, 5, 4 }, { 0.1, 0.090909, 0.083333, 0.076923,0.2, 0.181818, 0.166667, 0.153846,0.3, 0.272727, 0.250000, 0.230769,0.4, 0.363636, 0.333333, 0.307692,0.5, 0.454545, 0.416667, 0.384615, 0.428571, 0.400000, 0.375000, 0.352941, 0.500000, 0.466667, 0.437500, 0.411765, 0.571429, 0.533333, 0.500000, 0.470588, 0.642857, 0.600000, 0.562500, 0.529412, 0.714286, 0.666667, 0.625000, 0.588235,0.611111, 0.578947, 0.550000, 0.523810,0.666667, 0.631579, 0.600000, 0.571429,0.722222, 0.684211, 0.650000, 0.619048,0.777778, 0.736842, 0.700000, 0.666667,0.833333, 0.789474, 0.750000, 0.714286, 0.727273, 0.695652, 0.666667, 0.64, 0.772727, 0.739130, 0.708333, 0.68, 0.818182, 0.782609, 0.750000, 0.72, 0.863636, 0.826087, 0.791667, 0.76, 0.909091, 0.869565, 0.833333, 0.80, 0.807692, 0.777778, 0.750000, 0.724138, 0.846154, 0.814815, 0.785714, 0.758621, 0.884615, 0.851852, 0.821429, 0.793103, 0.923077, 0.888889, 0.857143, 0.827586, 0.961538, 0.925926, 0.892857, 0.862069, 0.866667, 0.838710, 0.812500, 0.787879, 0.900000, 0.870968, 0.843750, 0.818182, 0.933333, 0.903226, 0.875000, 0.848485, 0.966667, 0.935484, 0.906250, 0.878788, 1.000000, 0.967742, 0.937500, 0.909091 }, sd::DataType::FLOAT32); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Divide(), y, z); + ASSERT_EQ(e, z); +} + +/////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest7) { + + auto x = NDArray('c', { 3, 5, 1 }, sd::DataType::FLOAT32); + auto y = NDArray('c', { 3, 1, 4 }, sd::DataType::FLOAT32); + auto z = NDArray('c', { 3, 5, 4 }, sd::DataType::FLOAT32); + // recieved by main algorithm + auto e = NDArray('c', { 3, 5, 4 }, { -9., -10., -11., -12.,-8., -9., -10., -11., -7., -8., -9., -10.,-6., -7., -8., -9.,-5., -6., -7., -8.,-8., -9., -10., -11.,-7., -8., -9., -10.,-6., -7., -8., -9.,-5., -6., -7., -8.,-4., -5., -6., -7.,-7., -8.000000, -9.000000, -10.00,-6.000000, -7.000000, -8.000000, -9.000,-5.000000, -6.000000, -7.000000, -8.000,-4.000000, -5.000000, -6.000000, -7.000,-3.000000, -4.000000, -5.000000, -6.000 }, sd::DataType::FLOAT32); + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Subtract(), y, z); + ASSERT_EQ(e, z); +} +/////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Test_broadcast_SpecialCaseTest8) { + + auto x = NDArray('c', { 2, 3, 5, 1 }, sd::DataType::FLOAT32); + auto y = NDArray('c', { 2, 3, 1, 4 }, sd::DataType::FLOAT32); + auto z = NDArray('c', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); + // recieved by main algorithm + auto e = NDArray('c', { 2, 3, 5, 4 }, { -9.0, -10., -11., -12.,-8., -9., -10., -11.0,-7., -8., -9., -10.,-6., -7., -8., -9.,-5., -6., -7., -8.,-8., -9., -10., -11.,-7., -8., -9., -10.,-6., -7., -8., -9.,-5., -6., -7., -8.,-4., -5., -6., -7.,-7., -8., -9., -10.,-6., -7., -8., -9.,-5., -6., -7., -8.,-4., -5., -6., -7.,-3., -4., -5., -6.,-6., -7., -8., -9.,-5., -6., -7., -8.,-4., -5., -6., -7.,-3., -4., -5., -6.,-2., -3., -4., -5.,-5., -6., -7., -8.,-4., -5., -6., -7.,-3., -4., -5., -6.,-2., -3., -4., -5.,-1., -2., -3., -4.,-4., -5., -6., -7.,-3., -4., -5., -6.,-2., -3., -4., -5.,-1., -2., -3., -4., 0., -1., -2., -3. }, sd::DataType::FLOAT32); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Subtract(), y, z); + ASSERT_EQ(e, z); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test1) { + + auto x = NDArrayFactory::create('c', {3, 4}); + auto y = NDArrayFactory::create('c', {4, 3}); + auto exp = NDArrayFactory::create('f', {3, 3}, {35., 79., 123., 40., 92., 144., 45., 105., 165.}); + + x.linspace(1.); + y.linspace(0.5, 0.5); + + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test2) { + + auto x = NDArrayFactory::create('c', {3, 4}); + auto y = NDArrayFactory::create('f', {4, 3}); + auto exp = NDArrayFactory::create('f', {3, 3}, {35., 79., 123.,40., 92., 144.,45.,105., 165.}); + + x.linspace(1.); + y.linspace(0.5, 0.5); + + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test3) { + + auto x = NDArrayFactory::create('f', {3, 4}); + auto y = NDArrayFactory::create('c', {4, 3}); + auto exp = NDArrayFactory::create('f', {3, 3}, {35., 79., 123.,40., 92., 144.,45.,105., 165.}); + + x.linspace(1.); + y.linspace(0.5, 0.5); + + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test4) { + + auto x = NDArrayFactory::create ('f', {3, 4}); + auto y = NDArrayFactory::create('f', {4, 3}); + auto exp = NDArrayFactory::create('f', {3, 3}, {35., 79., 123.,40., 92., 144.,45.,105., 165.}); + + x.linspace(1.); + y.linspace(0.5, 0.5); + + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test5) { + + auto x = NDArrayFactory::create('c', {4, 3}); + auto y = NDArrayFactory::create('c', {4, 3}); + auto exp = NDArrayFactory::create('f', {3, 3}, {83., 94., 105., 94., 107., 120., 105., 120., 135.}); + + x.linspace(1.); + y.linspace(0.5, 0.5); + + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test6) { + + auto x = NDArrayFactory::create('c', {4, 3}); + auto y = NDArrayFactory::create('f', {3, 4}); + auto exp = NDArrayFactory::create('f', {3, 3}, {35., 40., 45., 79., 92., 105., 123., 144., 165.}); + + x.linspace(1.); + y.linspace(0.5, 0.5); + + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test7) { + + auto x = NDArrayFactory::create('c', {5, 3,4}); + auto y = NDArrayFactory::create('f', {5, 3,4}); + auto exp = NDArrayFactory::create('f',{5, 3,3}, {3. , 84.6, 281.4, 593.4, 1020.6, 7. , 107.8, 323.8, 655. , 1101.4,11. , 131. , 366.2, 716.6, 1182.2, + 7. , 107.8, 323.8, 655. , 1101.4,17.4, 137.4, 372.6, 723. , 1188.6,27.8, 167. , 421.4, 791. , 1275.8, + 11. , 131. , 366.2, 716.6, 1182.2,27.8, 167. , 421.4, 791. , 1275.8,44.6, 203. , 476.6, 865.4, 1369.4,}); + + x.linspace(1.); + y.linspace(0.1, 0.1); + + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {0, 1}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test8) { + + auto x = NDArrayFactory::create('c', {2,5, 3,4}); + auto y = NDArrayFactory::create('f', {2,5, 3,4}); + auto exp = NDArrayFactory::create('f',{2,5, 3,3}, {3. , 1563. , 84.6, 2220.6, 281.4, 2993.4, 593.4, 3881.4,1020.6, 4884.6, 7. , 1663. , 107.8, 2339.8, 323.8, 3131.8, 655. , 4039. ,1101.4, 5061.4, + 11. , 1763. , 131. , 2459. , 366.2, 3270.2, 716.6, 4196.6,1182.2, 5238.2, 7. , 1663. , 107.8, 2339.8, 323.8, 3131.8, 655. , 4039. ,1101.4, 5061.4, + 17.4, 1769.4, 137.4, 2465.4, 372.6, 3276.6, 723. , 4203. ,1188.6, 5244.6, 27.8, 1875.8, 167. , 2591. , 421.4, 3421.4, 791. , 4367. ,1275.8, 5427.8, + 11. , 1763. , 131. , 2459. , 366.2, 3270.2, 716.6, 4196.6,1182.2, 5238.2, 27.8, 1875.8, 167. , 2591. , 421.4, 3421.4, 791. , 4367. ,1275.8, 5427.8, + 44.6, 1988.6, 203. , 2723. , 476.6, 3572.6, 865.4, 4537.4,1369.4, 5617.4}); + + x.linspace(1.); + y.linspace(0.1, 0.1); + + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {0, 1}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test9) { + + auto x = NDArrayFactory::create('c', {2,5, 4,3}); + auto y = NDArrayFactory::create('f', {2,5, 3,4}); + auto exp = NDArrayFactory::create('f',{2,5, 3,3}, {7. , 1639. , 103. , 2311. , 314.2, 3098.2, 640.6, 4000.6,1082.2, 5018.2, 8. , 1664. , 108.8, 2340.8, 324.8, 3132.8, 656. , 4040. ,1102.4, 5062.4, + 9. , 1689. , 114.6, 2370.6, 335.4, 3167.4, 671.4, 4079.4,1122.6, 5106.6, 15.8, 1743.8, 131. , 2435. , 361.4, 3241.4, 707. , 4163. ,1167.8, 5199.8, + 18.4, 1770.4, 138.4, 2466.4, 373.6, 3277.6, 724. , 4204. ,1189.6, 5245.6, 21. , 1797. , 145.8, 2497.8, 385.8, 3313.8, 741. , 4245. ,1211.4, 5291.4, + 24.6, 1848.6, 159. , 2559. , 408.6, 3384.6, 773.4, 4325.4,1253.4, 5381.4, 28.8, 1876.8, 168. , 2592. , 422.4, 3422.4, 792. , 4368. ,1276.8, 5428.8, + 33. , 1905. , 177. , 2625. , 436.2, 3460.2, 810.6, 4410.6,1300.2, 5476.2}); + + x.linspace(1.); + y.linspace(0.1, 0.1); + + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests14, matmul_test10) { + + auto x = NDArrayFactory::create_('c', {3, 5}); + x->linspace(1); + + auto y = NDArrayFactory::create_('c', {5, 3}); + y->linspace(1); + + float _expB[]{135.0f, 310.0f, 485.0f, 150.0f, 350.0f, 550.0f, 165.0f, 390.0f, 615.0f}; + Nd4jLong _expS[] {2, 3, 3, 1, 3, 0, 1, 102}; // expected shape + ArrayOptions::setDataType(_expS, sd::DataType::FLOAT32); + NDArray exp(_expB, _expS); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + variableSpace->putVariable(1, new Variable()); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1, -2}); + + sd::ops::matmul op; + + Nd4jStatus status = op.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(variableSpace->hasVariable(1)); + + auto result = variableSpace->getVariable(1)->getNDArray(); + + ASSERT_TRUE(result->equalsTo(&exp)); + + delete block; + delete variableSpace; +} + +TEST_F(DeclarableOpsTests14, matmul_test11) { + auto A = NDArrayFactory::create('c', {3, 3}); + auto B = NDArrayFactory::create('c', {3, 1}); + auto exp = NDArrayFactory::create('c', {3, 1}, {14.00f, 32.00f, 50.00f}); + + A.linspace(1); + B.linspace(1); + + sd::ops::matmul op; + + auto result = op.evaluate({&A, &B}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests14, matmul_test12) { + auto x= NDArrayFactory::create('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12}); + auto y= NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8 , 9, 10, 11, 12}); + auto exp= NDArrayFactory::create('f', {4, 4}, {38.0, 44.0, 50.0, 56.0, 83.0, 98.0, 113.0, 128.0, 128.0, 152.0, 176.0, 200.0, 173.0, 206.0, 239.0, 272.0}); + + sd::ops::matmul op; + auto result = op.evaluate({&x, &y}, {}, {1, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + + +} + + +TEST_F(DeclarableOpsTests14, matmul_test13) { + auto x= NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); + auto y= NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto exp= NDArrayFactory::create('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); + + sd::ops::matmul op; + auto result = op.evaluate({&x, &y}, {}, {1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + //z->printIndexedBuffer("z"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests14, matmul_test14) { + auto x= NDArrayFactory::create('c', {3, 1}, {1, 2, 3}); + auto y= NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); + auto exp= NDArrayFactory::create('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); + + sd::ops::matmul op; + auto result = op.evaluate({&x, &y}, {}, {0, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + //z->printIndexedBuffer("z"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests14, matmul_test15) { + auto x= NDArrayFactory::create('c', {3, 1}, {1, 2, 3}); + auto y= NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto exp= NDArrayFactory::create('f', {3, 4}, {1.0, 2.0, 3.0, 2.0, 4.0, 6.0, 3.0, 6.0, 9.0, 4.0, 8.0, 12.0}); + + sd::ops::matmul op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + //z->printIndexedBuffer("z"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests14, matmul_test16) { + auto x= NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); + auto y= NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto exp= NDArrayFactory::create('f', {4, 4}, {1,2, 3, 4,2,4, 6, 8,3,6, 9,12,4,8,12,16}); + + sd::ops::matmul op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + //z->printIndexedBuffer("z"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests14, matmul_test17) { + auto x = NDArrayFactory::create('c', {1, 2}, {2.0f, 2.0f}); + auto y = NDArrayFactory::create('c', {2, 1}, {2.0f, 2.0f}); + auto exp = NDArrayFactory::create('c', {1, 1}, {8.0f}); + + sd::ops::matmul op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_EQ(exp, *result.at(0)); + + +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test18) { + + auto x = NDArrayFactory::create('c', {1, 4, 3}); + auto y = NDArrayFactory::create('f', {1, 3, 4}); + auto exp = NDArrayFactory::create('f', {1, 3, 3}, {35., 40., 45., 79., 92., 105., 123., 144., 165.}); + + x.linspace(1.); + y.linspace(0.5, 0.5); + + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test19) { + + auto x = NDArrayFactory::create('c', {4, 1}); + auto y = NDArrayFactory::create('f', {1, 4}); + auto exp = NDArrayFactory::create('f', {1, 1}, {15}); + + x.linspace(1.); + y.linspace(0.5, 0.5); + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1}); + ASSERT_EQ(Status::OK(), results.status()); + + auto z = results.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test20) { + + auto x = NDArrayFactory::create('c', {1, 4, 1}); + auto y = NDArrayFactory::create('f', {1, 1, 4}); + auto exp = NDArrayFactory::create('f', {1, 1, 1}, {15}); + + x.linspace(1.); + y.linspace(0.5, 0.5); + + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1}); + + ASSERT_EQ(Status::OK(), results.status()); + auto z = results.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test21) { + + auto x = NDArrayFactory::create('c', {2, 3}); + auto y = NDArrayFactory::create('c', {3, 5}); + auto exp = NDArrayFactory::create('f', {5, 2}, {23. , 26. , 29. , 32. , 35., 50. , 57.5, 65. , 72.5, 80.}); + + x.linspace(1.); + y.linspace(0.5, 0.5); + + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {0, 0, 1}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test22) { + + auto x = NDArrayFactory::create('c', {3, 2}); + auto y = NDArrayFactory::create('c', {3, 5}); + auto exp = NDArrayFactory::create('f', {5, 2}, {37. , 41.5, 46. , 50.5, 55., 46. , 52. , 58. , 64. , 70.}); + + x.linspace(1.); + y.linspace(0.5, 0.5); + + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 0, 1}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test23) { + + auto x = NDArrayFactory::create('c', {3, 2}); + auto y = NDArrayFactory::create('c', {3, 5}); + auto exp = NDArrayFactory::create('f', {5, 2}, {37. , 41.5, 46. , 50.5, 55., 46. , 52. , 58. , 64. , 70.}); + + x.linspace(1.); + y.linspace(0.5, 0.5); + + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 0, 1}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test24) { + + auto x = NDArrayFactory::create('c', {2,2, 3,5}); + auto y = NDArrayFactory::create('c', {2,2, 4,3}); + auto exp = NDArrayFactory::create('f',{2,2, 4,5}, {4.6, 281.8, 89.2, 582.4, 10. , 314.2,108.1, 628.3, 15.4, 346.6,127. , 674.2, 20.8, 379. ,145.9, 720.1, 5.2, 289.6, 93.4, 593.8, + 11.5, 322.9,113.2, 640.6, 17.8, 356.2,133. , 687.4, 24.1, 389.5,152.8, 734.2, 5.8, 297.4, 97.6, 605.2, 13. , 331.6,118.3, 652.9, + 20.2, 365.8,139. , 700.6, 27.4, 400. ,159.7, 748.3, 6.4, 305.2,101.8, 616.6, 14.5, 340.3,123.4, 665.2, 22.6, 375.4,145. , 713.8, + 30.7, 410.5,166.6, 762.4, 7. , 313. ,106. , 628. , 16. , 349. ,128.5, 677.5, 25. , 385. ,151. , 727. , 34. , 421. ,173.5, 776.5}); + + x.linspace(1.); + y.linspace(0.1, 0.1); + + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1, 1}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test25) { + + auto x = NDArrayFactory::create('f', {4, 3}); + auto y = NDArrayFactory::create('c', {4}); + auto exp = NDArrayFactory::create('f',{3}, {7., 8., 9.}); + + x.linspace(1.); + y.linspace(0.1, 0.1); + + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 0}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test26) { + + auto x = NDArrayFactory::create('f', {3}); + auto y = NDArrayFactory::create('c', {4, 3}); + auto exp = NDArrayFactory::create('f',{4}, {1.4, 3.2, 5., 6.8}); + + x.linspace(1.); + y.linspace(0.1, 0.1); + + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {0, 1}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test27) { + + auto x = NDArrayFactory::create('f', {1, 1}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('f',{1, 1}, {0.2}); + + x.linspace(2.); + y.linspace(0.1, 0.1); + + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test28) { + + auto x = NDArrayFactory::create('f', {1, 1}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('f',{1, 1}, {0.2}); + + x.linspace(2.); + y.linspace(0.1, 0.1); + + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1,1,1}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test29) { + + auto x = NDArrayFactory::create('f', {1}); + auto y = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('f',{1}, {0.2}); + + x.linspace(2.); + y.linspace(0.1, 0.1); + + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test30) { + + auto x = NDArrayFactory::create('f', {1,1}); + auto y = NDArrayFactory::create('c', {1}); + auto exp = NDArrayFactory::create('f',{1}, {0.2}); + + x.linspace(2.); + y.linspace(0.1, 0.1); + + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test31) { + + auto x = NDArrayFactory::create('f', {4}); + auto y = NDArrayFactory::create('c', {4}); + auto exp = NDArrayFactory::create(3.); + + x.linspace(1.); + y.linspace(0.1, 0.1); + + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test32) { + + auto x = NDArrayFactory::create('f', {1}, {2.}); + auto y = NDArrayFactory::create('c', {1}, {3.}); + auto exp = NDArrayFactory::create(6.); + + sd::ops::matmul op; + auto results = op.evaluate({&x, &y}, {}, {1, 1}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} +///////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test33) { + auto x = NDArrayFactory::create('c', {4, 3}); + auto y = NDArrayFactory::create('c', {4, 1}); + auto exp = NDArrayFactory::create('c',{ 3, 1}, {70, 80, 90}); + + x.linspace(1); + y.linspace(1); + + sd::ops::matmul op; + auto result = op.evaluate({&x, &y}, {}, {1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} +////////////////////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test34) { + auto a = NDArrayFactory::create('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto b = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {30, 70, 110}); + + sd::ops::matmul op; + auto result = op.evaluate({&a, &b}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} +///////////////////////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test35) { + auto a = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + auto b = NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto exp = NDArrayFactory::create('c', {3}, {70, 80, 90}); + + sd::ops::matmul op; + auto result = op.evaluate({&a, &b}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} +//////////////////////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test36) { + auto a = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto b = NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto exp = NDArrayFactory::create('c', {1, 3}, {70, 80, 90}); + + sd::ops::matmul op; + auto result = op.evaluate({&a, &b}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, matmul_test37) { + + NDArray a('c', {32, 12, 128, 64}, sd::DataType::FLOAT32); + NDArray b('c', {32, 12, 128, 64}, sd::DataType::FLOAT32); + NDArray c('c', {32,12,128,128}, sd::DataType::FLOAT32); + NDArray cExp('c', {32,12,128,128}, sd::DataType::FLOAT32); + + a = 1; + b = 1; + cExp = 64; //Each entry in output c is sum of 64 (1.0 x 1.0) multiplications + + sd::ops::matmul op; + auto status = op.execute({&a, &b}, {&c}, {}, {0,1}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + + ASSERT_TRUE(cExp.isSameShape(c)); + ASSERT_TRUE(cExp.equalsTo(c)); +} +/////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Test_broadcast_3D_1) { + + // x[4, 12, 128] * y[4, 128] = z[4, 12, 128] + + auto x = NDArray('c', { 2, 3, 5 }, sd::DataType::FLOAT32); + auto y = NDArray('c', { 2, 5 }, sd::DataType::FLOAT32); + auto z = NDArray('c', { 2, 3, 5 }, sd::DataType::FLOAT32); + // recieved by main algorithm + auto e = NDArray('c', { 2, 3, 5 }, { 10.000000, 22.000000, 36.000000, 52.000000, 70.000000, 60.000000, 77.000000, 96.000000, 117.000000, 140.000000, 110.000000, 132.000000, 156.000000, 182.000000, 210.000000, 240.000000, 272.000000, 306.000000, 342.000000, 380.000000, 315.000000, 352.000000, 391.000000, 432.000000, 475.000000, 390.000000, 432.000000, 476.000000, 522.000000, 570.000000 }, sd::DataType::FLOAT32); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyBroadcast(sd::broadcast::Multiply, { 0,2 }, y, z); + //z.printBuffer(); + ASSERT_EQ(e, z); +} +/////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Test_broadcast_3D_2) { + + auto x = NDArray('f', { 2, 3, 5 }, sd::DataType::FLOAT32); + auto y = NDArray('f', { 2, 5 }, sd::DataType::FLOAT32); + auto z = NDArray('f', { 2, 3, 5 }, sd::DataType::FLOAT32); + // recieved by main algorithm + auto eC = NDArray('c', { 2, 3, 5 }, { 0.100000, 0.181818, 0.250000, 0.307692, 0.357143, 0.600000, 0.636364, 0.666667, 0.692308, 0.714286, 1.100000, 1.090909, 1.083333, 1.076923, 1.071429, 1.066667, 1.062500, 1.058824, 1.055556, 1.052632, 1.400000, 1.375000, 1.352941, 1.333333, 1.315789, 1.733333, 1.687500, 1.647059, 1.611111, 1.578947 }, sd::DataType::FLOAT32); + + auto e = NDArray('f', { 2, 3, 5 }, sd::DataType::FLOAT32); + + e.assign(eC); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyBroadcast(sd::broadcast::Divide, { 0,2 }, y, z); + + ASSERT_EQ(e, z); +} +/////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Test_broadcast_4D_1) { + + auto x = NDArray('c', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); + auto y = NDArray('c', { 2, 5, 4 }, sd::DataType::FLOAT32); + auto z = NDArray('c', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); + // recieved by main algorithm + auto e = NDArray('c', { 2, 3, 5, 4 }, { 10.000000, 22.000000, 36.000000, 52.000000, 70.000000, 90.000000, 112.000000, 136.000000, 162.000000, 190.000000, 220.000000, 252.000000, 286.000000, 322.000000, 360.000000, 400.000000, 442.000000, 486.000000, 532.000000, 580.000000, 210.000000, 242.000000, 276.000000, 312.000000, 350.000000, 390.000000, 432.000000, 476.000000, 522.000000, 570.000000, 620.000000, 672.000000, 726.000000, 782.000000, 840.000000, 900.000000, 962.000000, 1026.000000, 1092.000000, 1160.000000, 410.000000, 462.000000, 516.000000, 572.000000, 630.000000, 690.000000, 752.000000, 816.000000, 882.000000, 950.000000, 1020.000000, 1092.000000, 1166.000000, 1242.000000, 1320.000000, 1400.000000, 1482.000000, 1566.000000, 1652.000000, 1740.000000, 1830.000000, 1922.000000, 2016.000000, 2112.000000, 2210.000000, 2310.000000, 2412.000000, 2516.000000, 2622.000000, 2730.000000, 2840.000000, 2952.000000, 3066.000000, 3182.000000, 3300.000000, 3420.000000, 3542.000000, 3666.000000, 3792.000000, 3920.000000, 2430.000000, 2542.000000, 2656.000000, 2772.000000, 2890.000000, 3010.000000, 3132.000000, 3256.000000, 3382.000000, 3510.000000, 3640.000000, 3772.000000, 3906.000000, 4042.000000, 4180.000000, 4320.000000, 4462.000000, 4606.000000, 4752.000000, 4900.000000, 3030.000000, 3162.000000, 3296.000000, 3432.000000, 3570.000000, 3710.000000, 3852.000000, 3996.000000, 4142.000000, 4290.000000, 4440.000000, 4592.000000, 4746.000000, 4902.000000, 5060.000000, 5220.000000, 5382.000000, 5546.000000, 5712.000000, 5880.000000 }, sd::DataType::FLOAT32); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyBroadcast(sd::broadcast::Multiply, { 0,2,3 }, y, z); + + ASSERT_EQ(e, z); +} +/////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Test_broadcast_4D_2) { + + auto x = NDArray('f', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); + auto y = NDArray('f', { 2, 5, 4 }, sd::DataType::FLOAT32); + auto z = NDArray('f', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); + // recieved by main algorithm + auto eC = NDArray('c', { 2, 3, 5, 4 }, { 0.100000,0.181818,0.250000,0.307692,0.357143,0.400000,0.437500,0.470588,0.500000,0.526316,0.550000,0.571429, 0.590909,0.608696,0.625000,0.640000, 0.653846,0.666667,0.678571,0.689655, 2.100000,2.000000,1.916667, 1.846154, 1.785714, 1.733333,1.687500, 1.647059,1.611111, 1.578947,1.550000, 1.523810,1.500000, 1.478261,1.458333, 1.440000,1.423077, 1.407407,1.392857, 1.379310,4.100000, 3.818182,3.583333, 3.384615, 3.214286, 3.066667,2.937500, 2.823529,2.722222, 2.631579,2.550000, 2.476191,2.409091, 2.347826,2.291667, 2.240000,2.192308, 2.148148,2.107143, 2.068965,2.033333, 2.000000,1.968750, 1.939394,1.911765, 1.885714,1.861111, 1.837838,1.815789, 1.794872,1.775000, 1.756098,1.738095, 1.720930,1.704545, 1.688889,1.673913, 1.659575,1.645833,1.632653,2.700000,2.645161,2.593750,2.545455,2.500000,2.457143,2.416667,2.378378,2.342105,2.307692,2.275000,2.243902,2.214286,2.186047,2.159091,2.133333,2.108696,2.085106,2.062500,2.040816,3.366667,3.290323,3.218750,3.151515,3.088235,3.028571,2.972222,2.918919,2.868421,2.820513,2.775000,2.731707,2.690476,2.651163,2.613636,2.577778,2.543478,2.510638,2.479167,2.448980 }, sd::DataType::FLOAT32); + + auto e = NDArray('f', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); + + e.assign(eC); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyBroadcast(sd::broadcast::Divide, { 0,2,3 }, y, z); + + ASSERT_EQ(e, z); +} +/////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Test_broadcast_4D_3) { + + auto x = NDArray('f', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); + auto y = NDArray('f', { 2, 5 }, sd::DataType::FLOAT32); + auto z = NDArray('f', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); + // recieved by main algorithm + auto eC = NDArray('c', { 2, 3, 5, 4 }, { 0.100000, 0.200000, 0.300000, 0.400000, 0.454545, 0.545455, 0.636364, 0.727273, 0.750000, 0.833333, 0.916667, 1.000000, 1.000000, 1.076923, 1.153846, 1.230769, 1.214286, 1.285714, 1.357143, 1.428571, 2.100000, 2.200000, 2.300000, 2.400000, 2.272727, 2.363636, 2.454545, 2.545455, 2.416667, 2.500000, 2.583333, 2.666667, 2.538461, 2.615385, 2.692308, 2.769231, 2.642857, 2.714286, 2.785714, 2.857143, 4.100000, 4.200000, 4.300000, 4.400000, 4.090909, 4.181818, 4.272727, 4.363636, 4.083333, 4.166667, 4.250000, 4.333333, 4.076923, 4.153846, 4.230769, 4.307693, 4.071429, 4.142857, 4.214286, 4.285714, 4.066667, 4.133333, 4.200000, 4.266667, 4.062500, 4.125000, 4.187500, 4.250000, 4.058824, 4.117647, 4.176471, 4.235294, 4.055555, 4.111111, 4.166667, 4.222222, 4.052631, 4.105263, 4.157895, 4.210526, 5.400000, 5.466667, 5.533333, 5.600000, 5.312500, 5.375000, 5.437500, 5.500000, 5.235294, 5.294117, 5.352941, 5.411765, 5.166667, 5.222222, 5.277778, 5.333333, 5.105263, 5.157895, 5.210526, 5.263158, 6.733333, 6.800000, 6.866667, 6.933333, 6.562500, 6.625000, 6.687500, 6.750000, 6.411765, 6.470588, 6.529412, 6.588235, 6.277778, 6.333333, 6.388889, 6.444445, 6.157895, 6.210526, 6.263158, 6.315790 }, sd::DataType::FLOAT32); + + auto e = NDArray('f', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); + + e.assign(eC); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyBroadcast(sd::broadcast::Divide, { 0,2 }, y, z); + + ASSERT_EQ(e, z); +} +/////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Test_broadcast_4D_4) { + + // x[4, 12, 128, 128] * y[4, 1, 128, 1] = z[4, 12, 128, 128] + + auto x = NDArray('f', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); + auto y = NDArray('f', { 2, 1, 5, 1 }, sd::DataType::FLOAT32); + auto z = NDArray('f', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); + // recieved by main algorithm + auto eC = NDArray('c', { 2, 3, 5, 4 }, { 0.100000, 0.200000, 0.300000, 0.400000, 0.454545, 0.545455, 0.636364, 0.727273, 0.750000, 0.833333, 0.916667, 1.000000, 1.000000, 1.076923, 1.153846, 1.230769, 1.214286, 1.285714, 1.357143, 1.428571, 2.100000, 2.200000, 2.300000, 2.400000, 2.272727, 2.363636, 2.454545, 2.545455, 2.416667, 2.500000, 2.583333, 2.666667, 2.538461, 2.615385, 2.692308, 2.769231, 2.642857, 2.714286, 2.785714, 2.857143, 4.100000, 4.200000, 4.300000, 4.400000, 4.090909, 4.181818, 4.272727, 4.363636, 4.083333, 4.166667, 4.250000, 4.333333, 4.076923, 4.153846, 4.230769, 4.307693, 4.071429, 4.142857, 4.214286, 4.285714, 4.066667, 4.133333, 4.200000, 4.266667, 4.062500, 4.125000, 4.187500, 4.250000, 4.058824, 4.117647, 4.176471, 4.235294, 4.055555, 4.111111, 4.166667, 4.222222, 4.052631, 4.105263, 4.157895, 4.210526, 5.400000, 5.466667, 5.533333, 5.600000, 5.312500, 5.375000, 5.437500, 5.500000, 5.235294, 5.294117, 5.352941, 5.411765, 5.166667, 5.222222, 5.277778, 5.333333, 5.105263, 5.157895, 5.210526, 5.263158, 6.733333, 6.800000, 6.866667, 6.933333, 6.562500, 6.625000, 6.687500, 6.750000, 6.411765, 6.470588, 6.529412, 6.588235, 6.277778, 6.333333, 6.388889, 6.444445, 6.157895, 6.210526, 6.263158, 6.315790 }, sd::DataType::FLOAT32); + + auto e = NDArray('f', { 2, 3, 5, 4 }, sd::DataType::FLOAT32); + e.assign(eC); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Divide(), y, z); + + ASSERT_EQ(e, z); +} +/////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Test_broadcast_5D_1) { + // x[4, 12, 128, 128, 128] * y[4, 1, 128, 128, 128] = z[4, 12, 128, 128, 128] + auto x = NDArray('c', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); + auto y = NDArray('c', { 2, 1, 5, 4, 3 }, sd::DataType::FLOAT32); + auto z = NDArray('c', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); + // recieved by main algorithm + auto e = NDArray('c', { 2, 3, 5, 4, 3 }, { 10.000000, 22.000000, 36.000000, 52.000000, 70.000000, 90.000000, 112.000000, 136.000000, 162.000000, 190.000000, 220.000000, 252.000000, 286.000000, 322.000000, 360.000000, 400.000000, 442.000000, 486.000000, 532.000000, 580.000000, 630.000000, 682.000000, 736.000000, 792.000000, 850.000000, 910.000000, 972.000000, 1036.000000, 1102.000000, 1170.000000, 1240.000000, 1312.000000, 1386.000000, 1462.000000, 1540.000000, 1620.000000, 1702.000000, 1786.000000, 1872.000000, 1960.000000, 2050.000000, 2142.000000, 2236.000000, 2332.000000, 2430.000000, 2530.000000, 2632.000000, 2736.000000, 2842.000000, 2950.000000, 3060.000000, 3172.000000, 3286.000000, 3402.000000, 3520.000000, 3640.000000, 3762.000000, 3886.000000, 4012.000000, 4140.000000, 610.000000, 682.000000, 756.000000, 832.000000, 910.000000, 990.000000, 1072.000000, 1156.000000, 1242.000000, 1330.000000, 1420.000000, 1512.000000, 1606.000000, 1702.000000, 1800.000000, 1900.000000, 2002.000000, 2106.000000, 2212.000000, 2320.000000, 2430.000000, 2542.000000, 2656.000000, 2772.000000, 2890.000000, 3010.000000, 3132.000000, 3256.000000, 3382.000000, 3510.000000, 3640.000000, 3772.000000, 3906.000000, 4042.000000, 4180.000000, 4320.000000, 4462.000000, 4606.000000, 4752.000000, 4900.000000, 5050.000000, 5202.000000, 5356.000000, 5512.000000, 5670.000000, 5830.000000, 5992.000000, 6156.000000, 6322.000000, 6490.000000, 6660.000000, 6832.000000, 7006.000000, 7182.000000, 7360.000000, 7540.000000, 7722.000000, 7906.000000, 8092.000000, 8280.000000, 1210.000000, 1342.000000, 1476.000000, 1612.000000, 1750.000000, 1890.000000, 2032.000000, 2176.000000, 2322.000000, 2470.000000, 2620.000000, 2772.000000, 2926.000000, 3082.000000, 3240.000000, 3400.000000, 3562.000000, 3726.000000, 3892.000000, 4060.000000, 4230.000000, 4402.000000, 4576.000000, 4752.000000, 4930.000000, 5110.000000, 5292.000000, 5476.000000, 5662.000000, 5850.000000, 6040.000000, 6232.000000, 6426.000000, 6622.000000, 6820.000000, 7020.000000, 7222.000000, 7426.000000, 7632.000000, 7840.000000, 8050.000000, 8262.000000, 8476.000000, 8692.000000, 8910.000000, 9130.000000, 9352.000000, 9576.000000, 9802.000000, 10030.000000, 10260.000000, 10492.000000, 10726.000000, 10962.000000, 11200.000000, 11440.000000, 11682.000000, 11926.000000, 12172.000000, 12420.000000, 12670.000000, 12922.000000, 13176.000000, 13432.000000, 13690.000000, 13950.000000, 14212.000000, 14476.000000, 14742.000000, 15010.000000, 15280.000000, 15552.000000, 15826.000000, 16102.000000, 16380.000000, 16660.000000, 16942.000000, 17226.000000, 17512.000000, 17800.000000, 18090.000000, 18382.000000, 18676.000000, 18972.000000, 19270.000000, 19570.000000, 19872.000000, 20176.000000, 20482.000000, 20790.000000, 21100.000000, 21412.000000, 21726.000000, 22042.000000, 22360.000000, 22680.000000, 23002.000000, 23326.000000, 23652.000000, 23980.000000, 24310.000000, 24642.000000, 24976.000000, 25312.000000, 25650.000000, 25990.000000, 26332.000000, 26676.000000, 27022.000000, 27370.000000, 27720.000000, 28072.000000, 28426.000000, 28782.000000, 29140.000000, 29500.000000, 29862.000000, 30226.000000, 30592.000000, 30960.000000, 16870.000000, 17182.000000, 17496.000000, 17812.000000, 18130.000000, 18450.000000, 18772.000000, 19096.000000, 19422.000000, 19750.000000, 20080.000000, 20412.000000, 20746.000000, 21082.000000, 21420.000000, 21760.000000, 22102.000000, 22446.000000, 22792.000000, 23140.000000, 23490.000000, 23842.000000, 24196.000000, 24552.000000, 24910.000000, 25270.000000, 25632.000000, 25996.000000, 26362.000000, 26730.000000, 27100.000000, 27472.000000, 27846.000000, 28222.000000, 28600.000000, 28980.000000, 29362.000000, 29746.000000, 30132.000000, 30520.000000, 30910.000000, 31302.000000, 31696.000000, 32092.000000, 32490.000000, 32890.000000, 33292.000000, 33696.000000, 34102.000000, 34510.000000, 34920.000000, 35332.000000, 35746.000000, 36162.000000, 36580.000000, 37000.000000, 37422.000000, 37846.000000, 38272.000000, 38700.000000, 21070.000000, 21442.000000, 21816.000000, 22192.000000, 22570.000000, 22950.000000, 23332.000000, 23716.000000, 24102.000000, 24490.000000, 24880.000000, 25272.000000, 25666.000000, 26062.000000, 26460.000000, 26860.000000, 27262.000000, 27666.000000, 28072.000000, 28480.000000, 28890.000000, 29302.000000, 29716.000000, 30132.000000, 30550.000000, 30970.000000, 31392.000000, 31816.000000, 32242.000000, 32670.000000, 33100.000000, 33532.000000, 33966.000000, 34402.000000, 34840.000000, 35280.000000, 35722.000000, 36166.000000, 36612.000000, 37060.000000, 37510.000000, 37962.000000, 38416.000000, 38872.000000, 39330.000000, 39790.000000, 40252.000000, 40716.000000, 41182.000000, 41650.000000, 42120.000000, 42592.000000, 43066.000000, 43542.000000, 44020.000000, 44500.000000, 44982.000000, 45466.000000, 45952.000000, 46440.000000 }, sd::DataType::FLOAT32); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z); + // z.printBuffer(); + ASSERT_EQ(e, z); +} +/////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Test_broadcast_5D_2) { + + auto x = NDArray('f', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); + auto y = NDArray('f', { 2, 5, 4, 3 }, sd::DataType::FLOAT32); + auto z = NDArray('f', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); + // recieved by main algorithm + auto eC = NDArray('c', { 2, 3, 5, 4, 3 }, { 0.100000, 0.181818, 0.250000, 0.307692, 0.357143, 0.400000, 0.437500, 0.470588, 0.500000, 0.526316, 0.550000, 0.571429, 0.590909, 0.608696, 0.625000, 0.640000, 0.653846, 0.666667, 0.678571, 0.689655, 0.700000, 0.709677, 0.718750, 0.727273, 0.735294, 0.742857, 0.750000, 0.756757, 0.763158, 0.769231, 0.775000, 0.780488, 0.785714, 0.790698, 0.795455, 0.800000, 0.804348, 0.808511, 0.812500, 0.816327, 0.820000, 0.823529, 0.826923, 0.830189, 0.833333, 0.836364, 0.839286, 0.842105, 0.844828, 0.847458, 0.850000, 0.852459, 0.854839, 0.857143, 0.859375, 0.861538, 0.863636, 0.865672, 0.867647, 0.869565, 6.100000, 5.636364, 5.250000, 4.923077, 4.642857, 4.400000, 4.187500, 4.000000, 3.833333, 3.684211, 3.550000, 3.428571, 3.318182, 3.217391, 3.125000, 3.040000, 2.961539, 2.888889, 2.821429, 2.758621, 2.700000, 2.645161, 2.593750, 2.545455, 2.500000, 2.457143, 2.416667, 2.378378, 2.342105, 2.307692, 2.275000, 2.243902, 2.214286, 2.186047, 2.159091, 2.133333, 2.108696, 2.085106, 2.062500, 2.040816, 2.020000, 2.000000, 1.980769, 1.962264, 1.944444, 1.927273, 1.910714, 1.894737, 1.879310, 1.864407, 1.850000, 1.836066, 1.822581, 1.809524, 1.796875, 1.784615, 1.772727, 1.761194, 1.750000, 1.739130, 12.100000, 11.090909, 10.250000, 9.538462, 8.928572, 8.400000, 7.937500, 7.529412, 7.166667, 6.842105, 6.550000, 6.285714, 6.045455, 5.826087, 5.625000, 5.440000, 5.269231, 5.111111, 4.964286, 4.827586, 4.700000, 4.580645, 4.468750, 4.363636, 4.264706, 4.171429, 4.083333, 4.000000, 3.921053, 3.846154, 3.775000, 3.707317, 3.642857, 3.581395, 3.522727, 3.466667, 3.413043, 3.361702, 3.312500, 3.265306, 3.220000, 3.176471, 3.134615, 3.094340, 3.055556, 3.018182, 2.982143, 2.947368, 2.913793, 2.881356, 2.850000, 2.819672, 2.790323, 2.761905, 2.734375, 2.707692, 2.681818, 2.656716, 2.632353, 2.608696, 2.585714, 2.563380, 2.541667, 2.520548, 2.500000, 2.480000, 2.460526, 2.441558, 2.423077, 2.405063, 2.387500, 2.370370, 2.353658, 2.337349, 2.321429, 2.305882, 2.290698, 2.275862, 2.261364, 2.247191, 2.233333, 2.219780, 2.206522, 2.193548, 2.180851, 2.168421, 2.156250, 2.144330, 2.132653, 2.121212, 2.110000, 2.099010, 2.088235, 2.077670, 2.067308, 2.057143, 2.047170, 2.037383, 2.027778, 2.018349, 2.009091, 2.000000, 1.991071, 1.982301, 1.973684, 1.965217, 1.956897, 1.948718, 1.940678, 1.932773, 1.925000, 1.917355, 1.909836, 1.902439, 1.895161, 1.888000, 1.880952, 1.874016, 1.867188, 1.860465, 3.442857, 3.408451, 3.375000, 3.342466, 3.310811, 3.280000, 3.250000, 3.220779, 3.192308, 3.164557, 3.137500, 3.111111, 3.085366, 3.060241, 3.035714, 3.011765, 2.988372, 2.965517, 2.943182, 2.921348, 2.900000, 2.879121, 2.858696, 2.838710, 2.819149, 2.800000, 2.781250, 2.762887, 2.744898, 2.727273, 2.710000, 2.693069, 2.676471, 2.660194, 2.644231, 2.628572, 2.613208, 2.598131, 2.583333, 2.568807, 2.554545, 2.540540, 2.526786, 2.513274, 2.500000, 2.486957, 2.474138, 2.461539, 2.449152, 2.436975, 2.425000, 2.413223, 2.401639, 2.390244, 2.379032, 2.368000, 2.357143, 2.346457, 2.335938, 2.325581, 4.300000, 4.253521, 4.208333, 4.164383, 4.121622, 4.080000, 4.039474, 4.000000, 3.961539, 3.924051, 3.887500, 3.851852, 3.817073, 3.783133, 3.750000, 3.717647, 3.686047, 3.655172, 3.625000, 3.595506, 3.566667, 3.538461, 3.510870, 3.483871, 3.457447, 3.431579, 3.406250, 3.381443, 3.357143, 3.333333, 3.310000, 3.287129, 3.264706, 3.242718, 3.221154, 3.200000, 3.179245, 3.158879, 3.138889, 3.119266, 3.100000, 3.081081, 3.062500, 3.044248, 3.026316, 3.008696, 2.991379, 2.974359, 2.957627, 2.941176, 2.925000, 2.909091, 2.893443, 2.878049, 2.862903, 2.848000, 2.833333, 2.818898, 2.804688, 2.790698 }, sd::DataType::FLOAT32); + + auto e = NDArray('f', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); + + e.assign(eC); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyBroadcast(sd::broadcast::Divide, { 0,2,3,4 }, y, z); + + ASSERT_EQ(e, z); +} +/////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Test_broadcast_5D_3) { + + auto x = NDArray('f', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); + auto y = NDArray('f', { 2, 5 }, sd::DataType::FLOAT32); + auto z = NDArray('f', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); + // recieved by main algorithm + auto eC = NDArray('c', { 2, 3, 5, 4, 3 }, { 0.100000, 0.200000, 0.300000, 0.400000, 0.500000, 0.600000, 0.700000, 0.800000, 0.900000, 1.000000, 1.100000, 1.200000, 1.181818, 1.272727, 1.363636, 1.454545, 1.545455, 1.636364, 1.727273, 1.818182, 1.909091, 2.000000, 2.090909, 2.181818, 2.083333, 2.166667, 2.250000, 2.333333, 2.416667, 2.500000, 2.583333, 2.666667, 2.750000, 2.833333, 2.916667, 3.000000, 2.846154, 2.923077, 3.000000, 3.076923, 3.153846, 3.230769, 3.307692, 3.384615, 3.461539, 3.538461, 3.615385, 3.692308, 3.500000, 3.571429, 3.642857, 3.714286, 3.785714, 3.857143, 3.928571, 4.000000, 4.071429, 4.142857, 4.214286, 4.285714, 6.100000, 6.200000, 6.300000, 6.400000, 6.500000, 6.600000, 6.700000, 6.800000, 6.900000, 7.000000, 7.100000, 7.200000, 6.636364, 6.727273, 6.818182, 6.909091, 7.000000, 7.090909, 7.181818, 7.272727, 7.363636, 7.454545, 7.545455, 7.636364, 7.083333, 7.166667, 7.250000, 7.333333, 7.416667, 7.500000, 7.583333, 7.666667, 7.750000, 7.833333, 7.916667, 8.000000, 7.461538, 7.538462, 7.615385, 7.692307, 7.769231, 7.846154, 7.923077, 8.000000, 8.076923, 8.153846, 8.230769, 8.307693, 7.785714, 7.857143, 7.928571, 8.000000, 8.071428, 8.142858, 8.214286, 8.285714, 8.357142, 8.428572, 8.500000, 8.571428, 12.100000, 12.200000, 12.300000, 12.400000, 12.500000, 12.600000, 12.700000, 12.800000, 12.900000, 13.000000, 13.100000, 13.200000, 12.090909, 12.181818, 12.272727, 12.363636, 12.454545, 12.545455, 12.636364, 12.727273, 12.818182, 12.909091, 13.000000, 13.090909, 12.083333, 12.166667, 12.250000, 12.333333, 12.416667, 12.500000, 12.583333, 12.666667, 12.750000, 12.833333, 12.916667, 13.000000, 12.076923, 12.153846, 12.230769, 12.307693, 12.384615, 12.461538, 12.538462, 12.615385, 12.692307, 12.769231, 12.846154, 12.923077, 12.071428, 12.142858, 12.214286, 12.285714, 12.357142, 12.428572, 12.500000, 12.571428, 12.642858, 12.714286, 12.785714, 12.857142, 12.066667, 12.133333, 12.200000, 12.266666, 12.333333, 12.400000, 12.466666, 12.533334, 12.600000, 12.666667, 12.733334, 12.800000, 12.062500, 12.125000, 12.187500, 12.250000, 12.312500, 12.375000, 12.437500, 12.500000, 12.562500, 12.625000, 12.687500, 12.750000, 12.058824, 12.117647, 12.176471, 12.235294, 12.294118, 12.352942, 12.411765, 12.470589, 12.529411, 12.588235, 12.647058, 12.705882, 12.055555, 12.111111, 12.166667, 12.222222, 12.277778, 12.333333, 12.388889, 12.444445, 12.500000, 12.555555, 12.611111, 12.666667, 12.052631, 12.105263, 12.157895, 12.210526, 12.263158, 12.315789, 12.368421, 12.421053, 12.473684, 12.526316, 12.578947, 12.631579, 16.066668, 16.133333, 16.200001, 16.266666, 16.333334, 16.400000, 16.466667, 16.533333, 16.600000, 16.666666, 16.733334, 16.799999, 15.812500, 15.875000, 15.937500, 16.000000, 16.062500, 16.125000, 16.187500, 16.250000, 16.312500, 16.375000, 16.437500, 16.500000, 15.588235, 15.647058, 15.705882, 15.764706, 15.823529, 15.882353, 15.941176, 16.000000, 16.058823, 16.117647, 16.176470, 16.235294, 15.388889, 15.444445, 15.500000, 15.555555, 15.611111, 15.666667, 15.722222, 15.777778, 15.833333, 15.888889, 15.944445, 16.000000, 15.210526, 15.263158, 15.315789, 15.368421, 15.421053, 15.473684, 15.526316, 15.578947, 15.631579, 15.684211, 15.736842, 15.789474, 20.066668, 20.133333, 20.200001, 20.266666, 20.333334, 20.400000, 20.466667, 20.533333, 20.600000, 20.666666, 20.733334, 20.799999, 19.562500, 19.625000, 19.687500, 19.750000, 19.812500, 19.875000, 19.937500, 20.000000, 20.062500, 20.125000, 20.187500, 20.250000, 19.117647, 19.176470, 19.235294, 19.294117, 19.352942, 19.411764, 19.470589, 19.529411, 19.588236, 19.647058, 19.705883, 19.764706, 18.722221, 18.777779, 18.833334, 18.888889, 18.944445, 19.000000, 19.055555, 19.111111, 19.166666, 19.222221, 19.277779, 19.333334, 18.368422, 18.421053, 18.473684, 18.526316, 18.578947, 18.631578, 18.684210, 18.736841, 18.789474, 18.842106, 18.894737, 18.947369 }, sd::DataType::FLOAT32); + + auto e = NDArray('f', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); + + e.assign(eC); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyBroadcast(sd::broadcast::Divide, { 0,2 }, y, z); + + ASSERT_EQ(e, z); +} +/////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Test_broadcast_5D_4) { + + auto x = NDArray('f', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); + auto y = NDArray('f', { 2, 1, 5, 1, 1 }, sd::DataType::FLOAT32); + auto z = NDArray('f', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); + // recieved by main algorithm + auto eC = NDArray('c', { 2, 3, 5, 4, 3 }, { 0.100000, 0.200000, 0.300000, 0.400000, 0.500000, 0.600000, 0.700000, 0.800000, 0.900000, 1.000000, 1.100000, 1.200000, 1.181818, 1.272727, 1.363636, 1.454545, 1.545455, 1.636364, 1.727273, 1.818182, 1.909091, 2.000000, 2.090909, 2.181818, 2.083333, 2.166667, 2.250000, 2.333333, 2.416667, 2.500000, 2.583333, 2.666667, 2.750000, 2.833333, 2.916667, 3.000000, 2.846154, 2.923077, 3.000000, 3.076923, 3.153846, 3.230769, 3.307692, 3.384615, 3.461539, 3.538461, 3.615385, 3.692308, 3.500000, 3.571429, 3.642857, 3.714286, 3.785714, 3.857143, 3.928571, 4.000000, 4.071429, 4.142857, 4.214286, 4.285714, 6.100000, 6.200000, 6.300000, 6.400000, 6.500000, 6.600000, 6.700000, 6.800000, 6.900000, 7.000000, 7.100000, 7.200000, 6.636364, 6.727273, 6.818182, 6.909091, 7.000000, 7.090909, 7.181818, 7.272727, 7.363636, 7.454545, 7.545455, 7.636364, 7.083333, 7.166667, 7.250000, 7.333333, 7.416667, 7.500000, 7.583333, 7.666667, 7.750000, 7.833333, 7.916667, 8.000000, 7.461538, 7.538462, 7.615385, 7.692307, 7.769231, 7.846154, 7.923077, 8.000000, 8.076923, 8.153846, 8.230769, 8.307693, 7.785714, 7.857143, 7.928571, 8.000000, 8.071428, 8.142858, 8.214286, 8.285714, 8.357142, 8.428572, 8.500000, 8.571428, 12.100000, 12.200000, 12.300000, 12.400000, 12.500000, 12.600000, 12.700000, 12.800000, 12.900000, 13.000000, 13.100000, 13.200000, 12.090909, 12.181818, 12.272727, 12.363636, 12.454545, 12.545455, 12.636364, 12.727273, 12.818182, 12.909091, 13.000000, 13.090909, 12.083333, 12.166667, 12.250000, 12.333333, 12.416667, 12.500000, 12.583333, 12.666667, 12.750000, 12.833333, 12.916667, 13.000000, 12.076923, 12.153846, 12.230769, 12.307693, 12.384615, 12.461538, 12.538462, 12.615385, 12.692307, 12.769231, 12.846154, 12.923077, 12.071428, 12.142858, 12.214286, 12.285714, 12.357142, 12.428572, 12.500000, 12.571428, 12.642858, 12.714286, 12.785714, 12.857142, 12.066667, 12.133333, 12.200000, 12.266666, 12.333333, 12.400000, 12.466666, 12.533334, 12.600000, 12.666667, 12.733334, 12.800000, 12.062500, 12.125000, 12.187500, 12.250000, 12.312500, 12.375000, 12.437500, 12.500000, 12.562500, 12.625000, 12.687500, 12.750000, 12.058824, 12.117647, 12.176471, 12.235294, 12.294118, 12.352942, 12.411765, 12.470589, 12.529411, 12.588235, 12.647058, 12.705882, 12.055555, 12.111111, 12.166667, 12.222222, 12.277778, 12.333333, 12.388889, 12.444445, 12.500000, 12.555555, 12.611111, 12.666667, 12.052631, 12.105263, 12.157895, 12.210526, 12.263158, 12.315789, 12.368421, 12.421053, 12.473684, 12.526316, 12.578947, 12.631579, 16.066668, 16.133333, 16.200001, 16.266666, 16.333334, 16.400000, 16.466667, 16.533333, 16.600000, 16.666666, 16.733334, 16.799999, 15.812500, 15.875000, 15.937500, 16.000000, 16.062500, 16.125000, 16.187500, 16.250000, 16.312500, 16.375000, 16.437500, 16.500000, 15.588235, 15.647058, 15.705882, 15.764706, 15.823529, 15.882353, 15.941176, 16.000000, 16.058823, 16.117647, 16.176470, 16.235294, 15.388889, 15.444445, 15.500000, 15.555555, 15.611111, 15.666667, 15.722222, 15.777778, 15.833333, 15.888889, 15.944445, 16.000000, 15.210526, 15.263158, 15.315789, 15.368421, 15.421053, 15.473684, 15.526316, 15.578947, 15.631579, 15.684211, 15.736842, 15.789474, 20.066668, 20.133333, 20.200001, 20.266666, 20.333334, 20.400000, 20.466667, 20.533333, 20.600000, 20.666666, 20.733334, 20.799999, 19.562500, 19.625000, 19.687500, 19.750000, 19.812500, 19.875000, 19.937500, 20.000000, 20.062500, 20.125000, 20.187500, 20.250000, 19.117647, 19.176470, 19.235294, 19.294117, 19.352942, 19.411764, 19.470589, 19.529411, 19.588236, 19.647058, 19.705883, 19.764706, 18.722221, 18.777779, 18.833334, 18.888889, 18.944445, 19.000000, 19.055555, 19.111111, 19.166666, 19.222221, 19.277779, 19.333334, 18.368422, 18.421053, 18.473684, 18.526316, 18.578947, 18.631578, 18.684210, 18.736841, 18.789474, 18.842106, 18.894737, 18.947369 }, sd::DataType::FLOAT32); + + auto e = NDArray('f', { 2, 3, 5, 4, 3 }, sd::DataType::FLOAT32); + e.assign(eC); + + x.linspace(1.f); + y.linspace(10.f); + z.assign(0.f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Divide(), y, z); + + ASSERT_EQ(e, z); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Stack_1) { + + float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12}; + float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24}; + float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,16,16,17,18,19,20,21,22,23,24}; + Nd4jLong shape1[] = {2, 3, 4, 4, 1, 0, 1, 99}; + Nd4jLong shape2[] = {2, 3, 4, 4, 1, 0, 1, 99}; + Nd4jLong expShape[] = {3, 2, 3, 4, 12, 4, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); + + NDArray input1(buff1, shape1); + NDArray input2(buff2, shape2); + NDArray expected(expBuff, expShape); + + sd::ops::stack op; + auto results = op.evaluate({&input1, &input2}, {}, {0}); + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.equalsTo(output)); + + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Stack_2) { + + float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12}; + float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24}; + float expBuff[] = {1,2,3,4, 13, 14, 16, 16, 5,6,7,8, 17, 18, 19, 20, 9, 10, 11, 12, 21, 22, 23, 24}; + Nd4jLong shape1[] = {2, 3, 4, 4, 1, 0, 1, 99}; + Nd4jLong shape2[] = {2, 3, 4, 4, 1, 0, 1, 99}; + Nd4jLong expShape[] = {3, 3, 2, 4, 8, 4, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); + + NDArray input1(buff1, shape1); + NDArray input2(buff2, shape2); + NDArray expected(expBuff, expShape); + + sd::ops::stack op; + auto results = op.evaluate({&input1, &input2}, {}, {1}); + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Stack_3) { + + float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12}; + float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24}; + float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,16,16,17,18,19,20,21,22,23,24}; + Nd4jLong shape1[] = {2, 1, 12, 12, 1, 0, 1, 99}; + Nd4jLong shape2[] = {2, 1, 12, 12, 1, 0, 1, 99}; + Nd4jLong expShape[] = {3, 2, 1, 12, 12, 12, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); + + NDArray input1(buff1, shape1); + NDArray input2(buff2, shape2); + NDArray expected(expBuff, expShape); + + sd::ops::stack op; + auto results = op.evaluate({&input1, &input2}, {}, {0}); + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Stack_4) { + + float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12}; + float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24}; + float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,16,16,17,18,19,20,21,22,23,24}; + Nd4jLong shape1[] = {2, 1, 12, 12, 1, 0, 1, 99}; + Nd4jLong shape2[] = {2, 1, 12, 12, 1, 0, 1, 99}; + Nd4jLong expShape[] = {3, 1, 2, 12, 24, 12, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); + + NDArray input1(buff1, shape1); + NDArray input2(buff2, shape2); + NDArray expected(expBuff, expShape); + + sd::ops::stack op; + auto results = op.evaluate({&input1, &input2}, {}, {1}); + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Stack_5) { + + float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12}; + float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24}; + float expBuff[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12,13,14,16,16,17,18,19,20,21,22,23,24}; + Nd4jLong shape1[] = {2, 12, 1, 1,1, 0, 1, 99}; + Nd4jLong shape2[] = {2, 12, 1, 1,1, 0, 1, 99}; + Nd4jLong expShape[] = {3, 2, 12, 1, 12, 1, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); + + NDArray input1(buff1, shape1); + NDArray input2(buff2, shape2); + NDArray expected(expBuff, expShape); + + sd::ops::stack op; + auto results = op.evaluate({&input1, &input2}, {}, {0}); + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Stack_6) { + + float buff1[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10,11,12}; + float buff2[] = {13,14,16,16,17,18,19,20,21,22,23,24}; + float expBuff[] = {1 ,13 ,2 ,14 ,3 ,16 ,4 ,16 ,5 ,17 ,6 ,18 ,7 ,19 ,8 ,20 ,9 ,21 ,10 ,22 ,11 ,23 ,12 ,24}; + Nd4jLong shape1[] = {2, 12, 1, 1, 12, 0, 1, 99}; + Nd4jLong shape2[] = {2, 12, 1, 1, 12, 0, 1, 99}; + Nd4jLong expShape[] = {3, 12, 2, 1, 2, 1, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(shape2, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); + + NDArray input1(buff1, shape1); + NDArray input2(buff2, shape2); + NDArray expected(expBuff, expShape); + + sd::ops::stack op; + auto results = op.evaluate({&input1, &input2}, {}, {1}); + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Stack_7) { + + float buff1[] = {1}; + float expBuff[] = {1, 1, 1}; + Nd4jLong shape1[] = {2, 1, 1, 1, 1, 0, 1, 99}; + Nd4jLong expShape[] = {3, 3, 1, 1, 1, 1, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); + + NDArray input1(buff1, shape1); + NDArray expected(expBuff, expShape); + + sd::ops::stack op; + auto results = op.evaluate({&input1, &input1, &input1}, {}, {0}); + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Stack_8) { + + float buff1[] = {1}; + float expBuff[] = {1, 1, 1}; + Nd4jLong shape1[] = {1, 1, 1, 0, 1, 99}; + Nd4jLong expShape[] = {2, 3, 1, 1, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); + + NDArray input1(buff1, shape1); + NDArray expected(expBuff, expShape); + + sd::ops::stack op; + auto results = op.evaluate({&input1, &input1, &input1}, {}, {0}); + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Stack_9) { + + float buff1[] = {1}; + float expBuff[] = {1, 1, 1}; + Nd4jLong shape1[] = {2, 1, 1, 1, 1, 0, 1, 99}; + Nd4jLong expShape[] = {3, 1, 3, 1, 3, 1, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); + + NDArray input1(buff1, shape1); + NDArray expected(expBuff, expShape); + + sd::ops::stack op; + auto results = op.evaluate({&input1, &input1, &input1}, {}, {1}); + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Stack_10) { + + float buff1[] = {1}; + float expBuff[] = {1, 1, 1}; + Nd4jLong shape1[] = {1, 1, 1, 0, 1, 99}; + Nd4jLong expShape[] = {2, 1, 3, 3, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); + + NDArray input1(buff1, shape1); + NDArray expected(expBuff, expShape); + + sd::ops::stack op; + auto results = op.evaluate({&input1, &input1, &input1}, {}, {1}); + auto output = results.at(0); + + //expected.printShapeInfo("exp"); + //output->printShapeInfo("out"); + + ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +TEST_F(DeclarableOpsTests14, Stack_11) { + + float buff1[] = {1}; + float expBuff[] = {1, 1, 1}; + Nd4jLong shape1[] = {1, 1, 1, 0, 1, 99}; + Nd4jLong expShape[] = {2, 3, 1, 1, 1, 0, 1, 99}; + ArrayOptions::setDataType(shape1, sd::DataType::FLOAT32); + ArrayOptions::setDataType(expShape, sd::DataType::FLOAT32); + + NDArray input1(buff1, shape1); + NDArray expected(expBuff, expShape); + + sd::ops::stack op; + auto results = op.evaluate({&input1, &input1, &input1}, {}, {}); + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Stack_12) { + float inBuff[] = {1.0f, 2.0f, 3.0f}; + float expBuff[] = {1.0f, 2.0f, 3.0f}; + + auto input = NDArrayFactory::create(inBuff, 'c', {1, 3}); + + auto exp = NDArrayFactory::create(expBuff, 'c', {1, 1, 3}); + + sd::ops::stack op; + + auto result = op.evaluate({&input}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Stack_13) { + float inBuff[] = {1.0f, 2.0f, 3.0f}; + float expBuff[] = {1.0f, 2.0f, 3.0f}; + + auto input = NDArrayFactory::create(inBuff, 'c', {1, 1, 3}); + + auto exp = NDArrayFactory::create(expBuff, 'c', {1, 1, 1, 3}); + + sd::ops::stack op; + + auto result = op.evaluate({&input}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Stack_14) { + float inBuff[] = {1.0f, 2.0f, 3.0f}; + float expBuff[] = {1.0f, 2.0f, 3.0f}; + + auto input = NDArrayFactory::create(inBuff, 'c', {1, 3}); + + auto exp = NDArrayFactory::create(expBuff, 'c', {1, 1, 3}); + + sd::ops::stack op; + + auto result = op.evaluate({&input}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + //z->printShapeInfo(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests14, Stack_15) { + auto t = NDArrayFactory::create('c', {2, 3, 5}); + auto u = NDArrayFactory::create('c', {2, 3, 5}); + auto v = NDArrayFactory::create('c', {2, 3, 5}); + auto exp = NDArrayFactory::create('c', {3, 2, 3, 5}); + + sd::ops::stack op; + auto result = op.evaluate({&t, &u, &v}, {}, {-4}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + + +} + + +TEST_F(DeclarableOpsTests14, Stack_16) { + auto t = NDArrayFactory::create(1.0f); + auto u = NDArrayFactory::create(2.0f); + auto v = NDArrayFactory::create(3.0f); + auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); + + sd::ops::stack op; + auto result = op.evaluate({&t, &u, &v}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests14, Stack_17) { + auto t = NDArrayFactory::create('c', {1, 1}, {1.0f}); + auto u = NDArrayFactory::create('c', {1, 1}, {2.0f}); + auto v = NDArrayFactory::create('c', {1, 1}, {3.0f}); + auto w = NDArrayFactory::create('c', {1, 1}, {4.0f}); + auto exp = NDArrayFactory::create('c', {4, 1, 1}, {1, 2, 3, 4}); + + sd::ops::stack op; + auto result = op.evaluate({&t, &u, &v, &w}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + // z->printShapeInfo("z shape"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests14, Stack_18) { + auto x = NDArrayFactory::create('c', {0}); + auto e = NDArrayFactory::create('c', {1, 0}); + + sd::ops::stack op; + auto result = op.evaluate({&x}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_EQ(e, *z); + sd::ops::reduce_min sumOp; + auto res2 = sumOp.evaluate({&e}, {1.}, {1}); + ASSERT_EQ(res2.status(), Status::OK()); + auto out = res2.at(0); + + ASSERT_EQ(out->e(0), DataTypeUtils::infOrMax()); + + +} + +TEST_F(DeclarableOpsTests14, Stack_19) { + auto x = NDArrayFactory::empty(); + auto e = NDArrayFactory::create('c', {0}); + + sd::ops::stack op; + auto result = op.evaluate({&x}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_EQ(e, *z); + + +} + +TEST_F(DeclarableOpsTests14, Stack_20) { + auto x = NDArrayFactory::empty(); + auto e = NDArrayFactory::create('c', {2, 0}); + + sd::ops::stack op; + auto result = op.evaluate({&x, &x}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_EQ(e, *z); + + +} + +TEST_F(DeclarableOpsTests14, Stack_21) { + + NDArray x1('c', {3,2}, sd::DataType::FLOAT32); + NDArray x2('c', {3,2}, sd::DataType::FLOAT32); + x1.linspace(0); + x2.linspace(6); + + sd::ops::stack opStack; + auto resultStack = opStack.evaluate({&x1, &x2}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, resultStack.status()); + + + sd::ops::concat opConcat; + auto resultConcat = opConcat.evaluate({&x1, &x2}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, resultConcat.status()); + + auto outStack = resultStack.at(0); + auto outConcat = resultConcat.at(0); + + outConcat->reshapei({2,3,2}); + + ASSERT_TRUE(outStack->isSameShape(outConcat)); + ASSERT_TRUE(outStack->equalsTo(outConcat)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Reshape1) { + const std::vector xShape = { 5,4,3 }; + const std::vector yShape = { 3,5,4 }; + + auto x = NDArrayFactory::create_('f', xShape); + auto y = NDArrayFactory::create_('f', yShape); + + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + auto block = new Context(1, variableSpace, true); + block->fillInputs({ -1, -2 }); + + sd::ops::reshapeas reshape; + + reshape.execute(block); + + ASSERT_TRUE(x->isSameShape(y)); + + delete variableSpace; + delete block; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests14, Reshape2) { + const std::vector xShape = { 5,4,3 }; + const std::vector yShape = { 3,5,4 }; + + auto x = NDArrayFactory::create_('c', xShape); + auto y = NDArrayFactory::create_('c', yShape); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(1, new Variable()); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({ -1 }); + std::vector* arguments = block->getIArguments(); + arguments->push_back(-y->ordering()); + arguments->push_back(3); + arguments->push_back(5); + arguments->push_back(4); + + sd::ops::reshape reshape; + + Nd4jStatus status = reshape.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + + ASSERT_TRUE(result->isSameShape(y)); + + delete y; + delete block; + delete variableSpace; +} + +TEST_F(DeclarableOpsTests14, Flatten2d1) { + auto x = NDArrayFactory::create('c', { 3, 4, 5 }); + auto zAssertion = NDArrayFactory::create('c', { 3, 20 }); + + sd::ops::flatten_2d op; + auto result = op.evaluate({ &x }, {}, { 1 }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(result.at(0)->isSameShape(zAssertion)); +} + +TEST_F(DeclarableOpsTests14, Flatten2d2) { + auto x = NDArrayFactory::create('c', { 2,3, 4, 5 }); + auto zAssertion = NDArrayFactory::create('c', { 6, 20 }); + + sd::ops::flatten_2d op; + auto result = op.evaluate({ &x }, {}, { -2 }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(result.at(0)->isSameShape(zAssertion)); +} + +TEST_F(DeclarableOpsTests14, Reshape3) { + auto x = NDArrayFactory::create('c', { 3, 4, 5 }); + + sd::ops::reshape op; + auto result = op.evaluate({ &x }, {}, { -99, 3, 4, 5 }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(x.isSameShape(z)); +} + +TEST_F(DeclarableOpsTests14, Reshape4) { + auto x = NDArrayFactory::create('c', { 3, 4, 5 }); + + sd::ops::reshape op; + auto result = op.evaluate({ &x }, {}, { 3, 4, 5 }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(x.isSameShape(z)); +} + +TEST_F(DeclarableOpsTests14, Reshape5) { + auto x = NDArrayFactory::create('c', { 3, 4, 5 }); + + sd::ops::reshape op; + auto result = op.evaluate({ &x }, {}, { 5, 4, 3 }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); +} + +TEST_F(DeclarableOpsTests14, Reshape6) { + auto x = NDArrayFactory::create('c', { 3, 4, 5 }); + auto exp = NDArrayFactory::create('c', { 4, 15 }); + + sd::ops::reshape op; + auto result = op.evaluate({ &x }, {}, { 4, -1 }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(z->isSameShape(exp)); +} + +TEST_F(DeclarableOpsTests14, Reshape7) { + auto x = NDArrayFactory::create('c', { 3, 4, 5 }); + auto exp = NDArrayFactory::create('c', { 60 }); + + sd::ops::reshape op; + auto result = op.evaluate({ &x }, {}, { -1 }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(z->isSameShape(exp)); +} + +TEST_F(DeclarableOpsTests14, Reshape8) { + auto x = NDArrayFactory::create('f', {2, 3}, {1.0, 4.0, 2.0, 5.0, 3.0, 6.0}); + auto e = NDArrayFactory::create('f', {3, 2}, {1.0, 3.0, 5.0, 2.0, 4.0, 6.0}); + + auto r = x.reshape('c', {3, 2});; + r.streamline('f'); + + sd::ops::reshape op; + auto result = op.evaluate({&x}, {3, 2}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); +} + +TEST_F(DeclarableOpsTests14, Reshape9) { + auto array = NDArrayFactory::create(119.f); + auto e = NDArrayFactory::create('c', {1, 1}, {119.f}); + + sd::ops::reshape op; + auto result = op.evaluate({&array}, {}, {1, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_EQ(e, *z); +} + +TEST_F(DeclarableOpsTests14, Reshape10) { + auto array = NDArrayFactory::create(119.f); + auto e = NDArrayFactory::create('c', {1, 1}, {119.f}); + auto z = NDArrayFactory::create('c', {1, 1}); + + sd::ops::reshape op; + auto result = op.execute({&array}, {&z}, {}, {1, 1}, {}); + ASSERT_EQ(Status::OK(), result); + ASSERT_EQ(e, z); +} + +TEST_F(DeclarableOpsTests14, Reshape11) { + auto x = NDArrayFactory::create('c', {4, 3}); + auto exp = NDArrayFactory::create('c', {4, 3}); + + x.linspace(1); + exp.linspace(1); + + sd::ops::reshape op; + auto result = op.evaluate({&x}, {-99, 4, 3}); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +TEST_F(DeclarableOpsTests14, Reshape12) { + auto x = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + auto shape = NDArrayFactory::create('c', {2}, {-1, 2}); + auto exp = NDArrayFactory::create('c', {4, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + + sd::ops::reshape op; + auto result = op.evaluate({&x, &shape}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +TEST_F(DeclarableOpsTests14, Reshape13) { + auto vector = NDArrayFactory::create('c', {1}, {119.0f}); + auto exp = NDArrayFactory::create(119.f); + auto empty = NDArrayFactory::empty_(); + + sd::ops::reshape op; + auto result = op.evaluate({&vector, empty}, {}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_EQ(exp, *result.at(0)); + + delete empty; +} + +TEST_F(DeclarableOpsTests14, Reshape14) { + auto x = NDArrayFactory::create('c', {1, 0, 0, 2}); + auto y = NDArrayFactory::create('c', {2}, {10, 0}); + auto e = NDArrayFactory::create('c', {10, 0}); + + sd::ops::reshape op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_EQ(e, *z); + +} + +TEST_F(DeclarableOpsTests14, Reshape15) { + + auto x0 = NDArrayFactory::create('c', {2, 0}); + auto x1 = NDArrayFactory::create('c', {0, 1, 2}); + + auto shape0 = NDArrayFactory::create('c', {3}, {2, 0, -1}); + auto shape1 = NDArrayFactory::create('c', {2}, {-1, 1}); + + auto e0 = NDArrayFactory::create('c', {2, 0, 1}); + auto e1 = NDArrayFactory::create('c', {0, 1}); + + sd::ops::reshape op; + auto result0 = op.evaluate({&x0, &shape0}, {}, {}); + ASSERT_EQ(Status::OK(), result0.status()); + auto z0 = result0.at(0); + ASSERT_EQ(e0, *z0); + + auto result1 = op.evaluate({&x1, &shape1}, {}, {}); + ASSERT_EQ(Status::OK(), result1.status()); + auto z1 = result1.at(0); + ASSERT_EQ(e1, *z1); +} + +TEST_F(DeclarableOpsTests14, Reshape16) { + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto shape = NDArrayFactory::create('c', {1, 3}, {1, 2, 2}); + + auto exp = NDArrayFactory::create('c', {1, 2, 2}, {1, 2, 3, 4}); + + sd::ops::reshape op; + + auto result = op.evaluate({&x, &shape}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +TEST_F(DeclarableOpsTests14, Reshape17) { + auto x = NDArrayFactory::create(2.0f); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {2.0f}); + + sd::ops::reshape op; + auto result = op.evaluate({&x}, {}, {-99, 1, 1, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +TEST_F(DeclarableOpsTests14, Reshape18) { + auto x = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); + + sd::ops::reshape op; + auto result = op.evaluate({&x}, {}, {-99, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests14, Reshape19) { + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); + + sd::ops::reshape op; + auto result = op.evaluate({&x}, {}, {-99, 1, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + + +TEST_F(DeclarableOpsTests14, Reshape20) { + + NDArray x1('c', {2,0}, sd::DataType::FLOAT32); + NDArray x2('c', {10,0}, sd::DataType::FLOAT32); + NDArray x3('c', {2,0,0,10}, sd::DataType::FLOAT32); + NDArray x4('c', {0,0,10}, sd::DataType::FLOAT32); + NDArray x5('c', {0,2,10}, sd::DataType::FLOAT32); + NDArray x6('c', {0,10,0}, sd::DataType::FLOAT32); + NDArray x7('c', {0,1,2}, sd::DataType::FLOAT32); + NDArray x8('c', {1,2,0}, sd::DataType::FLOAT32); + + sd::ops::reshape op; + + auto result = op.evaluate({&x1}, {}, {2, -1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({2,0})); + + result = op.evaluate({&x2}, {}, {2, 0, -1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({2,0,5})); + + result = op.evaluate({&x2}, {}, {5, 2, -1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({5,2,0})); + + result = op.evaluate({&x2}, {}, {-1, 2, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({5,2,0})); + + result = op.evaluate({&x3}, {}, {2, 0, -1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({2,0,10})); + + result = op.evaluate({&x4}, {}, {2, -1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({2,5,0})); + + result = op.evaluate({&x5}, {}, {2, 0, 0, 0, -1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({2,0,0,0,10})); + + result = op.evaluate({&x6}, {}, {-1, 2, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({5, 2, 0})); + + result = op.evaluate({&x7}, {}, {-1, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({2, 0})); + + result = op.evaluate({&x7}, {}, {10,0,50,100}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({10,0,50,100})); + + result = op.evaluate({&x7}, {}, {2,0,-1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(result.at(0)->isSameShape({2,0,1})); +} diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests15.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests15.cpp new file mode 100644 index 000000000..779d0ccf8 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests15.cpp @@ -0,0 +1,2027 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +// +// Created by raver on 8/4/2018. +// + +#include "testlayers.h" +#include +#include +#include +#include +#include + + +using namespace sd; + + +class DeclarableOpsTests15 : public testing::Test { +public: + + DeclarableOpsTests15() { + printf("\n"); + fflush(stdout); + } +}; + +TEST_F(DeclarableOpsTests15, Test_NormalizeMoments_1) { + auto d = NDArrayFactory::create('c', {10, 10}); + auto w = NDArrayFactory::create(10); + auto x = NDArrayFactory::create('c', {10}); + auto y = NDArrayFactory::create('c', {10}); + + auto z0 = NDArrayFactory::create('c', {10}); + auto z1 = NDArrayFactory::create('c', {10}); + + sd::ops::normalize_moments op; + auto result = op.execute({&w, &x, &y}, std::vector{&z0, &z1}, {1e-4}, {}, {}); + ASSERT_EQ(Status::OK(), result); +} + +TEST_F(DeclarableOpsTests15, Test_Add_1) { + auto x = NDArrayFactory::create('c', {5}, {1, 1, 1, 1, 1}); + auto y = NDArrayFactory::create('c', {5}, {1, 1, 1, 1, 1}); + auto e = NDArrayFactory::create('c', {5}, {2, 2, 2, 2, 2}); + + sd::ops::add op; + auto result = op.execute({&x, &y}, {&x}, {}, {}, {}); + ASSERT_EQ(Status::OK(), result); + ASSERT_EQ(e, x); +} + +TEST_F(DeclarableOpsTests15, Test_Half_assign_1) { + auto x = NDArrayFactory::create('c', {2, 5}); + int y = 1; + x.assign(y); + + ASSERT_EQ(10, x.sumNumber().e(0)); +} + +TEST_F(DeclarableOpsTests15, Test_standarize_1) { + auto x = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); + auto e = NDArrayFactory::create('c', {5}, {0.f, 0.f, 0.f, 0.f, 0.f}); + + sd::ops::standardize op; + auto result = op.execute({&x}, {&x}, {}, {0}, {}); + ASSERT_EQ(Status::OK(), result); + ASSERT_EQ(e, x); +} + +TEST_F(DeclarableOpsTests15, Test_standarize_bp_1) { + auto x = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); + auto eps = NDArrayFactory::create('c', {5}, {0.f, 0.f, 0.f, 0.f, 0.f}); + + sd::ops::standardize_bp op; + auto result = op.evaluate({&x, &eps}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + +} + +TEST_F(DeclarableOpsTests15, Test_AdjustContrast_1) { + auto x = NDArrayFactory::create('c', {4,4,3}); + NDArray factor = NDArrayFactory::create(2.); + auto e = NDArrayFactory::create('c', {4,4,3}, {-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5, + 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5, + 26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5, + 50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5}); + + + x.linspace(1.); + sd::ops::adjust_contrast op; + auto result = op.evaluate({&x, &factor}, {}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto out = result.at(0); + + ASSERT_TRUE(e.equalsTo(out)); + +} + +TEST_F(DeclarableOpsTests15, Test_AdjustContrast_2) { + auto x = NDArrayFactory::create('c', {1, 4,4,3}); + auto e = NDArrayFactory::create('c', {1, 4,4,3}, { + -21.5f, -20.5f, -19.5f, -15.5f, -14.5f, -13.5f, -9.5f, -8.5f, -7.5f, -3.5f, -2.5f, -1.5f, + 2.5f, 3.5f, 4.5f, 8.5f, 9.5f, 10.5f, 14.5f, 15.5f, 16.5f, 20.5f, 21.5f, 22.5f, + 26.5f, 27.5f, 28.5f, 32.5f, 33.5f, 34.5f, 38.5f, 39.5f, 40.5f, 44.5f, 45.5f, 46.5f, + 50.5f, 51.5f, 52.5f, 56.5f, 57.5f, 58.5f, 62.5f, 63.5f, 64.5f, 68.5f, 69.5f, 70.5f + }); + x.linspace(1.); + sd::ops::adjust_contrast op; + auto result = op.evaluate({&x}, {2.}); + ASSERT_EQ(Status::OK(), result.status()); + auto out = result.at(0); +// out->printIndexedBuffer("Adjusted Constrast"); + ASSERT_TRUE(e.equalsTo(out)); + +} + +TEST_F(DeclarableOpsTests15, Test_AdjustContrast_3) { + auto x = NDArrayFactory::create('c', {1, 4,4,3}); + auto e = NDArrayFactory::create('c', {1, 4,4,3}, { + -21.5f, -20.5f, -19.5f, -15.5f, -14.5f, -13.5f, -9.5f, -8.5f, -7.5f, -3.5f, -2.5f, -1.5f, + 2.5f, 3.5f, 4.5f, 8.5f, 9.5f, 10.5f, 14.5f, 15.5f, 16.5f, 20.5f, 21.5f, 22.5f, + 26.5f, 27.5f, 28.5f, 32.5f, 33.5f, 34.5f, 38.5f, 39.5f, 40.5f, 44.5f, 45.5f, 46.5f, + 50.5f, 51.5f, 52.5f, 56.5f, 57.5f, 58.5f, 62.5f, 63.5f, 64.5f, 68.5f, 69.5f, 70.5f + }); + x.linspace(1.); + sd::ops::adjust_contrast_v2 op; + auto result = op.evaluate({&x}, {2.}); + ASSERT_EQ(Status::OK(), result.status()); + auto out = result.at(0); +// out->printIndexedBuffer("Adjusted Constrast"); + ASSERT_TRUE(e.equalsTo(out)); + +} + +TEST_F(DeclarableOpsTests15, Test_AdjustContrast_4) { + auto x = NDArrayFactory::create('c', {4, 4, 3}); + auto e = NDArrayFactory::create('c', {4, 4, 3}, { + -21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5, + 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5, + 26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5, + 50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5 + }); + x.linspace(1.); + sd::ops::adjust_contrast_v2 op; + auto result = op.evaluate({&x}, {2.}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto out = result.at(0); +// out->printIndexedBuffer("Adjusted Constrast"); + ASSERT_TRUE(e.equalsTo(out)); + +} + +TEST_F(DeclarableOpsTests15, Test_AdjustContrast_5) { + auto x = NDArrayFactory::create('c', {1, 3, 4}); + auto e = NDArrayFactory::create('c', {1, 3, 4}, { + -3., -2., -1., 0., 5., 6., 7., 8., 13., 14., 15., 16. + }); + x.linspace(1.); + sd::ops::adjust_contrast_v2 op; + auto result = op.evaluate({&x}, {2.}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto out = result.at(0); +// out->printIndexedBuffer("Adjusted Constrast"); + ASSERT_TRUE(e.equalsTo(out)); + +} + +/* + * public void testAdjustContrast1() { + INDArray in = Nd4j.createFromArray(new float[]{0.7788f,0.8012f,0.7244f,0.2309f,0.7271f,0.1804f, + 0.5056f,0.8925f,0.5461f,0.9234f,0.0856f,0.7938f,0.6591f,0.5555f,0.1596f,0.3087f,0.1548f,0.4695f, + 0.9939f,0.6113f,0.6765f,0.1800f,0.6750f,0.2246f,0.0509f,0.4601f,0.8284f,0.2354f,0.9752f,0.8361f, + 0.2585f,0.4189f,0.7028f,0.7679f,0.5373f,0.7234f,0.2690f,0.0062f,0.0327f,0.0644f,0.8428f,0.7494f, + 0.0755f,0.6245f,0.3491f,0.5793f,0.5730f,0.1822f,0.6420f,0.9143f,0.3019f, + 0.3574f,0.1704f,0.8395f,0.5468f,0.0744f,0.9011f,0.6574f,0.4124f,0.2445f,0.4248f,0.5219f, + 0.6952f,0.4900f,0.2158f,0.9549f,0.1386f,0.1544f,0.5365f,0.0134f,0.4163f,0.1456f,0.4109f, + 0.2484f, 0.3330f,0.2974f,0.6636f,0.3808f,0.8664f, 0.1896f, 0.7530f, 0.7215f, 0.6612f, 0.7270f, + 0.5704f,0.2666f,0.7453f,0.0444f,0.3024f,0.4850f,0.7982f,0.0965f,0.7843f,0.5075f, + 0.0844f,0.8370f,0.6103f,0.4604f,0.6087f, 0.8594f, 0.4599f, 0.6714f, 0.2744f, 0.1981f, 0.4143f, + 0.7821f,0.3505f,0.5040f,0.1180f,0.8307f,0.1817f,0.8442f,0.5074f,0.4471f,0.5105f,0.6666f, + 0.2576f,0.2341f,0.6801f,0.2652f,0.5394f,0.4690f,0.6146f,0.1210f,0.2576f,0.0769f,0.4643f, + 0.1628f,0.2026f,0.3774f,0.0506f,0.3462f,0.5720f,0.0838f,0.4228f,0.0588f,0.5362f,0.4756f, + 0.2530f,0.1778f,0.0751f,0.8977f,0.3648f,0.3065f,0.4739f,0.7014f,0.4473f,0.5171f,0.1744f, + 0.3487f,0.7759f,0.9491f,0.2072f,0.2182f,0.6520f,0.3092f,0.9545f,0.1881f,0.9579f,0.1785f, + 0.9636f,0.4830f,0.6569f,0.3353f,0.9997f,0.5869f,0.5747f,0.0238f,0.2943f,0.5248f,0.5879f, + .7266f,0.1965f,0.9167f,0.9726f,0.9206f,0.0519f,0.2997f,0.0039f,0.7652f,0.5498f, + 0.3794f,0.3791f,0.3528f,0.2873f,0.8082f,0.4732f,0.4399f,0.6606f,0.5991f,0.0034f,0.4874f + }).reshape(8,8,3,1); + INDArray out = Nd4j.create(DataType.FLOAT, in.shape()); + INDArray[] res = Nd4j.exec(new AdjustContrast(in, 2.0, out)); + assertArrayEquals(out.shape(), in.shape()); + //assertEquals(expected, out); + } + * */ + +TEST_F(DeclarableOpsTests15, Test_AdjustContrast_6) { + auto x = NDArrayFactory::create('c', {8,8, 3, 1}, {0.7788f,0.8012f,0.7244f,0.2309f,0.7271f,0.1804f, + 0.5056f,0.8925f,0.5461f,0.9234f,0.0856f,0.7938f,0.6591f,0.5555f,0.1596f,0.3087f,0.1548f,0.4695f, + 0.9939f,0.6113f,0.6765f,0.1800f,0.6750f,0.2246f,0.0509f,0.4601f,0.8284f,0.2354f,0.9752f,0.8361f, + 0.2585f,0.4189f,0.7028f,0.7679f,0.5373f,0.7234f,0.2690f,0.0062f,0.0327f,0.0644f,0.8428f,0.7494f, + 0.0755f,0.6245f,0.3491f,0.5793f,0.5730f,0.1822f,0.6420f,0.9143f,0.3019f, + 0.3574f,0.1704f,0.8395f,0.5468f,0.0744f,0.9011f,0.6574f,0.4124f,0.2445f,0.4248f,0.5219f, + 0.6952f,0.4900f,0.2158f,0.9549f,0.1386f,0.1544f,0.5365f,0.0134f,0.4163f,0.1456f,0.4109f, + 0.2484f, 0.3330f,0.2974f,0.6636f,0.3808f,0.8664f, 0.1896f, 0.7530f, 0.7215f, 0.6612f, 0.7270f, + 0.5704f,0.2666f,0.7453f,0.0444f,0.3024f,0.4850f,0.7982f,0.0965f,0.7843f,0.5075f, + 0.0844f,0.8370f,0.6103f,0.4604f,0.6087f, 0.8594f, 0.4599f, 0.6714f, 0.2744f, 0.1981f, 0.4143f, + 0.7821f,0.3505f,0.5040f,0.1180f,0.8307f,0.1817f,0.8442f,0.5074f,0.4471f,0.5105f,0.6666f, + 0.2576f,0.2341f,0.6801f,0.2652f,0.5394f,0.4690f,0.6146f,0.1210f,0.2576f,0.0769f,0.4643f, + 0.1628f,0.2026f,0.3774f,0.0506f,0.3462f,0.5720f,0.0838f,0.4228f,0.0588f,0.5362f,0.4756f, + 0.2530f,0.1778f,0.0751f,0.8977f,0.3648f,0.3065f,0.4739f,0.7014f,0.4473f,0.5171f,0.1744f, + 0.3487f,0.7759f,0.9491f,0.2072f,0.2182f,0.6520f,0.3092f,0.9545f,0.1881f,0.9579f,0.1785f, + 0.9636f,0.4830f,0.6569f,0.3353f,0.9997f,0.5869f,0.5747f,0.0238f,0.2943f,0.5248f,0.5879f, + .7266f,0.1965f,0.9167f,0.9726f,0.9206f,0.0519f,0.2997f,0.0039f,0.7652f,0.5498f, + 0.3794f,0.3791f,0.3528f,0.2873f,0.8082f,0.4732f,0.4399f,0.6606f,0.5991f,0.0034f,0.4874f}); + auto e = NDArrayFactory::create('c', {8, 8, 3, 1}, { + 1.0218375f, 1.0666375f, 0.9130375f, + -0.07396251f, 0.91843754f, -0.17496246f, + 0.47543746f, 1.2492375f, 0.55643755f, + 1.3110375f, -0.36456245f, 1.0518374f, + 0.7824375f, 0.57523745f, -0.21656245f, + 0.0816375f, -0.2261625f, 0.40323752f, + 1.4520376f, 0.6868375f, 0.81723756f, + -0.17576247f, 0.81423753f, -0.08656245f, + + -0.36249164f, 0.45590833f, 1.1925083f, + 0.00650835f, 1.4861084f, 1.2079083f, + 0.05270836f, 0.37350836f, 0.94130826f, + 1.0715083f, 0.6103083f, 0.9825083f, + 0.07370833f, -0.4518917f, -0.39889166f, + -0.3354917f, 1.2213084f, 1.0345083f, + -0.3132917f, 0.78470826f, 0.23390833f, + 0.6943083f, 0.68170834f, -0.09989169f, + + 0.8352709f, 1.3798709f, 0.15507084f, + 0.26607084f, -0.10792917f, 1.2302709f, + 0.6448709f, -0.29992914f, 1.3534708f, + 0.86607087f, 0.37607086f, 0.04027084f, + 0.40087086f, 0.59507084f, 0.9416709f, + 0.53127086f, -0.01712915f, 1.4610709f, + -0.17152917f, -0.13992918f, 0.6242708f, + -0.42192918f, 0.38387084f, -0.15752912f, + + 0.3311833f, 0.00618333f, 0.17538333f, + 0.10418332f, 0.8365834f, 0.27098334f, + 1.2421833f, -0.1114167f, 1.0153834f, + 0.9523833f, 0.8317833f, 0.9633833f, + 0.6501833f, 0.04258335f, 0.9999833f, + -0.40181667f, 0.11418331f, 0.47938335f, + 1.1057833f, -0.29761666f, 1.0779834f, + 0.5243833f, -0.32181668f, 1.1833833f, + + 0.73157084f, 0.4317708f, 0.7283708f, + 1.2297708f, 0.4307708f, 0.85377085f, + 0.05977082f, -0.09282917f, 0.33957082f, + 1.0751709f, 0.2119708f, 0.51897085f, + -0.25302917f, 1.1723708f, -0.12562919f, + 1.1993709f, 0.5257708f, 0.40517086f, + 0.53197086f, 0.8441708f, 0.02617085f, + -0.0208292f, 0.8711709f, 0.04137081f, + + 0.74936247f, 0.6085625f, 0.8997625f, + -0.08743751f, 0.18576252f, -0.17563748f, + 0.5991625f, -0.0038375f, 0.07576251f, + 0.42536253f, -0.22823751f, 0.36296248f, + 0.81456256f, -0.16183749f, 0.5161625f, + -0.21183747f, 0.7429625f, 0.6217625f, + 0.17656249f, 0.02616251f, -0.17923748f, + 1.4659625f, 0.40016252f, 0.28356248f, + + 0.4195791f, 0.8745791f, 0.36637908f, + 0.50597906f, -0.17942089f, 0.16917908f, + 1.0235791f, 1.3699791f, -0.11382091f, + -0.0918209f, 0.7757791f, 0.09017909f, + 1.3807791f, -0.15202093f, 1.3875791f, + -0.1712209f, 1.3989791f, 0.43777913f, + 0.7855791f, 0.1423791f, 1.4711791f, + 0.6455791f, 0.6211791f, -0.48062086f, + + 0.10189578f, 0.5628958f, 0.68909574f, + 0.96649575f, -0.09370419f, 1.3466958f, + 1.4584957f, 1.3544958f, -0.3829042f, + 0.11269578f, -0.47890422f, 1.0436958f, + 0.6128957f, 0.27209583f, 0.2714958f, + 0.21889582f, 0.08789578f, 1.1296958f, + 0.4596958f, 0.39309582f, 0.8344958f, + 0.71149576f, -0.4799042f, 0.4880958f + }); + + sd::ops::adjust_contrast op; + auto result = op.evaluate({&x}, {2.}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto out = result.at(0); +// out->printBuffer("Adjusted Constrast6"); +// e.printBuffer("Adjusted Expected 6"); +// ASSERT_TRUE(e.equalsTo(out)); + +} + +TEST_F(DeclarableOpsTests15, Test_AdjustContrast_7) { + auto x = NDArrayFactory::create('c', {8,8, 3, 1}, {0.7788f,0.8012f,0.7244f,0.2309f,0.7271f,0.1804f, + 0.5056f,0.8925f,0.5461f,0.9234f,0.0856f,0.7938f,0.6591f,0.5555f,0.1596f,0.3087f,0.1548f,0.4695f, + 0.9939f,0.6113f,0.6765f,0.1800f,0.6750f,0.2246f,0.0509f,0.4601f,0.8284f,0.2354f,0.9752f,0.8361f, + 0.2585f,0.4189f,0.7028f,0.7679f,0.5373f,0.7234f,0.2690f,0.0062f,0.0327f,0.0644f,0.8428f,0.7494f, + 0.0755f,0.6245f,0.3491f,0.5793f,0.5730f,0.1822f,0.6420f,0.9143f,0.3019f, + 0.3574f,0.1704f,0.8395f,0.5468f,0.0744f,0.9011f,0.6574f,0.4124f,0.2445f,0.4248f,0.5219f, + 0.6952f,0.4900f,0.2158f,0.9549f,0.1386f,0.1544f,0.5365f,0.0134f,0.4163f,0.1456f,0.4109f, + 0.2484f, 0.3330f,0.2974f,0.6636f,0.3808f,0.8664f, 0.1896f, 0.7530f, 0.7215f, 0.6612f, 0.7270f, + 0.5704f,0.2666f,0.7453f,0.0444f,0.3024f,0.4850f,0.7982f,0.0965f,0.7843f,0.5075f, + 0.0844f,0.8370f,0.6103f,0.4604f,0.6087f, 0.8594f, 0.4599f, 0.6714f, 0.2744f, 0.1981f, 0.4143f, + 0.7821f,0.3505f,0.5040f,0.1180f,0.8307f,0.1817f,0.8442f,0.5074f,0.4471f,0.5105f,0.6666f, + 0.2576f,0.2341f,0.6801f,0.2652f,0.5394f,0.4690f,0.6146f,0.1210f,0.2576f,0.0769f,0.4643f, + 0.1628f,0.2026f,0.3774f,0.0506f,0.3462f,0.5720f,0.0838f,0.4228f,0.0588f,0.5362f,0.4756f, + 0.2530f,0.1778f,0.0751f,0.8977f,0.3648f,0.3065f,0.4739f,0.7014f,0.4473f,0.5171f,0.1744f, + 0.3487f,0.7759f,0.9491f,0.2072f,0.2182f,0.6520f,0.3092f,0.9545f,0.1881f,0.9579f,0.1785f, + 0.9636f,0.4830f,0.6569f,0.3353f,0.9997f,0.5869f,0.5747f,0.0238f,0.2943f,0.5248f,0.5879f, + .7266f,0.1965f,0.9167f,0.9726f,0.9206f,0.0519f,0.2997f,0.0039f,0.7652f,0.5498f, + 0.3794f,0.3791f,0.3528f,0.2873f,0.8082f,0.4732f,0.4399f,0.6606f,0.5991f,0.0034f,0.4874f}); + auto e = NDArrayFactory::create('c', {8, 8, 3, 1}, { + 1.0218375, 1.0666375 , 0.9130375 , + -0.07396251, 0.91843754, -0.17496246, + 0.47543746, 1.2492375 , 0.55643755, + 1.3110375 , -0.36456245, 1.0518374 , + 0.7824375 , 0.57523745, -0.21656245, + 0.0816375 , -0.2261625 , 0.40323752, + 1.4520376 , 0.6868375 , 0.81723756, + -0.17576247, 0.81423753, -0.08656245, + + -0.36249164, 0.45590833, 1.1925083 , + 0.00650835, 1.4861084 , 1.2079083 , + 0.05270836, 0.37350836, 0.94130826, + 1.0715083 , 0.6103083 , 0.9825083 , + 0.07370833, -0.4518917 , -0.39889166, + -0.3354917 , 1.2213084 , 1.0345083 , + -0.3132917 , 0.78470826, 0.23390833, + 0.6943083 , 0.68170834, -0.09989169, + + 0.8352709 , 1.3798709 , 0.15507084, + 0.26607084, -0.10792917, 1.2302709 , + 0.6448709 , -0.29992914, 1.3534708 , + 0.86607087, 0.37607086, 0.04027084, + 0.40087086, 0.59507084, 0.9416709 , + 0.53127086, -0.01712915, 1.4610709 , + -0.17152917, -0.13992918, 0.6242708 , + -0.42192918, 0.38387084, -0.15752912, + + + 0.3311833 , 0.00618333, 0.17538333, + 0.10418332, 0.8365834 , 0.27098334, + 1.2421833 , -0.1114167 , 1.0153834 , + 0.9523833 , 0.8317833 , 0.9633833 , + 0.6501833 , 0.04258335, 0.9999833 , + -0.40181667, 0.11418331, 0.47938335, + 1.1057833 , -0.29761666, 1.0779834 , + 0.5243833 , -0.32181668, 1.1833833 , + + 0.73157084, 0.4317708 , 0.7283708 , + 1.2297708 , 0.4307708 , 0.85377085, + 0.05977082, -0.09282917, 0.33957082, + 1.0751709 , 0.2119708 , 0.51897085, + -0.25302917, 1.1723708 , -0.12562919, + 1.1993709 , 0.5257708 , 0.40517086, + 0.53197086, 0.8441708 , 0.02617085, + -0.0208292 , 0.8711709 , 0.04137081, + + 0.74936247, 0.6085625 , 0.8997625 , + -0.08743751, 0.18576252, -0.17563748, + 0.5991625 , -0.0038375 , 0.07576251, + 0.42536253, -0.22823751, 0.36296248, + 0.81456256, -0.16183749, 0.5161625 , + -0.21183747, 0.7429625 , 0.6217625 , + 0.17656249, 0.02616251, -0.17923748, + 1.4659625 , 0.40016252, 0.28356248, + + 0.4195791 , 0.8745791 , 0.36637908, + 0.50597906, -0.17942089, 0.16917908, + 1.0235791 , 1.3699791 , -0.11382091, + -0.0918209 , 0.7757791 , 0.09017909, + 1.3807791 , -0.15202093, 1.3875791 , + -0.1712209 , 1.3989791 , 0.43777913, + 0.7855791 , 0.1423791 , 1.4711791 , + 0.6455791 , 0.6211791 , -0.48062086, + + + 0.10189578, 0.5628958 , 0.68909574, + 0.96649575, -0.09370419, 1.3466958 , + 1.4584957 , 1.3544958 , -0.3829042 , + 0.11269578, -0.47890422, 1.0436958 , + 0.6128957 , 0.27209583, 0.2714958 , + 0.21889582, 0.08789578, 1.1296958 , + 0.4596958 , 0.39309582, 0.8344958 , + 0.71149576, -0.4799042, 0.4880958 + }); +// x.linspace(1.); + sd::ops::adjust_contrast_v2 op; + auto result = op.evaluate({&x}, {2.}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto out = result.at(0); +// out->printBuffer("Adjusted Constrast7"); +// e.printBuffer("Adjusted expected 7"); + auto diff = e - *out; +// diff.printBuffer("Adjusted subtract 7"); + ASSERT_TRUE(e.equalsTo(out)); + +} + +TEST_F(DeclarableOpsTests15, Test_BitCast_1) { + auto x = NDArrayFactory::create('c', {2, 2, 2}); + auto e = NDArrayFactory::create('c', {2, 2}, {2., 512., 8192., 131072.032 }); + x.linspace(1.); + + sd::ops::bitcast op; + auto result = op.evaluate({&x}, {(int) sd::DataType::DOUBLE}); + ASSERT_EQ(Status::OK(), result.status()); + auto out = result.at(0); +// out->printIndexedBuffer("Casted result"); + ASSERT_TRUE(e.equalsTo(out)); + +} + +TEST_F(DeclarableOpsTests15, Test_BitCast_2) { + auto x = NDArrayFactory::create('c', {2, 4}); + auto e = NDArrayFactory::create('c', {2, 4, 2}, {0.f, 1.875f, 0.f, 2.f, 0.f, 2.125f, 0.f, 2.25f, + 0.f, 2.312f, 0.f, 2.375f, 0.f, 2.438f, 0.f, 2.5f}); + x.linspace(1.); + + sd::ops::bitcast op; + auto result = op.evaluate({&x}, {(int) sd::DataType::HALF}); + ASSERT_EQ(Status::OK(), result.status()); + auto out = result.at(0); + + ASSERT_TRUE(e.equalsTo(out)); + +} + +TEST_F(DeclarableOpsTests15, Test_BitCast_3) { + auto x = NDArrayFactory::create('c', {1, 4}); + + x.linspace(1.); + sd::ops::bitcast op; + try { + auto result = op.evaluate({&x}, {(int) sd::DataType::INT64}); + ASSERT_NE(Status::OK(), result.status()); + + } catch (std::exception& e) { + nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); + } +} + +TEST_F(DeclarableOpsTests15, Test_BitCast_4) { + auto x = NDArrayFactory::create('c', {1, 4}); + auto e = NDArrayFactory::create('c', {1, 2}, {1234567890LL, 2468013579LL}); + x.linspace(1.); + sd::ops::bitcast op; + try { + auto result = op.execute({&x}, {&e}, {}, {sd::DataType::INT64}, {}); + ASSERT_NE(Status::OK(), result); + } catch(std::exception& e) { + nd4j_printf("Error `%s' should be here. It's OK.\n",e.what()); + } + +} + +TEST_F(DeclarableOpsTests15, Test_BitCast_4_1) { + auto x = NDArrayFactory::create('c', {1, 2}); + auto e = NDArrayFactory::create('c', {1, 2}, {4607182418800017408LL, 4611686018427387904LL}); // as TF 4607182418800017408, 4611686018427387904 + x.linspace(1.); + sd::ops::bitcast op; + + auto result = op.evaluate({&x}, {}, {sd::DataType::INT64}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + // e.printIndexedBuffer("Double to int64"); + auto res = result.at(0); + ASSERT_EQ(*res, e); + +} + + +TEST_F(DeclarableOpsTests15, Test_BitCast_5) { + auto x = NDArrayFactory::create('c', {4, 4}, { + 0.4922f, 0.2969f, 0.6172f, 0.8906f, + 0.9297f, 0.0859f, 0.2344f, 0.3828f, + 0.5781f, 0.7969f, 0.0391f, 0.1719f, + 0.8359f, 0.9297f, 0.3438f, 0.0938f}); + + auto e = NDArrayFactory::create('c', {4}, {4260467851820808160LL, 3900173902914993008LL, 3566895990128523424LL, + 3314989625590692528LL}); + + sd::ops::bitcast op; + auto result = op.evaluate({&x}, {}, {sd::DataType::INT64}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto res = result.at(0); + +// res->printIndexedBuffer("BITCAST5"); + ASSERT_TRUE(e.equalsTo(res)); + +} + +TEST_F(DeclarableOpsTests15, Test_BitCast_6) { + auto x = NDArrayFactory::create('c', {4, 4}, { + 1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, + 13.f, 14.f, 15.f, 16.f}); + + auto e = NDArrayFactory::create('c', {4}, {4899988963420290048LL, 5188224837230806272LL, 5332342774136064128LL, + 5476460161268730496LL}); + + sd::ops::bitcast op; + auto result = op.evaluate({&x}, {}, {sd::DataType::INT64}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto res = result.at(0); + +// res->printIndexedBuffer("BITCAST6"); + ASSERT_TRUE(e.equalsTo(res)); + +} +TEST_F(DeclarableOpsTests15, Test_BitCast_7) { + auto x = NDArrayFactory::create('c', {4, 4}, { + 1.1f, 2.2f, 3.3f, 4.4f, + 5.1f, 6.2f, 7.3f, 8.4f, + 9.1f, 10.2f, 11.3f, 12.4f, + 13.f, 14.2f, 15.3f, 16.4f}); + + auto e = NDArrayFactory::create('c', {4}, { + 4928700072476425318LL, 5202580391758873882LL, 5346698272827918477LL, 5483778673873668736LL}); + + sd::ops::bitcast op; + auto result = op.evaluate({&x}, {}, {sd::DataType::INT64}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto res = result.at(0); + +// res->printIndexedBuffer("BITCAST7"); + ASSERT_TRUE(e.equalsTo(res)); + +} + +TEST_F(DeclarableOpsTests15, test_matmul_bp_1) { + auto a = NDArrayFactory::create('c', {1, 3}); + auto b = NDArrayFactory::create('c', {1, 4}); + auto gI = NDArrayFactory::create('c', {3, 4}); + + auto gA = NDArrayFactory::create('c', {1, 3}); + auto gB = NDArrayFactory::create('c', {1, 4}); + + sd::ops::matmul_bp op; + auto status = op.execute({&a, &b, &gI}, std::vector{&gA, &gB}, {}, {1, 0, 0}, {}); + ASSERT_EQ(Status::OK(), status); +} + +TEST_F(DeclarableOpsTests15, test_non_decreasing_1) { + auto x = NDArrayFactory::create(1.0); + auto z = NDArrayFactory::create(false); + auto e = NDArrayFactory::create(true); + + sd::ops::is_non_decreasing op; + Context ctx(1); + ctx.setInputArray(0, &x); + ctx.setOutputArray(0, &z); + + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); + ASSERT_EQ(e, z); +} + +TEST_F(DeclarableOpsTests15, test_check_numeric_1) { + auto x = NDArrayFactory::create('c', {3},{1.f, 2.f, 3.f}); + auto y = NDArrayFactory::string("shouldn't ever trigger"); + + sd::ops::check_numerics op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_EQ(x, *z); +} + +TEST_F(DeclarableOpsTests15, test_check_numeric_2) { +#ifdef FFAST_MATH + if (1 > 0) + return; +#endif + + auto x = NDArrayFactory::create('c', {3},{1.f, 2.f, std::numeric_limits::infinity()}); + auto y = NDArrayFactory::string("should trigger"); + auto z = NDArrayFactory::create('c', {3} ); + + sd::ops::check_numerics op; + try { + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + ASSERT_TRUE(false); + } catch (std::invalid_argument &e) { + // + } +} + +TEST_F(DeclarableOpsTests15, test_check_numeric_3) { +#ifdef FFAST_MATH + if (1 > 0) + return; +#endif + + auto x = NDArrayFactory::create('c', {3},{1.f, 2.f, std::numeric_limits::quiet_NaN()}); + auto y = NDArrayFactory::string("should trigger"); + auto z = NDArrayFactory::create('c', {3} ); + + sd::ops::check_numerics op; + try { + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + ASSERT_TRUE(false); + } catch (std::invalid_argument &e) { + // + } +} + +TEST_F(DeclarableOpsTests15, Test_layer_norm_1) { + auto x = NDArrayFactory::create('c', {1, 5}, {1.f, 2.f, 3.f, 4.f, 5.f}); + auto g = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); + auto b = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); + + sd::ops::layer_norm op; + auto result = op.evaluate({&x, &g, &b}, {}, {0}, {false}); + ASSERT_EQ(Status::OK(), result.status()); + +} + +TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_1) { + auto x = NDArrayFactory::create('c', {1, 5}, {1.f, 2.f, 3.f, 4.f, 5.f}); + auto g = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); + auto b = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); + auto eps = NDArrayFactory::create('c', {1, 5}, {0.f, 0.f, 0.f, 0.f, 0.f}); + + sd::ops::layer_norm_bp op; + auto result = op.evaluate({&x, &g, &b, &eps}, {}, {0}, {false}); + ASSERT_EQ(Status::OK(), result.status()); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_2) { + + NDArray x('c', {3, 4, 8, 8}, sd::DataType::FLOAT32); + NDArray gain('c', {4}, {-0.1, 0.1, -0.2, 0.2}, sd::DataType::FLOAT32); + NDArray bias('c', {4}, {-0.05, 0.05, -1.05, 1.05}, sd::DataType::FLOAT32); + NDArray gradO('c', {3, 4, 8, 8}, sd::DataType::FLOAT32); + + NDArray gradI('c', {3, 4, 8, 8}, sd::DataType::FLOAT32); + NDArray gradG('c', {4}, sd::DataType::FLOAT32); + NDArray gradB('c', {4}, sd::DataType::FLOAT32); + + x.linspace(-20, 0.5); + gradO.linspace(-4, 0.05); + + sd::ops::layer_norm_bp op; + auto status = op.execute({&x, &gain, &bias, &gradO}, {&gradI, &gradG, &gradB}, {}, {1,2,3}, {true}); + ASSERT_EQ(Status::OK(), status); +} + +TEST_F(DeclarableOpsTests15, test_hashCode_1) { + auto x = NDArrayFactory::create('c', {10}); + auto y = NDArrayFactory::create('c', {10}); + + x.linspace(1.); + y.linspace(2.); + + sd::ops::hashcode op; + auto resultA0 = op.evaluate({&x}); + auto resultA1 = op.evaluate({&x}); + auto resultB0 = op.evaluate({&y}); +// resultA0->at(0)->printIndexedBuffer("A0"); +// resultA1->at(0)->printIndexedBuffer("A1"); +// resultB0->at(0)->printIndexedBuffer("B0"); + ASSERT_EQ(*resultA0.at(0), *resultA1.at(0)); + ASSERT_NE(*resultA0.at(0), *resultB0.at(0)); +} + +TEST_F(DeclarableOpsTests15, test_hashCode_2) { + auto x = NDArrayFactory::create('c', {1027}); + auto y = NDArrayFactory::create('c', {1027}); + + x.linspace(1.); + y.linspace(2.); + + sd::ops::hashcode op; + auto resultA0 = op.evaluate({&x}); + auto resultA1 = op.evaluate({&x}); + auto resultB0 = op.evaluate({&y}); + +// resultA0->at(0)->printIndexedBuffer("A0"); +// resultA1->at(0)->printIndexedBuffer("A1"); +// resultB0->at(0)->printIndexedBuffer("B0"); + + ASSERT_EQ(*resultA0.at(0), *resultA1.at(0)); + ASSERT_NE(*resultA0.at(0), *resultB0.at(0)); +} + +TEST_F(DeclarableOpsTests15, test_rank_1) { + auto array = NDArrayFactory::create('c', {4, 64}); + auto e = NDArrayFactory::create('c', {}, {2}); + auto z = NDArrayFactory::create('c', {}); + + sd::ops::rank op; + auto result = op.execute({&array}, {&z}, {}, {}, {}); + ASSERT_EQ(Status::OK(), result); + ASSERT_EQ(e, z); +} + +TEST_F(DeclarableOpsTests15, test_rank_2) { + auto array = NDArrayFactory::create('c', {4, 64}); + auto e = NDArrayFactory::create('c', {}, {2}); + + sd::ops::rank op; + auto result = op.evaluate({&array}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_EQ(e, *z); + + +} + +TEST_F(DeclarableOpsTests15, test_lstmBlock_1) { + auto x0 = NDArrayFactory::create(5); + auto x1 = NDArrayFactory::create('c', {5, 1, 4}, {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f, 0.50563407f, 0.89252293f, 0.5461209f, 0.92336726f, 0.085571885f, 0.7937801f, 0.65908563f, 0.55552566f, 0.15962744f, 0.30874777f, 0.15476847f, 0.46954823f, 0.9938899f, 0.6112741f}); + auto x2 = NDArrayFactory::create('c', {1, 3}, {0.7717289f, 0.9280778f, 0.98455656f}); + auto x3 = NDArrayFactory::create('c', {1, 3}, {0.94414854f, 0.5956861f, 0.8668989f}); + auto x4 = NDArrayFactory::create('c', {7, 12}, {0.460692f, 0.042572856f, 0.08420354f, -0.09538093f, -0.11416581f, -0.53166187f, 0.40133476f, -0.24381405f, 0.30778718f, 0.52713746f, 0.16253126f, -0.034891903f, 0.011679292f, -0.19076681f, 0.14710993f, -0.3704369f, 0.51872355f, 0.13536876f, -0.5568739f, -0.08727971f, 0.07601875f, -0.074174374f, -0.5345982f, -0.3581748f, -0.28263924f, -0.25141674f, 0.43328637f, -0.50227314f, -0.26641843f, -0.38241976f, -0.19636461f, -0.04020852f, -0.27312332f, 0.5207915f, -0.37247592f, -0.4713087f, -0.25670746f, -0.14942765f, -0.015806139f, -0.22531253f, 0.5582536f, 0.3093416f, 0.3221351f, -0.0964683f, 0.14318448f, 0.42279094f, -0.46992f, -0.43399644f, -0.51704615f, -0.11854091f, 0.21697259f, -0.049382925f, 0.14059627f, 0.3912331f, -0.41345632f, 0.5067368f, -0.3420229f, 0.485789f, 0.044918716f, 0.26209074f, 0.12357575f, 0.21778125f, -0.53791714f, 0.18346387f, 0.054183125f, 0.5480431f, 0.03675288f, -0.26656917f, -0.018610716f, 0.19917983f, 0.5566165f, 0.43570566f, -0.35720813f, 0.31097364f, -0.47134516f, -0.289197f, 0.091138184f, 0.13300979f, -0.36592877f, -0.17540845f, 0.21732038f, 0.4393713f, 0.42800313f, 0.5006979f}); + auto x5 = NDArrayFactory::create('c', {1, 3}); + auto x6 = NDArrayFactory::create('c', {1, 3}); + auto x7 = NDArrayFactory::create('c', {1, 3}); + auto x8 = NDArrayFactory::create('c', {12}); + + sd::ops::lstmBlock op; + auto result = op.evaluate({&x0, &x1, &x2, &x3, &x4, &x5, &x6, &x7, &x8}, {2.0, 0.3}, {0, 0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + // z->printIndexedBuffer("Z"); +} + +TEST_F(DeclarableOpsTests15, test_lstmBlock_2) { + int seqLen = 8; + int bS = 16; + int nIn = 8; + + auto x0 = NDArrayFactory::create(5); + auto x1 = NDArrayFactory::create('f', {bS, nIn, seqLen}); + auto x2 = NDArrayFactory::create('f', {bS, nIn}); // nIn == nOut + auto x3 = NDArrayFactory::create('f', {bS, nIn}); + auto x4 = NDArrayFactory::create('f', {2 * nIn, 4 * nIn}); + auto x5 = NDArrayFactory::create('f', {nIn}); + auto x6 = NDArrayFactory::create('f', {nIn}); + auto x7 = NDArrayFactory::create('f', {nIn}); + auto x8 = NDArrayFactory::create('f', {4 * nIn}); + + sd::ops::lstmBlock op; + auto result = op.evaluate({&x0, &x1, &x2, &x3, &x4, &x5, &x6, &x7, &x8}, {1.0, 0.0}, {0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + +} + +TEST_F(DeclarableOpsTests15, test_lstmBlock_3) { + + int seqLen = 3; + int bS = 2; + int nIn = 4; + + NDArray f('f', {bS, nIn, seqLen}, sd::DataType::FLOAT32); + NDArray cLast('f', {bS, nIn}, sd::DataType::FLOAT32); + + f = 2; + cLast = 3; + + for (int t = 0; t < seqLen; ++t) { + + //section 1 + //auto ft = f({0,0, 0,0, t,t+1}); + //auto temp = ft * cLast; + + + // section 2 + auto ft = f({0,0, 0,0, t,t+1}); + auto temp1 = ft.reshape('f', {bS, nIn}); + auto temp2 = temp1 * cLast; + } +} + +TEST_F(DeclarableOpsTests15, test_empty_increasing_1) { + auto x = NDArrayFactory::create('c', {1, 0, 3}); + auto z = NDArrayFactory::create(false); + + Context ctx(1); + ctx.setInputArray(0, &x); + ctx.setOutputArray(0, &z); + + sd::ops::is_strictly_increasing op; + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(true, z.e(0)); +} + +TEST_F(DeclarableOpsTests15, test_empty_decreasing_1) { + auto x = NDArrayFactory::create('c', {1, 0, 3}); + auto z = NDArrayFactory::create(false); + + Context ctx(1); + ctx.setInputArray(0, &x); + ctx.setOutputArray(0, &z); + + sd::ops::is_non_decreasing op; + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(true, z.e(0)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_grs_1) { + // rank 1 + NDArray rgbs('c', { 3 }, { 10, 50, 200 }, sd::DataType::INT32); + NDArray expected('c', { 1 }, std::vector{ 55 }, sd::DataType::INT32); + sd::ops::rgb_to_grs op; + auto result = op.evaluate({&rgbs}, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_grs_2) { + // rank 1 + auto rgbs = NDArrayFactory::create('f', { 3 }, { 1, 120, -25 }); + auto expected = NDArrayFactory::create('f', { 1 }, { 67 }); + sd::ops::rgb_to_grs op; + auto result = op.evaluate({ &rgbs }, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_grs_3) { + // rank 2 + NDArray rgbs('c', { 4, 3 }, { -94, 99, 97, 90, 114, 101, 111, 96, 105, 100, 103, 102 }, sd::DataType::INT32); + NDArray expected('c', { 4, 1 }, { 41, 105, 101, 101 }, sd::DataType::INT32); + sd::ops::rgb_to_grs op; + auto result = op.evaluate({ &rgbs }, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_grs_4) { + + NDArray rgbs('c', { 3, 2 }, {14, 99, 207, 10, 114, 201 }, sd::DataType::INT32); + + rgbs.permutei({1,0}); + NDArray expected('c', { 2, 1 }, { 138, 58 }, sd::DataType::INT32); + sd::ops::rgb_to_grs op; + auto result = op.evaluate({ &rgbs }, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_grs_5) { + // rank 2 + NDArray rgbs('c', { 3, 4 }, { -94, 99, 97, 90, 114, 101, 111, 96, 105, 100, 103, 102 }, sd::DataType::INT32); + NDArray expected('c', { 1, 4 }, { 50, 100, 105, 94 }, sd::DataType::INT32); + sd::ops::rgb_to_grs op; + auto result = op.evaluate({ &rgbs }, {}, {0}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_grs_6) { + // rank 3 + auto rgbs = NDArrayFactory::create('c', { 5,4,3 }, {1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f}); + auto expected = NDArrayFactory::create('c', { 5,4,1 }, {-47.82958221f, 34.46305847f, 21.36137581f, -21.91625023f,2.49686432f, -43.59792709f, 9.64180183f, 23.04854202f,40.7946167f, 44.98754883f, -25.19047546f, 20.64586449f,-4.97033119f, 30.0226841f, 30.30688286f, 15.61459541f,43.36166f, 18.22480774f, 13.74833488f, 21.59387016f}); + + sd::ops::rgb_to_grs op; + auto result = op.evaluate({ &rgbs }, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_grs_7) { + // rank 3 + auto rgbs = NDArrayFactory::create('c', { 5,3,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f}); + auto expected = NDArrayFactory::create('c', { 5,1,4 }, { 36.626545f, 38.607746f, -40.614971f, 18.233341f, -51.545094f,2.234142f, 20.913160f, 8.783220f, 15.955761f, 55.273506f, 36.838833f, -29.751089f, 8.148357f, 13.676106f, 1.097548f, 68.766457f, 38.690712f, 27.176361f, -14.156269f, 7.157052f }); + + sd::ops::rgb_to_grs op; + auto result = op.evaluate({ &rgbs }, {}, {1}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_grs_8) { + // rank 3 + auto rgbs = NDArrayFactory::create('c', { 3,5,4 }, {1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f}); + try { + sd::ops::rgb_to_grs op; + auto result = op.evaluate({ &rgbs }, {}, {}); + ASSERT_EQ(Status::THROW(), result.status()); + + } catch (std::exception& e) { + nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); + } +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_grs_9) { + // rank 3 + auto rgbs = NDArrayFactory::create('f', { 2, 2, 3 }, { 1.7750e+01f,-7.1062e+01f, -1.0019e+02f, -2.3406e+01f,5.2094e+01f,9.5438e+01f, -6.7461e+00f,3.8562e+01f, 6.5078e+00f, 3.3562e+01f,-5.8844e+01f,2.2750e+01f}); + auto expected = NDArrayFactory::create('f', { 2,2,1 }, { 36.626545f, 38.607746f, -40.614971f, 18.233341f }); + + sd::ops::rgb_to_grs op; + auto result = op.evaluate({ &rgbs }, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_1) { + // rank 1 + NDArray rgbs('f', { 3 }, { 10, 50, 200 }, sd::DataType::FLOAT32); + NDArray expected('f', { 3 }, { 55.14 , 71.2872001, -39.6005542 }, sd::DataType::FLOAT32); + sd::ops::rgb_to_yuv op; + auto result = op.evaluate({ &rgbs }, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_2) { + + NDArray rgbs('c', { 3, 2 }, { 14., 99., 207., 10., 114., 201. }, sd::DataType::FLOAT32); + rgbs.permutei({ 1,0 }); + + NDArray expected('c', { 2, 3 }, { 138.691, -12.150713, -109.38929, 58.385, 70.18241, 35.63085 }, sd::DataType::FLOAT32); + sd::ops::rgb_to_yuv op; + + auto result = op.evaluate({ &rgbs }, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_3) { + // rank 2 + NDArray rgbs('c', { 3, 4 }, { -9.4, 9.9, 9.7, 9.0, 1.14, 1.01, 1.11, 9.6, 1.05, 10.0, 1.03, 10.22 }, sd::DataType::FLOAT32); + NDArray expected('c', { 3, 4 }, { -2.021720, 4.692970, 3.669290, 9.491281, 1.511627, 2.611648, -1.298824, 0.358612, -6.472839, 4.568039, 5.290639, -0.430992 }, sd::DataType::FLOAT32); + + sd::ops::rgb_to_yuv op; + auto result = op.evaluate({ &rgbs }, {}, { 0 }); + auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_4) { + // rank 3 + NDArray rgbs('c', { 5,4,3 }, { 1.7750e+01, 1.4602e+01, 5.4883e+00, 9.5438e+01, 1.0038e+02, 4.0531e+01, -5.8844e+01, 2.9609e+01, -1.1414e+01, 2.1391e+01, 3.9656e+01, 2.1531e+01, -7.1062e+01, -4.5859e+00, 2.9438e+01, -6.7461e+00, 6.7938e+01, -6.1211e+00, 2.2750e+01, -6.1438e+01, 1.5404e-02, -8.5312e+01, 1.1641e+01, 6.2500e+01, -1.0019e+02, 3.9344e+01, -3.1344e+01, 3.8562e+01, 5.9961e+00, 6.2219e+01, -1.0477e+01, 1.7750e+01, 2.9938e+01, 7.5830e-01, -2.7516e+01, 7.2188e+01, -2.3406e+01, 1.1617e+01, 6.5125e+01, 6.5078e+00, 6.7812e+01, 4.6812e+01, 7.7344e+00, 6.8562e+01, 5.6719e+00, 2.3125e+01, 6.7562e+01, 9.3750e+00, 5.2094e+01, -8.6562e+01, 1.2695e+01, 3.3562e+01, 2.9734e+01, 5.2250e+01, 9.5469e+00, -7.4414e+00, -2.0125e+01, 1.8145e+00, 7.8438e+01, -4.8125e+01 }, sd::DataType::FLOAT32); + NDArray expected('c', { 5,4,3 }, { 14.5042902, -4.43686799, 2.847406, 92.079556, -25.36761168, 2.94630572, -1.515069, -4.87137291, -50.29369639, 32.128515, -5.21515376, -9.41983935,-20.5835293, 24.61614501, -44.28390394, 37.1647167, -21.30142676, -38.52221293, -29.26009994, 14.40679768, 45.62757638, -11.550021, 36.44083018, -64.71012983,-10.435098, - 10.28950082, - 78.74044941, 22.1427147, 19.72198103, 14.40435988, 10.699559, 9.46744852, - 18.5778351 , -7.6957283, 39.31166179, 7.41657542, 7.245035, 28.48336771, - 26.88963173, 47.0880442, - 0.13584441, - 35.60035823, 43.2050762, - 18.47048906, - 31.11782117, 47.642019, - 18.83162118, - 21.50836396,-33.788558, 22.87507047, 75.34330791, 33.445396, 9.25395257, 0.10229474, -3.8078287, -8.02985955, 11.71587638, 41.0993915, -43.90830496, -34.46396749 }, sd::DataType::FLOAT32); + + sd::ops::rgb_to_yuv op; + auto result = op.evaluate({ &rgbs }, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_5) { + // rank 3 + NDArray rgbs('c', { 5,3,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, sd::DataType::FLOAT32); + NDArray expected('c', { 5,3,4 }, { 36.628319, 38.600643,-40.624989, 18.231001, - 14.822637, - 2.479566, - 8.965780, 2.223851, -16.561626,-96.205162,-52.255379,-36.527435,-51.546139,2.234915, 20.914114, 8.785358, 32.552223, -3.356598, 9.069552, 1.393482,36.029255, 4.824605,- 9.972263,11.058715, 15.947105, 55.283543, 36.845627, -29.750486,0.887228, 6.534475, -21.794132,34.155693, -89.929497,39.562351, 27.276817,31.359871, 8.149521, 13.673355, 1.104303, 68.774300, 2.236881, 13.216944, - 3.555702,- 3.225931,3.063015, - 36.134724,58.302204, 8.477802, 38.695396,27.181587, - 14.157411,7.157054, 11.714512, 22.148155, 11.580557, - 27.204905,7.120562, 21.992094, 2.406748, - 6.265247, }, sd::DataType::FLOAT32); + + sd::ops::rgb_to_yuv op; + auto result = op.evaluate({ &rgbs }, {}, { 1 }); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_6) { + // rank 3 + NDArray rgbs('c', { 3,5,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, sd::DataType::FLOAT32); + try { + sd::ops::rgb_to_yuv op; + auto result = op.evaluate({ &rgbs }, {}, {}); + ASSERT_EQ(Status::THROW(), result.status()); + + } + catch (std::exception & e) { + nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); + } +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_rgb_to_yuv_7) { + // rank 3 + NDArray rgbs('f', { 2, 2, 3 }, { 1.7750e+01f,-7.1062e+01f, -1.0019e+02f, -2.3406e+01f,5.2094e+01f,9.5438e+01f, -6.7461e+00f,3.8562e+01f, 6.5078e+00f, 3.3562e+01f,-5.8844e+01f,2.2750e+01f }, sd::DataType::FLOAT32); + NDArray expected('f', { 2,2,3 }, { 36.628319,38.600643, -40.624989,18.231001, -14.822637,-2.479566, -8.965780, 2.223851, -16.561626,- 96.205162,-52.255379, -36.527435 }, sd::DataType::FLOAT32); + + sd::ops::rgb_to_yuv op; + auto result = op.evaluate({ &rgbs }, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_1) { + // rank 1 + NDArray yuv('c', { 3 }, { 55.14 , 71.2872001, -39.6005542 }, sd::DataType::FLOAT32); + NDArray expected('c', { 3 }, { 10, 50, 200 }, sd::DataType::FLOAT32); + sd::ops::yuv_to_rgb op; + auto result = op.evaluate({ &yuv }, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_2) { + // rank 1 + NDArray yuv('f', { 3 }, { 55.14, 71.2872001, -39.6005542 }, sd::DataType::FLOAT32); + NDArray expected('f', { 3 }, { 10, 50, 200 }, sd::DataType::FLOAT32); + sd::ops::yuv_to_rgb op; + auto result = op.evaluate({ &yuv }, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_3) { + // rank 2 + NDArray expected('c', { 3, 4 }, { -9.4, 9.9, 9.7, 9.0, 1.14, 1.01, 1.11, 9.6, 1.05, 10.0, 1.03, 10.22 }, sd::DataType::FLOAT32); + NDArray yuv('c', { 3, 4 }, { -2.021720, 4.692970, 3.669290, 9.491281, 1.511627, 2.611648, -1.298824, 0.358612, -6.472839, 4.568039, 5.290639, -0.430992 }, sd::DataType::FLOAT32); + + sd::ops::yuv_to_rgb op; + auto result = op.evaluate({ &yuv }, {}, { 0 }); + auto output = result.at(0); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_4) { + // rank 3 + NDArray expected('c', { 5,4,3 }, { 1.7750e+01, 1.4602e+01, 5.4883e+00, 9.5438e+01, 1.0038e+02, 4.0531e+01, -5.8844e+01, 2.9609e+01, -1.1414e+01, 2.1391e+01, 3.9656e+01, 2.1531e+01, -7.1062e+01, -4.5859e+00, 2.9438e+01, -6.7461e+00, 6.7938e+01, -6.1211e+00, 2.2750e+01, -6.1438e+01, 1.5404e-02, -8.5312e+01, 1.1641e+01, 6.2500e+01, -1.0019e+02, 3.9344e+01, -3.1344e+01, 3.8562e+01, 5.9961e+00, 6.2219e+01, -1.0477e+01, 1.7750e+01, 2.9938e+01, 7.5830e-01, -2.7516e+01, 7.2188e+01, -2.3406e+01, 1.1617e+01, 6.5125e+01, 6.5078e+00, 6.7812e+01, 4.6812e+01, 7.7344e+00, 6.8562e+01, 5.6719e+00, 2.3125e+01, 6.7562e+01, 9.3750e+00, 5.2094e+01, -8.6562e+01, 1.2695e+01, 3.3562e+01, 2.9734e+01, 5.2250e+01, 9.5469e+00, -7.4414e+00, -2.0125e+01, 1.8145e+00, 7.8438e+01, -4.8125e+01 }, sd::DataType::FLOAT32); + NDArray yuv('c', { 5,4,3 }, { 14.5042902, -4.43686799, 2.847406, 92.079556, -25.36761168, 2.94630572, -1.515069, -4.87137291, -50.29369639, 32.128515, -5.21515376, -9.41983935,-20.5835293, 24.61614501, -44.28390394, 37.1647167, -21.30142676, -38.52221293, -29.26009994, 14.40679768, 45.62757638, -11.550021, 36.44083018, -64.71012983,-10.435098, -10.28950082, -78.74044941, 22.1427147, 19.72198103, 14.40435988, 10.699559, 9.46744852, -18.5778351 , -7.6957283, 39.31166179, 7.41657542, 7.245035, 28.48336771, -26.88963173, 47.0880442, -0.13584441, -35.60035823, 43.2050762, -18.47048906, -31.11782117, 47.642019, -18.83162118, -21.50836396,-33.788558, 22.87507047, 75.34330791, 33.445396, 9.25395257, 0.10229474, -3.8078287, -8.02985955, 11.71587638, 41.0993915, -43.90830496, -34.46396749 }, sd::DataType::FLOAT32); + + sd::ops::yuv_to_rgb op; + auto result = op.evaluate({ &yuv }, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_5) { + // rank 3 + NDArray expected('c', { 5,3,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, sd::DataType::FLOAT32); + NDArray yuv('c', { 5,3,4 }, { 36.628319, 38.600643,-40.624989, 18.231001, -14.822637, -2.479566, -8.965780, 2.223851, -16.561626,-96.205162,-52.255379,-36.527435,-51.546139,2.234915, 20.914114, 8.785358, 32.552223, -3.356598, 9.069552, 1.393482,36.029255, 4.824605,-9.972263,11.058715, 15.947105, 55.283543, 36.845627, -29.750486,0.887228, 6.534475, -21.794132,34.155693, -89.929497,39.562351, 27.276817,31.359871, 8.149521, 13.673355, 1.104303, 68.774300, 2.236881, 13.216944, -3.555702,-3.225931,3.063015, -36.134724,58.302204, 8.477802, 38.695396,27.181587, -14.157411,7.157054, 11.714512, 22.148155, 11.580557, -27.204905,7.120562, 21.992094, 2.406748, -6.265247, }, sd::DataType::FLOAT32); + + sd::ops::yuv_to_rgb op; + auto result = op.evaluate({ &yuv }, {}, { 1 }); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_6) { + // rank 3 + NDArray yuv('c', { 3,5,4 }, { 1.7750e+01f, -7.1062e+01f, -1.0019e+02f,-2.3406e+01f, 5.2094e+01f, 9.5438e+01f, -6.7461e+00f, 3.8562e+01f, 6.5078e+00f,3.3562e+01f, -5.8844e+01f, 2.2750e+01f, -1.0477e+01f, 7.7344e+00f, 9.5469e+00f,2.1391e+01f, -8.5312e+01f, 7.5830e-01f,2.3125e+01f, 1.8145e+00f, 1.4602e+01f,-4.5859e+00f, 3.9344e+01f, 1.1617e+01f,-8.6562e+01f, 1.0038e+02f, 6.7938e+01f,5.9961e+00f, 6.7812e+01f, 2.9734e+01f,2.9609e+01f, -6.1438e+01f, 1.7750e+01f,6.8562e+01f, -7.4414e+00f, 3.9656e+01f,1.1641e+01f, -2.7516e+01f, 6.7562e+01f,7.8438e+01f, 5.4883e+00f, 2.9438e+01f,-3.1344e+01f, 6.5125e+01f, 1.2695e+01f,4.0531e+01f, -6.1211e+00f, 6.2219e+01f,4.6812e+01f, 5.2250e+01f, -1.1414e+01f,1.5404e-02f, 2.9938e+01f, 5.6719e+00f,-2.0125e+01f, 2.1531e+01f, 6.2500e+01f,7.2188e+01f, 9.3750e+00f, -4.8125e+01f }, sd::DataType::FLOAT32); + try { + sd::ops::yuv_to_rgb op; + auto result = op.evaluate({ &yuv }, {}, {}); + ASSERT_EQ(Status::THROW(), result.status()); + + } + catch (std::exception & e) { + nd4j_printf("Error should be here `%s'. It's OK.\n", e.what()); + } +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, test_yuv_to_rgb_7) { + // rank 3 + NDArray expected('f', { 2, 2, 3 }, { 1.7750e+01f,-7.1062e+01f, -1.0019e+02f, -2.3406e+01f,5.2094e+01f,9.5438e+01f, -6.7461e+00f,3.8562e+01f, 6.5078e+00f, 3.3562e+01f,-5.8844e+01f,2.2750e+01f }, sd::DataType::FLOAT32); + NDArray yuv('f', { 2,2,3 }, { 36.628319, 38.600643, -40.624989, 18.231001, -14.822637, -2.479566, -8.965780, 2.223851, -16.561626, -96.205162, -52.255379, -36.527435 }, sd::DataType::FLOAT32); + + sd::ops::yuv_to_rgb op; + auto result = op.evaluate({ &yuv }, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +//////////////////////////////////////////////////////////////////////////////////////// + +TEST_F(DeclarableOpsTests15, Pow_BP_Test1) { + + // same shape + NDArray x('c', { 2,2,2 }, { 4,3,2,5,7,8,-9,-12 }, sd::DataType::FLOAT32); + NDArray y('c', { 2,2,2 }, { 2,3,-2,4,-1,-4,10,8 }, sd::DataType::FLOAT32); + + + NDArray dLdz('c', { 2,2,2 }, sd::DataType::FLOAT32); + NDArray dLdxExp('c', { 2,2,2 }, { 8, 27, -0.25, 500, -0.0204082, -0.000122, -3.87420e+09, -2.86654e+08 }, sd::DataType::FLOAT32); + NDArray dLdyExp('c', { 2,2,2 }, { 22.18071, 29.66253, 0.17329, 1005.89874, 0.27799, 0.00051, 0, 0 }, sd::DataType::FLOAT32); + + dLdz.assign(1.0); + + sd::ops::Pow_bp op; + auto results = op.evaluate({ &x, &y, &dLdz }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto* dLdx = results.at(0); + auto* dLdy = results.at(1); + + ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); + ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); + ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); + ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); +} + +TEST_F(DeclarableOpsTests15, Pow_BP_Test2) { + + NDArray x('c', { 1,2,3 }, sd::DataType::FLOAT32); + NDArray y('c', { 3,2,1 }, sd::DataType::FLOAT32); + NDArray dLdz('c', { 3,2,3 }, sd::DataType::FLOAT32); + + NDArray dLdxExp('c', { 1,2,3 }, { 16.8, 19.2, 21.6, 24., 26.4, 28.8 }, sd::DataType::FLOAT32); + NDArray dLdyExp('c', { 3,2,1 }, { 13.30843, 33.27106, 53.2337, 73.19634, 93.15898, 113.12162 }, sd::DataType::FLOAT32); + + x.assign(4.0); + y.assign(2.0); + dLdz.linspace(0.1, 0.1); + + sd::ops::Pow_bp op; + auto results = op.evaluate({ &x, &y, &dLdz }, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto* dLdx = results.at(0); + auto* dLdy = results.at(1); + + ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); + ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); + ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); + ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); + +} + +TEST_F(DeclarableOpsTests15, Pow_BP_Test3) { + + // y - same shape as dLdz + NDArray xY('c', { 1,2,3 }, sd::DataType::FLOAT32); + NDArray yY('c', { 3,2,3 }, sd::DataType::FLOAT32); + + NDArray dLdxExpY('c', { 1,2,3 }, { 16.8, 19.2, 21.6, 24. , 26.4, 28.8 }, sd::DataType::FLOAT32); + NDArray dLdyExpY('c', { 3,2,3 }, { 2.21807, 4.43614, 6.65421, 8.87228, 11.09035, 13.30843, 15.5265 , 17.74457, 19.96264, 22.18071, 24.39878, 26.61685, 28.83492, 31.05299, 33.27106, 35.48914, 37.70721, 39.92528 }, sd::DataType::FLOAT32); + NDArray dLdz('c', { 3,2,3 }, sd::DataType::FLOAT32); + + xY.assign(4.0); + yY.assign(2.0); + dLdz.linspace(0.1, 0.1); + + sd::ops::Pow_bp op; + auto resultsY = op.evaluate({ &xY, &yY, &dLdz }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, resultsY.status()); + + auto* dLdxY = resultsY.at(0); + auto* dLdyY = resultsY.at(1); + + ASSERT_TRUE(dLdxExpY.isSameShape(dLdxY)); + ASSERT_TRUE(dLdxExpY.equalsTo(dLdxY)); + ASSERT_TRUE(dLdyExpY.isSameShape(dLdyY)); + ASSERT_TRUE(dLdyExpY.equalsTo(dLdyY)); +} + +TEST_F(DeclarableOpsTests15, Pow_BP_Test4) { + + // x - same shape ad dLdz + NDArray yX('c', { 1,2,3 }, sd::DataType::FLOAT32); + NDArray xX('c', { 3,2,3 }, sd::DataType::FLOAT32); + + NDArray dLdxExpX('c', { 3,2,3 }, { 3.2, 6.4, 9.6, 12.8, 16. , 19.2, 22.4, 25.6, 28.8, 32. , 35.2, 38.4, 41.6, 44.8, 48., 51.2, 54.4, 57.6 }, sd::DataType::FLOAT32); + NDArray dLdyExpX('c', { 1,2,3 }, { 23.28975, 26.61685, 29.94396, 33.27106, 36.59817, 39.92528 }, sd::DataType::FLOAT32); + + NDArray dLdz('c', { 3,2,3 }, sd::DataType::FLOAT32); + dLdz.linspace(0.1, 0.1); + + sd::ops::Pow_bp op; + + xX.assign(2.0); + yX.assign(4.0); + + auto resultsX = op.evaluate({ &xX, &yX, &dLdz }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, resultsX.status()); + + auto* dLdxX = resultsX.at(0); + auto* dLdyX = resultsX.at(1); + + ASSERT_TRUE(dLdxExpX.isSameShape(dLdxX)); + ASSERT_TRUE(dLdxExpX.equalsTo(dLdxX)); + ASSERT_TRUE(dLdyExpX.isSameShape(dLdyX)); + ASSERT_TRUE(dLdyExpX.equalsTo(dLdyX)); +} + +TEST_F(DeclarableOpsTests15, Pow_BP_Test5) { + + // both single array + NDArray xConst('c', { 1 }, sd::DataType::FLOAT32); + NDArray yConst('c', { 1 }, sd::DataType::FLOAT32); + NDArray dLdz('c', { 1 }, sd::DataType::FLOAT32); + NDArray dLdxExp('c', { 1 }, sd::DataType::FLOAT32); + NDArray dLdyExp('c', { 1 }, sd::DataType::FLOAT32); + + xConst.assign(3.0); + yConst.assign(4.0); + dLdz.assign(1.0); + + dLdxExp.assign(4.0 * pow(3, 3)); + dLdyExp.assign(pow(3, 4) * log(3)); + + sd::ops::Pow_bp op; + auto results = op.evaluate({ &xConst, &yConst, &dLdz }, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto* dLdx = results.at(0); + auto* dLdy = results.at(1); + + ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); + ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); + + ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); + ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); +} + +TEST_F(DeclarableOpsTests15, Pow_BP_Test6) { + + // x single array + NDArray xConst('c', { 1 }, sd::DataType::FLOAT32); + NDArray y('c', { 2, 2, 2 }, sd::DataType::FLOAT32); + NDArray dLdzC('c', { 2, 2, 2 }, sd::DataType::FLOAT32); + + xConst.assign(2.0); + y.assign(4.0); + dLdzC.linspace(0.1, 0.1); + + NDArray dLdxExpXC('c', { 1 }, std::vector{ 115.2 }, sd::DataType::FLOAT32); + NDArray dLdyExpXC('c', { 2, 2, 2 }, { 1.10904, 2.21807, 3.32711, 4.43614, 5.54518, 6.65421, 7.76325, 8.87228 }, sd::DataType::FLOAT32); + + sd::ops::Pow_bp op; + auto resultsXC = op.evaluate({ &xConst, &y, &dLdzC }, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, resultsXC.status()); + + auto* dLdxXC = resultsXC.at(0); + auto* dLdyXC = resultsXC.at(1); + + ASSERT_TRUE(dLdxExpXC.isSameShape(dLdxXC)); + ASSERT_TRUE(dLdxExpXC.equalsTo(dLdxXC)); + ASSERT_TRUE(dLdyExpXC.isSameShape(dLdyXC)); + ASSERT_TRUE(dLdyExpXC.equalsTo(dLdyXC)); + +} + +TEST_F(DeclarableOpsTests15, Pow_BP_Test7) { + + // Y - scalar + auto Y = NDArrayFactory::create(2.f); + NDArray x('c', { 2, 2, 2 }, sd::DataType::FLOAT32); + NDArray dLdzC('c', { 2, 2, 2 }, sd::DataType::FLOAT32); + + dLdzC.linspace(0.1, 0.1); + x = 4.f; + + NDArray dLdxExpYs('c', { 2, 2, 2 }, { 0.8, 1.6, 2.4, 3.2, 4., 4.8, 5.6, 6.4 }, sd::DataType::FLOAT32); + + auto dLdyExpYs = NDArrayFactory::create(79.85056f); + + sd::ops::Pow_bp op; + auto resultsYs = op.evaluate({ &x, &Y, &dLdzC }, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, resultsYs.status()); + + auto* dLdxY = resultsYs.at(0); + auto* dLdyY = resultsYs.at(1); + + ASSERT_TRUE(dLdxExpYs.isSameShape(dLdxY)); + ASSERT_TRUE(dLdxExpYs.equalsTo(dLdxY)); + ASSERT_TRUE(dLdyExpYs.isSameShape(dLdyY)); + ASSERT_TRUE(dLdyExpYs.equalsTo(dLdyY)); +} + +TEST_F(DeclarableOpsTests15, Pow_BP_Test8) { + // both scalars + + auto X = NDArrayFactory::create(4.f); + auto Y = NDArrayFactory::create(2.f); + NDArray dLdz = NDArrayFactory::create(0.1f); + + NDArray dLdxExp = NDArrayFactory::create(2.f*4.f*0.1f); + + NDArray dLdyExp = NDArrayFactory::create(pow(4.f, 2.f) * log(4.f) * 0.1f); + + sd::ops::Pow_bp op; + auto results = op.evaluate({ &X, &Y, &dLdz }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto* dLdx = results.at(0); + auto* dLdy = results.at(1); + + ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); + ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); + ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); + ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); + +} + +TEST_F(DeclarableOpsTests15, Pow_BP_Test9) { + + sd::ops::Pow_bp op; + // diff shapes + NDArray x('c', { 3,2,1 }, sd::DataType::FLOAT32); + NDArray y('c', { 1,2,3 }, sd::DataType::FLOAT32); + NDArray dLdz('c', { 3,2,3 }, sd::DataType::FLOAT32); + + NDArray dLdxExp('c', { 3,2,1 }, { 4.8, 12., 19.2, 26.4, 33.6, 40.8 }, sd::DataType::FLOAT32); + NDArray dLdyExp('c', { 1,2,3 }, { 46.57949, 53.2337 , 59.88792, 66.54213, 73.19634, 79.85056 }, sd::DataType::FLOAT32); + + x.assign(4.0); + y.assign(2.0); + dLdz.linspace(0.1, 0.1); + + auto results = op.evaluate({ &x, &y, &dLdz }, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto* dLdx = results.at(0); + auto* dLdy = results.at(1); + + ASSERT_TRUE(dLdxExp.isSameShape(dLdx)); + ASSERT_TRUE(dLdxExp.equalsTo(dLdx)); + ASSERT_TRUE(dLdyExp.isSameShape(dLdy)); + ASSERT_TRUE(dLdyExp.equalsTo(dLdy)); +} + +TEST_F(DeclarableOpsTests15, Pow_BP_Test10) { + + // diff shapes broadcastable + NDArray yB('c', { 1,2,3,1 }, sd::DataType::FLOAT32); + NDArray xB('c', { 2,3,1 }, sd::DataType::FLOAT32); + + NDArray dLdyExpB('c', { 1,2,3,1 }, { 2.21807, 4.43614, 6.65421, 8.87228, 11.09035, 13.30843 }, sd::DataType::FLOAT32); + NDArray dLdxExpB('c', { 2,3,1 }, { 0.8, 1.6, 2.4, 3.2, 4., 4.8 }, sd::DataType::FLOAT32); + NDArray dLdzB('c', { 1,2,3,1 }, sd::DataType::FLOAT32); + + dLdzB.linspace(0.1, 0.1); + xB.assign(4.0); + yB.assign(2.0); + + sd::ops::Pow_bp op; + auto resultsB = op.evaluate({ &xB, &yB, &dLdzB }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, resultsB.status()); + + auto* dLdxB = resultsB.at(0); + auto* dLdyB = resultsB.at(1); + + ASSERT_TRUE(dLdxExpB.isSameShape(dLdxB)); + ASSERT_TRUE(dLdxExpB.equalsTo(dLdxB)); + + ASSERT_TRUE(dLdyExpB.isSameShape(dLdyB)); + ASSERT_TRUE(dLdyExpB.equalsTo(dLdyB)); +} + +TEST_F(DeclarableOpsTests15, Pow_BP_Test11) { +#ifdef FFAST_MATH + if (1 > 0) + return; +#endif + + NDArray xB('c', { 3,2,1 }, { .4, 3, 5, .8, -9, -12 }, sd::DataType::FLOAT32); + NDArray yB('c', { 1,2,3 }, { 3, -2, .4, -4, 10, .8 }, sd::DataType::FLOAT32); + + NDArray dLdxExpB('c', { 3,2,1 }, { -5.994056, 39366.191406, 7.508829, -2.223537, -std::numeric_limits::quiet_NaN(), -std::numeric_limits::quiet_NaN() }, sd::DataType::FLOAT32); + NDArray dLdyExpB('c', { 1,2,3 }, { 20.11211, -1.119612, -std::numeric_limits::quiet_NaN(), -0.1076, 12974.389648, -std::numeric_limits::quiet_NaN() }, sd::DataType::FLOAT32); + + NDArray dLdzB('c', { 3,2,3 }, { .1,.2,.3, .1,.2,.3, .1,.4,.1, .2,.1,.1, .3,.1,.5, .1, .7, .1 }, sd::DataType::FLOAT32); + + sd::ops::Pow_bp op; + auto resultsB = op.evaluate({ &xB, &yB, &dLdzB }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, resultsB.status()); + auto* dLdxB = resultsB.at(0); + auto* dLdyB = resultsB.at(1); + + ASSERT_TRUE(dLdxExpB.isSameShape(dLdxB)); + for (int i = 0; i < dLdxB->lengthOf(); ++i) { + if (!sd::math::nd4j_isnan(dLdxB->e(i)) && !sd::math::nd4j_isnan(dLdxExpB.e(i))) + ASSERT_NEAR(dLdxB->e(i), dLdxExpB.e(i), 0.00001); + } + + ASSERT_TRUE(dLdyExpB.isSameShape(dLdyB)); + for (int i = 0; i < dLdyB->lengthOf(); ++i) { + if (!sd::math::nd4j_isnan(dLdyB->e(i)) && !sd::math::nd4j_isnan(dLdyExpB.e(i))) + ASSERT_NEAR(dLdyB->e(i), dLdyExpB.e(i), 0.00001); + } + + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, TestTensorMmul_BP1) { + + NDArray A('c', { 1, 2, 3 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6 }, sd::DataType::FLOAT32); + NDArray B('c', { 1, 2, 4 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8 }, sd::DataType::FLOAT32); + NDArray dLdC('c', { 3, 4 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.1 }, sd::DataType::FLOAT32); + + NDArray dLdA('c', { 1, 2, 3 }, { 3.3, 8.5, 13.36, 3.7, 9.54, 15. }, sd::DataType::FLOAT32); + NDArray dLdB('c', { 1, 2, 4 }, { 3.38, 4.04, 4.7, 5.13, 3.83, 4.58, 5.33, 5.82 }, sd::DataType::FLOAT32); + + sd::ops::tensormmul_bp op_bp; + + auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,0,1, 2,0,1 }, {}); + + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + + auto* dLdAbp = resultsBP.at(0); + auto* dLdBbp = resultsBP.at(1); + + ASSERT_TRUE(dLdA.isSameShape(*dLdAbp)); + ASSERT_TRUE(dLdA.equalsTo(*dLdAbp)); + + ASSERT_TRUE(dLdB.isSameShape(*dLdBbp)); + ASSERT_TRUE(dLdB.equalsTo(*dLdBbp)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, TestTensorMmul_BP2) { + + NDArray A('c', { 1, 2, 3 }, { 2,2,2, 2,2,2 }, sd::DataType::FLOAT32); + NDArray B('c', { 1, 2, 3 }, { 3,3,3,3, 3,3 }, sd::DataType::FLOAT32); + NDArray dLdC('c', { 1 }, { 1 }, sd::DataType::FLOAT32); + + sd::ops::tensormmul_bp op_bp; + auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); + + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + + auto* dLdAbp = resultsBP.at(0); + auto* dLdBbp = resultsBP.at(1); + + ASSERT_TRUE(B.isSameShape(*dLdAbp)); + ASSERT_TRUE(B.equalsTo(*dLdAbp)); + + ASSERT_TRUE(A.isSameShape(*dLdBbp)); + ASSERT_TRUE(A.equalsTo(*dLdBbp)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, TestTensorMmul_BP3) { + + NDArray A('c', { 3, 2, 2 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, sd::DataType::FLOAT32); + NDArray B('c', { 4, 2, 2 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, sd::DataType::FLOAT32); + NDArray dLdC('c', { 3, 4 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2 }, sd::DataType::FLOAT32); + + NDArray dA('c', { 3, 2, 2 }, { 3.9, 4., 4.1, 4.2, 9.82, 10.08, 10.34, 10.6, 15.74, 16.16, 16.58, 17. }, sd::DataType::FLOAT32); + NDArray dB('c', { 4, 2, 2 }, { 4.07, 4.22, 4.37, 4.52, 4.82, 5., 5.18, 5.36, 5.57, 5.78, 5.99, 6.2, 6.32, 6.56, 6.8, 7.04 }, sd::DataType::FLOAT32); + + sd::ops::tensormmul_bp op_bp; + + auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); + + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + + auto* dLdAbp = resultsBP.at(0); + auto* dLdBbp = resultsBP.at(1); + + ASSERT_TRUE(dA.isSameShape(*dLdAbp)); + ASSERT_TRUE(dA.equalsTo(*dLdAbp)); + + ASSERT_TRUE(dB.isSameShape(*dLdBbp)); + ASSERT_TRUE(dB.equalsTo(*dLdBbp)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, TestTensorMmul_BP4) { + + NDArray A('c', { 3, 4, 1 }, { 0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3 }, sd::DataType::FLOAT32); + NDArray B('c', { 2, 4, 1 }, { 4, 13, .5, 19, 2.3, 1.2, 18, .9 }, sd::DataType::FLOAT32); + NDArray dLdC('c', { 3, 2 }, { 1.1, 1.2, 1.3, 1.4, 1.5, 1.6 }, sd::DataType::FLOAT32); + + NDArray dLdA('c', { 3, 4, 1 }, { 7.16, 15.74, 22.15, 21.98, 8.42, 18.58, 25.85, 25.96, 9.68, 21.42, 29.55, 29.94 }, sd::DataType::FLOAT32); + NDArray dLdB('c', { 2, 4, 1 }, { 30.49, 3.456, 201.9, 26.1, 32.84 , 3.768, 215.6, 28.2 }, sd::DataType::FLOAT32); + + sd::ops::tensormmul_bp op_bp; + + auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); + + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + + auto* dLdAbp = resultsBP.at(0); + auto* dLdBbp = resultsBP.at(1); + + ASSERT_TRUE(dLdA.isSameShape(*dLdAbp)); + ASSERT_TRUE(dLdA.equalsTo(*dLdAbp)); + + ASSERT_TRUE(dLdB.isSameShape(*dLdBbp)); + ASSERT_TRUE(dLdB.equalsTo(*dLdBbp)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, TestTensorMmul_BP5) { + + NDArray A('c', { 3, 4, 1, 1 }, { 0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3 }, sd::DataType::FLOAT32); + NDArray B('c', { 2, 4, 1, 1 }, { 4, 13, .5, 19, 2.3, 1.2, 18, .9 }, sd::DataType::FLOAT32); + NDArray dLdC('c', { 3, 1, 2, 1 }, { 1.1,1.2,1.3,1.4,1.5,1.6 }, sd::DataType::FLOAT32); + + NDArray dLdA('c', { 3, 4, 1, 1 }, { 7.16, 15.74, 22.15, 21.98, 8.42, 18.58, 25.85, 25.96, 9.68, 21.42, 29.55, 29.94 }, sd::DataType::FLOAT32); + NDArray dLdB('c', { 2, 4, 1, 1 }, { 30.49, 3.456, 201.9, 26.1, 32.84, 3.768, 215.6, 28.2 }, sd::DataType::FLOAT32); + + sd::ops::tensormmul_bp op_bp; + + auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); + + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + + auto* dLdAbp = resultsBP.at(0); + auto* dLdBbp = resultsBP.at(1); + + ASSERT_TRUE(dLdA.isSameShape(*dLdAbp)); + ASSERT_TRUE(dLdA.equalsTo(*dLdAbp)); + + ASSERT_TRUE(dLdB.isSameShape(*dLdBbp)); + ASSERT_TRUE(dLdB.equalsTo(*dLdBbp)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, TestTensorMmul_BP6) { + + NDArray A('c', { 2, 2, 2 }, { 2,2, 2,2, 2,2, 2,2 }, sd::DataType::FLOAT32); + NDArray B('c', { 2, 2, 2 }, { 3,3, 3,3, 3,3, 3,3 }, sd::DataType::FLOAT32); + + auto dLdC = NDArrayFactory::create(1.f); + + sd::ops::tensormmul_bp op_bp; + auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 3,0,1,2, 3,0,1,2 }, {}); + + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + + auto* dLdAbp = resultsBP.at(0); + auto* dLdBbp = resultsBP.at(1); + + ASSERT_TRUE(B.isSameShape(*dLdAbp)); + ASSERT_TRUE(B.equalsTo(*dLdAbp)); + + ASSERT_TRUE(A.isSameShape(*dLdBbp)); + ASSERT_TRUE(A.equalsTo(*dLdBbp)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, TestTensorMmul_BP7) { + + NDArray A('c', { 3, 4, 1 }, { 0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3 }, sd::DataType::FLOAT32); + NDArray B('c', { 2, 4, 1 }, { 4, 13, .5, 19, 2.3, 1.2, 18, .9 }, sd::DataType::FLOAT32); + NDArray dLdC('c', { 3, 1, 2, 1 }, { 1.1, 1.2, 1.3, 1.4, 1.5, 1.6 }, sd::DataType::FLOAT32); + + NDArray dLdA('c', { 3, 4, 1 }, { 7.16, 15.74, 22.15, 21.98, 8.42, 18.58, 25.85, 25.96, 9.68, 21.42, 29.55, 29.94 }, sd::DataType::FLOAT32); + NDArray dLdB('c', { 2, 4, 1 }, { 30.49, 3.456, 201.9, 26.1, 32.84, 3.768, 215.6, 28.2 }, sd::DataType::FLOAT32); + + sd::ops::tensormmul_bp op_bp; + + auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {}); + + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + auto* dLdAbp = resultsBP.at(0); + auto* dLdBbp = resultsBP.at(1); + + ASSERT_TRUE(dLdA.isSameShape(*dLdAbp)); + ASSERT_TRUE(dLdA.equalsTo(*dLdAbp)); + + ASSERT_TRUE(dLdB.isSameShape(*dLdBbp)); + ASSERT_TRUE(dLdB.equalsTo(*dLdBbp)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, TestTensorMmul_BP8) { + + NDArray A('c', { 1, 1, 4, 3 }, { 0.4, 3, 5, 9, 23, 0.12, 8, 9, 0.1, 0, 124, 3 }, sd::DataType::FLOAT32); + NDArray B('c', { 1, 1, 4, 2 }, { 4, 13, .5, 19, 2.3, 1.2, 18, .9 }, sd::DataType::FLOAT32); + NDArray dLdC('c', { 3, 2 }, { 1.1,1.2,1.3,1.4,1.5,1.6 }, sd::DataType::FLOAT32); + + NDArray dLdA('c', { 1, 1, 4, 3 }, { 20., 23.4, 26.8, 23.35, 27.25, 31.15, 3.97, 4.67, 5.37, 20.88, 24.66, 28.44 }, sd::DataType::FLOAT32); + NDArray dLdB('c', { 1, 1, 4, 2 }, { 11.84, 12.68, 39.98, 43.192, 20.65, 22.36, 165.7, 178.4 }, sd::DataType::FLOAT32); + + sd::ops::tensormmul_bp op_bp; + + auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 3,0,1,2, 3,0,1,2 }, {}); + + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + + auto* dLdAbp = resultsBP.at(0); + auto* dLdBbp = resultsBP.at(1); + + ASSERT_TRUE(dLdA.isSameShape(*dLdAbp)); + ASSERT_TRUE(dLdA.equalsTo(*dLdAbp)); + + ASSERT_TRUE(dLdB.isSameShape(*dLdBbp)); + ASSERT_TRUE(dLdB.equalsTo(*dLdBbp)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, TestTensorMmul_BP9) { + + NDArray A('c', { 3, 2, 2, 1 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, sd::DataType::FLOAT32); + NDArray B('c', { 4, 2, 2 ,1 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, sd::DataType::FLOAT32); + NDArray dLdC('c', { 3, 1, 4, 1 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2 }, sd::DataType::FLOAT32); + + NDArray dA('c', { 3, 2, 2, 1 }, { 3.9, 4., 4.1, 4.2, 9.82, 10.08, 10.34, 10.6, 15.74, 16.16, 16.58, 17. }, sd::DataType::FLOAT32); + NDArray dB('c', { 4, 2, 2, 1 }, { 4.07, 4.22, 4.37, 4.52, 4.82, 5., 5.18, 5.36, 5.57, 5.78, 5.99, 6.2, 6.32, 6.56, 6.8, 7.04 }, sd::DataType::FLOAT32); + + sd::ops::tensormmul_bp op_bp; + + auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); + + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + + auto* dLdAbp = resultsBP.at(0); + auto* dLdBbp = resultsBP.at(1); + + ASSERT_TRUE(dA.isSameShape(*dLdAbp)); + ASSERT_TRUE(dA.equalsTo(*dLdAbp)); + + ASSERT_TRUE(dB.isSameShape(*dLdBbp)); + ASSERT_TRUE(dB.equalsTo(*dLdBbp)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, TestTensorMmul_BP10) { + + NDArray A('c', { 1, 2, 2, 3 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, sd::DataType::FLOAT32); + NDArray B('c', { 1, 2, 2 ,4 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, sd::DataType::FLOAT32); + NDArray dLdC('c', { 1, 3, 1, 4 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2 }, sd::DataType::FLOAT32); + + + NDArray dA('c', { 1, 2, 2, 3 }, { 3.3, 8.5, 13.7, 3.7, 9.54, 15.38, 4.1, 10.58, 17.06, 4.5, 11.62, 18.74 }, sd::DataType::FLOAT32); + NDArray dB('c', { 1, 2, 2, 4 }, { 3.38, 4.04, 4.7, 5.36, 3.83, 4.58, 5.33, 6.08, 4.28, 5.12, 5.96, 6.8, 4.73, 5.66, 6.59, 7.52 }, sd::DataType::FLOAT32); + + sd::ops::tensormmul_bp op_bp; + + auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }, {}); + + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + + auto* dLdAbp = resultsBP.at(0); + auto* dLdBbp = resultsBP.at(1); + + ASSERT_TRUE(dA.isSameShape(*dLdAbp)); + ASSERT_TRUE(dA.equalsTo(*dLdAbp)); + + ASSERT_TRUE(dB.isSameShape(*dLdBbp)); + ASSERT_TRUE(dB.equalsTo(*dLdBbp)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, TestTensorMmul_BP11) { + + NDArray A('c', { 2, 2, 3 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, sd::DataType::FLOAT32); + NDArray B('c', { 2, 2 ,4 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, sd::DataType::FLOAT32); + NDArray dLdC('c', { 3, 4 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2 }, sd::DataType::FLOAT32); + + + NDArray dA('c', { 2, 2, 3 }, { 3.3, 8.5, 13.7, 3.7, 9.54, 15.38, 4.1, 10.58, 17.06, 4.5, 11.62, 18.74 }, sd::DataType::FLOAT32); + NDArray dB('c', { 2, 2, 4 }, { 3.38, 4.04, 4.7, 5.36, 3.83, 4.58, 5.33, 6.08, 4.28, 5.12, 5.96, 6.8, 4.73, 5.66, 6.59, 7.52 }, sd::DataType::FLOAT32); + + sd::ops::tensormmul_bp op_bp; + + auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 2,0,1, 2,0,1 }, {}); + + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + + auto* dLdAbp = resultsBP.at(0); + auto* dLdBbp = resultsBP.at(1); + + ASSERT_TRUE(dA.isSameShape(*dLdAbp)); + ASSERT_TRUE(dA.equalsTo(*dLdAbp)); + + ASSERT_TRUE(dB.isSameShape(*dLdBbp)); + ASSERT_TRUE(dB.equalsTo(*dLdBbp)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, TestTensorMmul_BP12) { + + NDArray A('c', { 2, 2, 3 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, sd::DataType::FLOAT32); + NDArray B('c', { 2, 2 ,3 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2 }, sd::DataType::FLOAT32); + NDArray dLdC('c', { 2, 3, 2, 3 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, + 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2., 2.1, 2.2, 2.3, 2.4, + 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6 }, sd::DataType::FLOAT32); + + NDArray dA('c', { 2, 2, 3 }, { 7.66, 20.26, 32.86, 8.29, 21.97, 35.65, 45.46, 58.06, 70.66, 49.33, 63.01, 76.69 }, sd::DataType::FLOAT32); + NDArray dB('c', { 2, 2, 3 }, { 25.86, 27.36, 28.86, 28.74, 30.42, 32.1, 30.36, 31.86, 33.36, 33.78, 35.46, 37.14 }, sd::DataType::FLOAT32); + + sd::ops::tensormmul_bp op_bp; + + auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {}); + + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + + auto* dLdAbp = resultsBP.at(0); + auto* dLdBbp = resultsBP.at(1); + + ASSERT_TRUE(dA.isSameShape(*dLdAbp)); + ASSERT_TRUE(dA.equalsTo(*dLdAbp)); + + ASSERT_TRUE(dB.isSameShape(*dLdBbp)); + ASSERT_TRUE(dB.equalsTo(*dLdBbp)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, TestTensorMmul_BP13) { + + NDArray A('c', { 3, 2, 2 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2 }, sd::DataType::DOUBLE); + NDArray B('c', { 3, 2, 2 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2 }, sd::DataType::DOUBLE); + NDArray dLdC('c', { 3, 2, 3, 2 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, + 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2., 2.1, 2.2, 2.3, 2.4, + 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6 }, sd::DataType::DOUBLE); + + NDArray dA('c', { 3, 2, 2 }, { 7.79, 20.57, 8.21, 21.71, 33.35, 46.13, 35.21, 48.71, 58.91, 71.69, 62.21, 75.71 }, sd::DataType::DOUBLE); + NDArray dB('c', { 3, 2, 2 }, { 26.49, 28.02, 28.41, 30.06, 29.55, 31.08, 31.71, 33.36, 32.61, 34.14, 35.01, 36.66 }, sd::DataType::DOUBLE); + + sd::ops::tensormmul_bp op_bp; + + auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {}); + + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + + auto* dLdAbp = resultsBP.at(0); + auto* dLdBbp = resultsBP.at(1); + + ASSERT_TRUE(dA.isSameShape(*dLdAbp)); + ASSERT_TRUE(dA.equalsTo(*dLdAbp)); + + ASSERT_TRUE(dB.isSameShape(*dLdBbp)); + ASSERT_TRUE(dB.equalsTo(*dLdBbp)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, TestTensorMmul_BP14) { + + NDArray A('c', { 2, 2, 2, 2 }, { 2.1, 2.2, 2.3, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9, 3., 3.1, 3.2, 3.3, 3.4, 3.5, 3.6 }, sd::DataType::DOUBLE); + + NDArray B('c', { 2, 2, 2, 2 }, { 3.1, 3.2, 3.3, 3.4, 3.5, 3.6, 3.7, 3.8, 3.9, 4., 4.1, 4.2, 4.3, 4.4, 4.5, 4.6 }, sd::DataType::DOUBLE); + + NDArray dLdC('c', { 2, 2, 2, 2, 2, 2 }, { .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, + 1.3, 1.4, 1.5, 1.6, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, + 1.3, 1.4, 1.5, 1.6, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, + 1.3, 1.4, 1.5, 1.6, .1, .2, .3, .4, .5, .6, .7, .8, .9, 1, 1.1, 1.2, + 1.3, 1.4, 1.5, 1.6 }, sd::DataType::DOUBLE); + + NDArray dA('c', { 2, 2, 2, 2 }, { 13.88, 37.24, 13.88, 37.24, 15.32, 41.24, 15.32, 41.24, 13.88, 37.24, 13.88, 37.24, 15.32, 41.24, 15.32, 41.24 }, sd::DataType::DOUBLE); + NDArray dB('c', { 2, 2, 2, 2 }, { 10.76, 12.88, 15., 17.12, 12.36, 14.8, 17.24, 19.68, 19.24, 21.36, 23.48, 25.6, 22.12, 24.56, 27., 29.44 }, sd::DataType::DOUBLE); + + sd::ops::tensormmul_bp op_bp; + + auto resultsBP = op_bp.evaluate({ &A, &B, &dLdC }, {}, { 1,1, 1,1 }, {}); + + ASSERT_EQ(ND4J_STATUS_OK, resultsBP.status()); + + auto* dLdAbp = resultsBP.at(0); + auto* dLdBbp = resultsBP.at(1); + + ASSERT_TRUE(dA.isSameShape(*dLdAbp)); + ASSERT_TRUE(dA.equalsTo(*dLdAbp)); + + ASSERT_TRUE(dB.isSameShape(*dLdBbp)); + ASSERT_TRUE(dB.equalsTo(*dLdBbp)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, TestTensorMmul_BP15) { + + NDArray A('c', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, sd::DataType::FLOAT32); + NDArray B('f', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, sd::DataType::FLOAT32); + + NDArray dLdC('f', { 2, 2 }, { 23.0, 24.44, 2.0, 26. }, sd::DataType::FLOAT32); + + NDArray dA('c', { 2, 2, 3 }, { 27., 127., 227., 77., 177., 277., 76.44, 278.20001, 479.96002, 177.32, 379.08001, 580.839966 }, sd::DataType::FLOAT32); + NDArray dB('f', { 2, 2, 3 }, { 194.08, 184., 336.4, 268., 241.52, 212., 383.839996, 296., 288.96002, 240., 431.27999, 324. }, sd::DataType::FLOAT32); + + sd::ops::tensormmul_bp op; + auto results = op.evaluate({ &A, &B, &dLdC }, {}, { 2,1,2,2,1,2 }); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto* dLdA = results.at(0); + auto* dLdB = results.at(1); + + ASSERT_TRUE(dA.isSameShape(*dLdA)); + ASSERT_TRUE(dA.equalsTo(*dLdA)); + + ASSERT_TRUE(dB.isSameShape(*dLdB)); + ASSERT_TRUE(dB.equalsTo(*dLdB)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, TestTensorMmul_BP16) { + + NDArray A('f', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, sd::DataType::DOUBLE); + NDArray B('c', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, sd::DataType::DOUBLE); + + NDArray dLdC('c', { 2, 2 }, sd::DataType::DOUBLE); + + const OpArgsHolder argsHolderFF({ &A, &B }, {}, { 2,1,2, 2,1,2 }); + const OpArgsHolder argsHolderBP({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }); + + sd::ops::tensormmul op; + sd::ops::tensormmul_bp op_bp; + + const bool isGradCorrect = GradCheck::checkGrad(op, op_bp, argsHolderFF, argsHolderBP, {1,0}); + ASSERT_TRUE(isGradCorrect); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, TestTensorMmul_BP17) { + + NDArray A('f', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, sd::DataType::DOUBLE); + NDArray B('f', { 2, 2, 3 }, { 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }, sd::DataType::DOUBLE); + + NDArray dLdC('c', { 2, 2 }, sd::DataType::DOUBLE); + + const OpArgsHolder argsHolderFF({ &A, &B }, {}, { 2,1,2, 2,1,2 }); + const OpArgsHolder argsHolderBP({ &A, &B, &dLdC }, {}, { 2,1,2, 2,1,2 }); + + sd::ops::tensormmul op; + sd::ops::tensormmul_bp op_bp; + + const bool isGradCorrect = GradCheck::checkGrad(op, op_bp, argsHolderFF, argsHolderBP, { 1,0 }); + ASSERT_TRUE(isGradCorrect); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, gru_1) { + + const int sL = 3; + const int bS = 2; + const int nIn = 5; + const int nOut = 4; + + + NDArray x('c', {sL, bS, nIn}, {0.5, 1. , 1.5, 2. , 2.5, 3. , 3.5, 4. , 4.5, 5. , 5.5, 6. , 6.5, 7. , 7.5, 8. , 8.5, 9. , 9.5, 10. , 10.5, 11. , 11.5, 12. , 12.5, 13. , 13.5, 14. , 14.5, 15.}, sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, {-3,-2,-1,0,1,2,3,4}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 3*nOut}, sd::DataType::FLOAT32); + NDArray Wh('c', {nOut, 3*nOut}, sd::DataType::FLOAT32); + NDArray b('c', {3*nOut}, sd::DataType::FLOAT32); + + NDArray expH('c', {sL, bS, nOut}, {-1.681847, -1.062565, -0.443283, 0.175998,0.837823, 1.488041, 2.13826 , 2.788478, -0.888747, -0.491826, -0.094907, 0.302014, + 0.751355, 1.182715, 1.614075, 2.045434, -0.388876, -0.126716, 0.135444, 0.397604,0.710558, 1.002922, 1.295287, 1.587651}, sd::DataType::FLOAT32); + + Wx = 0.003; + Wh = 0.006; + b = 0.5; + + NDArray dLdC('c', { 2, 2 }, sd::DataType::DOUBLE); + + sd::ops::gru op; + auto results = op.evaluate({&x, &hI, &Wx, &Wh, &b}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto* h = results.at(0); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, sqrtm_1) { + + NDArray x1('c', {1,1}, {4.}, sd::DataType::DOUBLE); + NDArray x2('c', {2,2}, {1.3,2,0.3,.5}, sd::DataType::DOUBLE); + NDArray x3('c', {3,3}, {0.5 ,-0.4 ,1.2 ,-2.8 ,-0.2 ,-2.1 ,-2.4 ,-2.0 ,1.1}, sd::DataType::DOUBLE); + NDArray x4('c', {4,4}, {0.33 ,-7.25 ,1.71 ,6.20 ,1.34 ,5.38 ,-2.76 ,-8.51 ,7.59 ,3.44 ,2.24 ,-6.82 ,-1.15 ,4.80 ,-4.67 ,2.14}, sd::DataType::DOUBLE); + NDArray x5('c', {5,5}, {2.4 ,0.3 ,0.0 ,1.1 ,1.8 ,0.1 ,1.7 ,2.7 ,1.5 ,2.6 ,0.6 ,2.1 ,2.2 ,1.0 ,0.2 ,1.2 ,2.8 ,1.9 ,0.8 ,2.0 ,0.5 ,1.6 ,0.9 ,1.4 ,2.5}, sd::DataType::DOUBLE); + + NDArray exp1('c', {1,1}, {2.}, sd::DataType::DOUBLE); + NDArray exp2('c', {2,2}, {1.0163674, 1.3341597,0.200124, 0.4827035}, sd::DataType::DOUBLE); + NDArray exp3('c', {3,3}, {6.5692188, 2.6273616,-0.1387864,-16.8404762,-7.0296495, 0.9204148,-11.4664296,-5.834273 , 2.2087478}, sd::DataType::DOUBLE); + NDArray exp4('c', {4,4}, {1.161387 ,-1.9343154, 0.230372 , 0.8660897,0.80588 , 3.4045446,-1.0152824,-2.0369467,2.2589629, 1.9674252, 1.5109997,-1.4283141,0.0226356, 1.3032279,-1.00396 , 1.8278487}, sd::DataType::DOUBLE); + NDArray exp5('c', {5,5}, {1.4175046,-0.4425298, 0.1846149, 0.3166522, 0.9140631,-0.1929139, 0.2889113, 1.4045273, 0.2600026, 1.552021 , 0.1372758, 0.5703854, 1.3336126, 0.3869317,-0.082492 , + 0.8607272, 3.1792474,-0.9499947, 0.8541668,-1.4243879, 0.0081136,-0.0622248, 0.4534325, 0.4641865, 1.8132138}, sd::DataType::DOUBLE); + + sd::ops::sqrtm op; + + auto results = op.evaluate({&x1}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(exp1.isSameShape(results.at(0))); + ASSERT_TRUE(exp1.equalsTo(results.at(0))); + + results = op.evaluate({&x2}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(exp2.isSameShape(results.at(0))); + ASSERT_TRUE(exp2.equalsTo(results.at(0))); + + results = op.evaluate({&x3}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(exp3.isSameShape(results.at(0))); + ASSERT_TRUE(exp3.equalsTo(results.at(0))); + + results = op.evaluate({&x4}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(exp4.isSameShape(results.at(0))); + ASSERT_TRUE(exp4.equalsTo(results.at(0))); + + results = op.evaluate({&x5}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(exp5.isSameShape(results.at(0))); + ASSERT_TRUE(exp5.equalsTo(results.at(0))); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, sqrtm_2) { + + NDArray x('c', {10,10}, {-0.3 ,2.7 ,4.9 ,7.0 ,7.3 ,-1.3 ,0.5 ,9.9 ,-9.4 ,8.4 ,2.2 ,5.2 ,7.6 ,1.2 ,2.0 ,-3.8 ,2.1 ,6.1 ,1.6 ,6.9 ,5.1 ,5.3 ,6.4 ,8.7 ,0.1 ,8.5 , + 3.3 ,1.0 ,6.8 ,0.4 ,0.7 ,3.2 ,7.4 ,6.7 ,1.1 ,7.2 ,6.0 ,7.5 ,9.7 ,5.4 ,9.0 ,6.3 ,0.0 ,4.5 ,8.3 ,7.9 ,3.0 ,6.5 ,0.6 ,8.0 ,9.5 ,3.6 ,1.9 ,6.2 ,0.9 ,4.0 ,4.1 , + 8.1 ,3.9 ,4.3 ,4.7 ,3.7 ,3.4 ,5.8 ,10.0 ,8.6 ,9.3 ,9.1 ,4.6 ,1.4 ,7.8 ,1.5 ,7.7 ,4.2 ,9.6 ,8.2 ,-7.1 ,5.7 ,5.5 ,2.6 ,8.8 ,2.9 ,0.2 ,5.6 ,-2.5 ,8.9 ,2.8 ,0.8 ,1.5 ,3.1 ,3.5 ,4.4 ,2.4 ,9.2 ,-4.8 ,1.7 ,6.6 ,9.8 ,1.8 ,5.9}, sd::DataType::DOUBLE); + + NDArray expZ('c', {10,10}, {1.2779038, 0.0333321, 0.8215617, 0.5736392, 1.3973911, -1.1757741,0.1990005, 1.5893778, -3.0159568, 2.5829108,0.5692253, 2.219431 , 1.022612 , -0.3131795, -0.1957848, -1.7805065, + 0.6668489, 1.1968921, 0.9781974, 1.2007764,0.7028634, 0.7496937, 2.2511438, 2.1945378, 0.2559353, 2.8948612,-0.4306994, -0.9922216, 0.3884369, -1.4174481, + -1.6060233, 0.1571057, 1.432471 , 0.4508346, 0.0618069, -2.4511742,2.0641709, 2.4751085, 1.84787 , 3.4146313,0.7774219, 0.768369 , -0.1417226, -0.3970577, 2.9512879, 0.5474537, + 0.4991412, 0.7604095, 0.4523091, 1.7813704,2.5998339, 0.9402402, -0.82775 , 2.3637147, -0.6394584, 4.6181937,-0.1762181, -0.2820475, 0.9280713, -2.1876918, + 0.1576249, 0.336376 , 0.2017592, 0.851786 , 1.3542577, 1.2752901,2.9718476, 1.1102557, 0.0067319, -0.2652283,0.8839235, -0.2637131, 1.5687876, 0.5156139, 1.9015886, 0.9087172, + -1.5607482, 2.4216275, 1.0399745, -0.4930439,1.3044354, 0.1690006, 0.2106909, -0.2683631, -0.4193939, 1.0233265,0.4571777, -0.2024148, 2.3564855, 1.0442339, + 1.1073322, 1.0728525, -0.5917566, 2.2267418, -1.6096582, 2.0685315,0.6800798, 0.4451858, -0.4048465, 1.2347676}, sd::DataType::DOUBLE); + sd::ops::sqrtm op; + + auto results = op.evaluate({&x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(expZ.isSameShape(results.at(0))); + ASSERT_TRUE(expZ.equalsTo(results.at(0))); +} diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests16.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests16.cpp new file mode 100644 index 000000000..dbdc87079 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests16.cpp @@ -0,0 +1,1499 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + + + // + // @author raver119@gmail.com + // + +#include "testlayers.h" +#include +#include +#include +#include +#include + + +using namespace sd; + + +class DeclarableOpsTests16 : public testing::Test { +public: + + DeclarableOpsTests16() { + printf("\n"); + fflush(stdout); + } +}; + +TEST_F(DeclarableOpsTests16, scatter_upd_1) { + auto x = NDArrayFactory::create('c', { 3 }, { 1.f, 1.f, 1.f }); + auto y = NDArrayFactory::create(0); + auto w = NDArrayFactory::create(3.0f); + auto e = NDArrayFactory::create('c', { 3 }, { 3.f, 1.f, 1.f }); + + sd::ops::scatter_upd op; + auto result = op.evaluate({ &x, &y, &w }); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_EQ(e, *z); +} + +TEST_F(DeclarableOpsTests16, scatter_upd_2) { + + NDArray x('c', { 10, 3 }, sd::DataType::FLOAT32); + NDArray indices('c', { 2 }, { 2,5 }, sd::DataType::INT32); + NDArray updates('c', { 2, 3 }, { 100,101,102, 200,201,202 }, sd::DataType::FLOAT32); + NDArray e('c', { 10, 3 }, { 1,2,3, 4,5,6, 100,101,102, 10,11,12, 13,14,15, 200,201,202, 19,20,21, 22,23,24, 25,26,27, 28,29,30 }, sd::DataType::FLOAT32); + + x.linspace(1); + + sd::ops::scatter_upd op; + auto result = op.evaluate({ &x, &indices, &updates }); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_EQ(e, *z); +} + +TEST_F(DeclarableOpsTests16, scatter_upd_3) { + + NDArray x('c', { 10, 3 }, sd::DataType::FLOAT32); + NDArray indices('c', { 2 }, { 20,5 }, sd::DataType::INT32); + NDArray updates('c', { 2, 3 }, { 100,101,102, 200,201,202 }, sd::DataType::FLOAT32); + NDArray output('c', { 10, 3 }, sd::DataType::FLOAT32); + + sd::ops::scatter_upd op; + ASSERT_ANY_THROW(op.execute({ &x, &indices, &updates }, { &output }, {}, {}, { true, true })); +} + +TEST_F(DeclarableOpsTests16, test_size_dtype_1) { + auto x = NDArrayFactory::create('c', { 3 }, { 1, 1, 1 }); + auto z = NDArrayFactory::create(0.0f); + auto e = NDArrayFactory::create(3.0f); + + sd::ops::size op; + auto status = op.execute({ &x }, { &z }, {}, {}, {}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); +} + +TEST_F(DeclarableOpsTests16, test_empty_noop_1) { + auto z = NDArrayFactory::empty(); + + sd::ops::noop op; + auto status = op.execute({}, { &z }, {}, {}, {}); + ASSERT_EQ(Status::OK(), status); +} + +TEST_F(DeclarableOpsTests16, test_empty_noop_2) { + auto z = NDArrayFactory::empty(); + + Context ctx(1); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + + sd::ops::noop op; + auto status = op.execute(&ctx); + + ASSERT_EQ(Status::OK(), status); +} + +TEST_F(DeclarableOpsTests16, test_svd_1) { + auto x = NDArrayFactory::create('c', { 3, 3 }, { 0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f,0.50563407f, 0.89252293f, 0.5461209f }); + auto z = NDArrayFactory::create('c', { 3 }); + + sd::ops::svd op; + auto status = op.execute({ &x }, { &z }, {}, { 0, 0, 16 }, {}); + + ASSERT_EQ(Status::OK(), status); +} + +TEST_F(DeclarableOpsTests16, test_hamming_distance_1) { + auto x = NDArrayFactory::create({ 37, 37, 37 }); + auto y = NDArrayFactory::create({ 8723, 8723, 8723 }); + auto e = NDArrayFactory::create(18); + + sd::ops::bits_hamming_distance op; + auto result = op.evaluate({ &x, &y }); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_EQ(e, *z); +} + +TEST_F(DeclarableOpsTests16, test_knn_mindistance_1) { + auto input = NDArrayFactory::create('c', { 512 }); + auto low = NDArrayFactory::create('c', { 512 }); + auto high = NDArrayFactory::create('c', { 512 }); + + auto output = NDArrayFactory::create(0.0f); + + input.linspace(1.0); + low.linspace(1.0); + high.linspace(1.0); + + sd::ops::knn_mindistance op; + auto result = op.execute({ &input, &low, &high }, { &output }, {}, {}, {}); + ASSERT_EQ(Status::OK(), result); +} + +TEST_F(DeclarableOpsTests16, test_empty_cast_1) { + auto x = NDArrayFactory::create('c', { 1, 0, 2 }); + auto e = NDArrayFactory::create('c', { 1, 0, 2 }); + + sd::ops::cast op; + auto result = op.evaluate({&x}, {10}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(e, *result.at(0)); +} + +TEST_F(DeclarableOpsTests16, test_range_1) { + sd::ops::range op; + auto z = NDArrayFactory::create('c', { 200 }); + + Context ctx(1); + ctx.setTArguments({ -1.0, 1.0, 0.01 }); + ctx.setOutputArray(0, &z); + + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); +} + +TEST_F(DeclarableOpsTests16, test_range_2) { + sd::ops::range op; + auto z = NDArrayFactory::create('c', { 200 }); + + double tArgs[] = { -1.0, 1.0, 0.01 }; + + auto shapes = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 3, nullptr, 0, nullptr, 0, nullptr, 0); + // shape::printShapeInfoLinear("Result", shapes->at(0)); + ASSERT_TRUE(shape::shapeEquals(z.shapeInfo(), shapes->at(0))); + + delete shapes; +} + +TEST_F(DeclarableOpsTests16, test_reverse_1) { + std::vector rows = { 3, 5, 7, 8, 9, 10, 119, 211 }; + std::vector columns = { 6, 5, 10, 100, 153, 171, 635 }; + + for (auto r : rows) { + for (auto c : columns) { + //nd4j_printf("Trying [%i, %i]\n", r, c); + auto array = NDArrayFactory::create('c', { r, c }); + auto exp = NDArrayFactory::create('c', { r, c }); + auto reversed = NDArrayFactory::create('c', { r, c }); + + auto rowOriginal = NDArrayFactory::create('c', { c }); + auto rowReversed = NDArrayFactory::create('c', { c }); + + for (int e = 0; e < c; e++) { + rowOriginal.p(e, (float)e); + rowReversed.p(c - e - 1, (float)e); + } + + + auto listI = array.allTensorsAlongDimension({ 1 }); + auto listE = exp.allTensorsAlongDimension({ 1 }); + + for (int e = 0; e < r; e++) { + listI.at(e)->assign(rowOriginal); + listE.at(e)->assign(rowReversed); + } + + sd::ops::reverse op; + Nd4jLong axis = 1; + auto status = op.execute({ &array }, { &reversed }, {}, { axis }, {}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(exp, reversed); + } + } +} + +TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_1) { + /* + test case generated by python colorsys and scaled to suit our needs + from colorsys import * + from random import * + import numpy as np + rgbs = np.random.uniform(0,1, 5*4*3 ).astype('float32').reshape([5,4,3]) + hsvs=np.apply_along_axis(lambda x: np.array(rgb_to_hsv(x[0],x[1],x[2])),2,rgbs) + rgbs.ravel() + hsvs.ravel() + */ + auto rgbs = NDArrayFactory::create('c', { 5, 4, 3 }, { + 0.545678377f, 0.725874603f, 0.413571358f, 0.644941628f, 0.517642438f, + 0.890151322f, 0.461456001f, 0.0869259685f, 0.928968489f, 0.588904262f, + 0.54742825f, 0.684074104f, 0.52110225f, 0.761800349f, 0.486593395f, + 0.753103435f, 0.237176552f, 0.263826847f, 0.913557053f, 0.90049392f, + 0.290193319f, 0.46850124f, 0.965541422f, 0.148351923f, 0.674094439f, + 0.524110138f, 0.216262609f, 0.0361763388f, 0.2204483f, 0.279114306f, + 0.3721793f, 0.632020354f, 0.25007084f, 0.823592246f, 0.637001634f, + 0.30433768f, 0.0448598303f, 0.385092884f, 0.366362303f, 0.586083114f, + 0.218390301f, 0.931746006f, 0.978048146f, 0.762684941f, 0.00208298792f, + 0.91390729f, 0.505838513f, 0.875348926f, 0.428009957f, 0.367065936f, + 0.911922634f, 0.270003974f, 0.164243385f, 0.0581932105f, 0.313204288f, + 0.644775152f, 0.437950462f, 0.775881767f, 0.575452209f, 0.946475744f + }); + auto expected = NDArrayFactory::create('c', { 5, 4, 3 }, { + 0.262831867f, 0.430244058f, 0.725874603f, 0.723622441f, 0.418478161f, + 0.890151322f, 0.740797927f, 0.906427443f, 0.928968489f, 0.717254877f, + 0.199753001f, 0.684074104f, 0.312434604f, 0.361258626f, 0.761800349f, + 0.991390795f, 0.685067773f, 0.753103435f, 0.163174023f, 0.682347894f, + 0.913557053f, 0.268038541f, 0.84635365f, 0.965541422f, 0.112067183f, + 0.679180562f, 0.674094439f, 0.540247589f, 0.870388806f, 0.279114306f, + 0.280050347f, 0.604331017f, 0.632020354f, 0.106776128f, 0.630475283f, + 0.823592246f, 0.490824632f, 0.883509099f, 0.385092884f, 0.75257351f, + 0.765611768f, 0.931746006f, 0.129888852f, 0.997870266f, 0.978048146f, + 0.849081645f, 0.446510047f, 0.91390729f, 0.685308874f, 0.597481251f, + 0.911922634f, 0.0834472676f, 0.784472764f, 0.270003974f, 0.396037966f, + 0.514242649f, 0.644775152f, 0.756701186f, 0.392005324f, 0.946475744f + }); + + + auto actual = NDArrayFactory::create('c', { 5,4,3 }); + + Context ctx(1); + ctx.setInputArray(0, &rgbs); + ctx.setOutputArray(0, &actual); + + sd::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); +#if 0 + //visual check + rgbs.printBuffer("rgbs "); + actual.printBuffer("HSV "); + expected.printBuffer("exp"); +#endif + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_2) { + /* + swapped_rgbs=rgbs.swapaxes(1,2).ravel() + swapped_hsvs=hsvs.swapaxes(1,2).ravel() + */ + auto rgbs = NDArrayFactory::create('c', { 5, 3, 4 }, { + 0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, 0.725874603f, + 0.517642438f, 0.0869259685f, 0.54742825f, 0.413571358f, 0.890151322f, + 0.928968489f, 0.684074104f, 0.52110225f, 0.753103435f, 0.913557053f, + 0.46850124f, 0.761800349f, 0.237176552f, 0.90049392f, 0.965541422f, + 0.486593395f, 0.263826847f, 0.290193319f, 0.148351923f, 0.674094439f, + 0.0361763388f, 0.3721793f, 0.823592246f, 0.524110138f, 0.2204483f, + 0.632020354f, 0.637001634f, 0.216262609f, 0.279114306f, 0.25007084f, + 0.30433768f, 0.0448598303f, 0.586083114f, 0.978048146f, 0.91390729f, + 0.385092884f, 0.218390301f, 0.762684941f, 0.505838513f, 0.366362303f, + 0.931746006f, 0.00208298792f, 0.875348926f, 0.428009957f, 0.270003974f, + 0.313204288f, 0.775881767f, 0.367065936f, 0.164243385f, 0.644775152f, + 0.575452209f, 0.911922634f, 0.0581932105f, 0.437950462f, 0.946475744f + }); + auto expected = NDArrayFactory::create('c', { 5, 3, 4 }, { + 0.262831867f, 0.723622441f, 0.740797927f, 0.717254877f, 0.430244058f, + 0.418478161f, 0.906427443f, 0.199753001f, 0.725874603f, 0.890151322f, + 0.928968489f, 0.684074104f, 0.312434604f, 0.991390795f, 0.163174023f, + 0.268038541f, 0.361258626f, 0.685067773f, 0.682347894f, 0.84635365f, + 0.761800349f, 0.753103435f, 0.913557053f, 0.965541422f, 0.112067183f, + 0.540247589f, 0.280050347f, 0.106776128f, 0.679180562f, 0.870388806f, + 0.604331017f, 0.630475283f, 0.674094439f, 0.279114306f, 0.632020354f, + 0.823592246f, 0.490824632f, 0.75257351f, 0.129888852f, 0.849081645f, + 0.883509099f, 0.765611768f, 0.997870266f, 0.446510047f, 0.385092884f, + 0.931746006f, 0.978048146f, 0.91390729f, 0.685308874f, 0.0834472676f, + 0.396037966f, 0.756701186f, 0.597481251f, 0.784472764f, 0.514242649f, + 0.392005324f, 0.911922634f, 0.270003974f, 0.644775152f, 0.946475744f + }); + + + auto actual = NDArrayFactory::create('c', { 5,3,4 }); + + Context ctx(1); + ctx.setInputArray(0, &rgbs); + ctx.setOutputArray(0, &actual); + ctx.setIArguments({ 1 }); + sd::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_3) { + + auto rgbs = NDArrayFactory::create('c', { 4, 3 }, { + 0.545678377f, 0.725874603f, 0.413571358f, 0.644941628f, 0.517642438f, + 0.890151322f, 0.461456001f, 0.0869259685f, 0.928968489f, 0.588904262f, + 0.54742825f, 0.684074104f + }); + auto expected = NDArrayFactory::create('c', { 4, 3 }, { + 0.262831867f, 0.430244058f, 0.725874603f, 0.723622441f, 0.418478161f, + 0.890151322f, 0.740797927f, 0.906427443f, 0.928968489f, 0.717254877f, + 0.199753001f, 0.684074104f + }); + + auto actual = NDArrayFactory::create('c', { 4, 3 }); + + Context ctx(1); + ctx.setInputArray(0, &rgbs); + ctx.setOutputArray(0, &actual); + + sd::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + + +TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_4) { + auto rgbs = NDArrayFactory::create('c', { 3, 4 }, { + 0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, 0.725874603f, + 0.517642438f, 0.0869259685f, 0.54742825f, 0.413571358f, 0.890151322f, + 0.928968489f, 0.684074104f + }); + auto expected = NDArrayFactory::create('c', { 3, 4 }, { + 0.262831867f, 0.723622441f, 0.740797927f, 0.717254877f, 0.430244058f, + 0.418478161f, 0.906427443f, 0.199753001f, 0.725874603f, 0.890151322f, + 0.928968489f, 0.684074104f + }); + + auto actual = NDArrayFactory::create('c', { 3, 4 }); + + Context ctx(1); + ctx.setInputArray(0, &rgbs); + ctx.setOutputArray(0, &actual); + ctx.setIArguments({ 0 }); + sd::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_5) { + auto rgbs = NDArrayFactory::create('c', { 3 }, { + 0.545678377f, 0.725874603f, 0.413571358f + }); + auto expected = NDArrayFactory::create('c', { 3 }, { + 0.262831867f, 0.430244058f, 0.725874603f + }); + + auto actual = NDArrayFactory::create('c', { 3 }); + + Context ctx(1); + ctx.setInputArray(0, &rgbs); + ctx.setOutputArray(0, &actual); + + sd::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + + +TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_6) { + auto rgbs = NDArrayFactory::create('c', { 3, 4 }, { + 0.545678377f, 0.644941628f, 0.461456001f, 0.588904262f, 0.725874603f, + 0.517642438f, 0.0869259685f, 0.54742825f, 0.413571358f, 0.890151322f, + 0.928968489f, 0.684074104f + }); + auto hsvs = NDArrayFactory::create('c', { 3, 4 }, { + 0.262831867f, 0.723622441f, 0.740797927f, 0.717254877f, 0.430244058f, + 0.418478161f, 0.906427443f, 0.199753001f, 0.725874603f, 0.890151322f, + 0.928968489f, 0.684074104f + }); + + //get subarray + //get subarray + NDArray subArrRgbs = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) }); + NDArray expected = hsvs.subarray({ NDIndex::all(), NDIndex::point(0) }); + subArrRgbs.reshapei({ 3 }); + expected.reshapei({ 3 }); +#if 0 + //[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER] + subArrRgbs.printShapeInfo("subArrRgbs"); +#endif + auto actual = NDArrayFactory::create('c', { 3 }); + + Context ctx(1); + ctx.setInputArray(0, &subArrRgbs); + ctx.setOutputArray(0, &actual); + sd::ops::rgb_to_hsv op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_1) { + + auto hsvs = NDArrayFactory::create('c', { 5, 4, 3 }, { + 0.705504596f, 0.793608069f, 0.65870738f, 0.848827183f, 0.920532584f, + 0.887555957f, 0.72317636f, 0.563831031f, 0.773604929f, 0.269532293f, + 0.332347751f, 0.111181192f, 0.239250854f, 0.499201417f, 0.862712979f, + 0.0853395388f, 0.0810681432f, 0.226065159f, 0.851340771f, 0.602043271f, + 0.690895379f, 0.971996486f, 0.273846686f, 0.464318275f, 0.194078103f, + 0.219649255f, 0.616706491f, 0.847525477f, 0.653597355f, 0.700065672f, + 0.0299375951f, 0.184475258f, 0.274936169f, 0.196718201f, 0.179381892f, + 0.934476376f, 0.895766437f, 0.52967906f, 0.675635338f, 0.966644645f, + 0.770889699f, 0.556649387f, 0.13426739f, 0.899450243f, 0.817096591f, + 0.150202557f, 0.763557851f, 0.709604502f, 0.741747797f, 0.657703638f, + 0.167678103f, 0.828556478f, 0.615502477f, 0.478080243f, 0.447288662f, + 0.864299297f, 0.129833668f, 0.66402483f, 0.795475543f, 0.561332941f + }); + auto expected = NDArrayFactory::create('c', { 5, 4, 3 }, { + 0.257768334f, 0.135951888f, 0.65870738f, 0.887555957f, 0.0705317783f, + 0.811602857f, 0.485313689f, 0.337422464f, 0.773604929f, 0.0883753772f, + 0.111181192f, 0.074230373f, 0.675155059f, 0.862712979f, 0.432045438f, + 0.226065159f, 0.21712242f, 0.207738476f, 0.690895379f, 0.274946465f, + 0.645954334f, 0.464318275f, 0.337166255f, 0.358530475f, 0.594427716f, + 0.616706491f, 0.481247369f, 0.700065672f, 0.242504601f, 0.661103036f, + 0.274936169f, 0.233327664f, 0.224217249f, 0.904251479f, 0.934476376f, + 0.766848235f, 0.675635338f, 0.317765447f, 0.54157777f, 0.556649387f, + 0.127534108f, 0.213413864f, 0.817096591f, 0.674227886f, 0.0821588641f, + 0.709604502f, 0.656080596f, 0.167780413f, 0.107076412f, 0.0573956046f, + 0.167678103f, 0.46964643f, 0.183820669f, 0.478080243f, 0.01761852f, + 0.129833668f, 0.0943436049f, 0.114806315f, 0.121884218f, 0.561332941f + }); + + + auto actual = NDArrayFactory::create('c', { 5,4,3 }); + + Context ctx(1); + ctx.setInputArray(0, &hsvs); + ctx.setOutputArray(0, &actual); + + sd::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_2) { + auto hsvs = NDArrayFactory::create('c', { 5, 3, 4 }, { + 0.705504596f, 0.848827183f, 0.72317636f, 0.269532293f, 0.793608069f, + 0.920532584f, 0.563831031f, 0.332347751f, 0.65870738f, 0.887555957f, + 0.773604929f, 0.111181192f, 0.239250854f, 0.0853395388f, 0.851340771f, + 0.971996486f, 0.499201417f, 0.0810681432f, 0.602043271f, 0.273846686f, + 0.862712979f, 0.226065159f, 0.690895379f, 0.464318275f, 0.194078103f, + 0.847525477f, 0.0299375951f, 0.196718201f, 0.219649255f, 0.653597355f, + 0.184475258f, 0.179381892f, 0.616706491f, 0.700065672f, 0.274936169f, + 0.934476376f, 0.895766437f, 0.966644645f, 0.13426739f, 0.150202557f, + 0.52967906f, 0.770889699f, 0.899450243f, 0.763557851f, 0.675635338f, + 0.556649387f, 0.817096591f, 0.709604502f, 0.741747797f, 0.828556478f, + 0.447288662f, 0.66402483f, 0.657703638f, 0.615502477f, 0.864299297f, + 0.795475543f, 0.167678103f, 0.478080243f, 0.129833668f, 0.561332941f + }); + auto expected = NDArrayFactory::create('c', { 5, 3, 4 }, { + 0.257768334f, 0.887555957f, 0.485313689f, 0.0883753772f, 0.135951888f, + 0.0705317783f, 0.337422464f, 0.111181192f, 0.65870738f, 0.811602857f, + 0.773604929f, 0.074230373f, 0.675155059f, 0.226065159f, 0.690895379f, + 0.464318275f, 0.862712979f, 0.21712242f, 0.274946465f, 0.337166255f, + 0.432045438f, 0.207738476f, 0.645954334f, 0.358530475f, 0.594427716f, + 0.700065672f, 0.274936169f, 0.904251479f, 0.616706491f, 0.242504601f, + 0.233327664f, 0.934476376f, 0.481247369f, 0.661103036f, 0.224217249f, + 0.766848235f, 0.675635338f, 0.556649387f, 0.817096591f, 0.709604502f, + 0.317765447f, 0.127534108f, 0.674227886f, 0.656080596f, 0.54157777f, + 0.213413864f, 0.0821588641f, 0.167780413f, 0.107076412f, 0.46964643f, + 0.01761852f, 0.114806315f, 0.0573956046f, 0.183820669f, 0.129833668f, + 0.121884218f, 0.167678103f, 0.478080243f, 0.0943436049f, 0.561332941f + }); + auto actual = NDArrayFactory::create('c', { 5,3,4 }); + + Context ctx(1); + ctx.setInputArray(0, &hsvs); + ctx.setOutputArray(0, &actual); + ctx.setIArguments({ 1 }); + sd::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_3) { + auto hsvs = NDArrayFactory::create('c', { 4, 3 }, { + 0.705504596f, 0.793608069f, 0.65870738f, 0.848827183f, 0.920532584f, + 0.887555957f, 0.72317636f, 0.563831031f, 0.773604929f, 0.269532293f, + 0.332347751f, 0.111181192f + }); + auto expected = NDArrayFactory::create('c', { 4, 3 }, { + 0.257768334f, 0.135951888f, 0.65870738f, 0.887555957f, 0.0705317783f, + 0.811602857f, 0.485313689f, 0.337422464f, 0.773604929f, 0.0883753772f, + 0.111181192f, 0.074230373f + }); + auto actual = NDArrayFactory::create('c', { 4,3 }); + + Context ctx(1); + ctx.setInputArray(0, &hsvs); + ctx.setOutputArray(0, &actual); + + sd::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + + +TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_4) { + auto hsvs = NDArrayFactory::create('c', { 3, 4 }, { + 0.705504596f, 0.848827183f, 0.72317636f, 0.269532293f, 0.793608069f, + 0.920532584f, 0.563831031f, 0.332347751f, 0.65870738f, 0.887555957f, + 0.773604929f, 0.111181192f + }); + auto expected = NDArrayFactory::create('c', { 3, 4 }, { + 0.257768334f, 0.887555957f, 0.485313689f, 0.0883753772f, 0.135951888f, + 0.0705317783f, 0.337422464f, 0.111181192f, 0.65870738f, 0.811602857f, + 0.773604929f, 0.074230373f + }); + auto actual = NDArrayFactory::create('c', { 3, 4 }); + + Context ctx(1); + ctx.setInputArray(0, &hsvs); + ctx.setOutputArray(0, &actual); + ctx.setIArguments({ 0 }); + sd::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_5) { + + auto hsvs = NDArrayFactory::create('c', { 3 }, { + 0.705504596f, 0.793608069f, 0.65870738f + }); + auto expected = NDArrayFactory::create('c', { 3 }, { + 0.257768334f, 0.135951888f, 0.65870738f + }); + + auto actual = NDArrayFactory::create('c', { 3 }); + + Context ctx(1); + ctx.setInputArray(0, &hsvs); + ctx.setOutputArray(0, &actual); + + sd::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + + +TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_6) { + + auto hsvs = NDArrayFactory::create('c', { 3, 4 }, { + 0.705504596f, 0.848827183f, 0.72317636f, 0.269532293f, 0.793608069f, + 0.920532584f, 0.563831031f, 0.332347751f, 0.65870738f, 0.887555957f, + 0.773604929f, 0.111181192f + }); + auto rgbs = NDArrayFactory::create('c', { 3, 4 }, { + 0.257768334f, 0.887555957f, 0.485313689f, 0.0883753772f, 0.135951888f, + 0.0705317783f, 0.337422464f, 0.111181192f, 0.65870738f, 0.811602857f, + 0.773604929f, 0.074230373f + }); + + auto actual = NDArrayFactory::create('c', { 3 }); + //get subarray + NDArray subArrHsvs = hsvs.subarray({ NDIndex::all(), NDIndex::point(0) }); + subArrHsvs.reshapei({ 3 }); + NDArray expected = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) }); + expected.reshapei({ 3 }); +#if 0 + //[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER] + subArrHsvs.printShapeInfo("subArrHsvs"); +#endif + + Context ctx(1); + ctx.setInputArray(0, &subArrHsvs); + ctx.setOutputArray(0, &actual); + sd::ops::hsv_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + + + +TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_1) { + /** + generated using numpy + _rgb_to_yiq_kernel = np.array([[0.299f, 0.59590059f, 0.2115f], + [0.587f, -0.27455667f, -0.52273617f], + [0.114f, -0.32134392f, 0.31119955f]]) + nnrgbs = np.array([random() for x in range(0,3*4*5)],np.float32).reshape([5,4,3]) + out =np.tensordot(nnrgbs,_rgb_to_yiq_kernel,axes=[[len(nnrgbs.shape)-1],[0]]) + + #alternatively you could use just with apply + out_2=np.apply_along_axis(lambda x: _rgb_to_yiq_kernel.T @ x,len(nnrgbs.shape)-1,nnrgbs) + + */ + auto rgb = NDArrayFactory::create('c', { 5, 4 ,3 }, + { + 0.48055f , 0.80757356f, 0.2564435f , 0.94277316f, 0.17006584f, + 0.33366168f, 0.41727918f, 0.54528666f, 0.48942474f, 0.3305715f , + 0.98633456f, 0.00158441f, 0.97605824f, 0.02462568f, 0.14837205f, + 0.00112842f, 0.99260217f, 0.9585542f , 0.41196227f, 0.3095014f , + 0.6620493f , 0.30888894f, 0.3122602f , 0.7993488f , 0.86656475f, + 0.5997049f , 0.9776477f , 0.72481847f, 0.7835693f , 0.14649455f, + 0.3573504f , 0.33301765f, 0.7853056f , 0.25830218f, 0.59289205f, + 0.41357264f, 0.5934154f , 0.72647524f, 0.6623308f , 0.96197623f, + 0.0720306f , 0.23853847f, 0.1427159f , 0.19581454f, 0.06766324f, + 0.10614152f, 0.26093867f, 0.9584985f , 0.01258832f, 0.8160156f , + 0.56506383f, 0.08418505f, 0.86440504f, 0.6807802f , 0.20662387f, + 0.4153733f , 0.76146203f, 0.50057423f, 0.08274968f, 0.9521758f + }); + + auto expected = NDArrayFactory::create('c', { 5, 4 ,3 }, + { + 0.64696468f, -0.01777124f, -0.24070648f, 0.41975525f, 0.40788622f, + 0.21433232f, 0.50064416f, -0.05832884f, -0.04447775f, 0.67799989f, + -0.07432612f, -0.44518381f, 0.32321111f, 0.52719408f, 0.2397369f , + 0.69227005f, -0.57987869f, -0.22032876f, 0.38032767f, -0.05223263f, + 0.13137188f, 0.3667803f , -0.15853189f, 0.15085728f, 0.72258149f, + 0.03757231f, 0.17403452f, 0.69337627f, 0.16971045f, -0.21071186f, + 0.39185397f, -0.13084008f, 0.145886f , 0.47240727f, -0.1417591f , + -0.12659159f, 0.67937788f, -0.05867803f, -0.04813048f, 0.35710624f, + 0.47681283f, 0.24003804f, 0.1653288f , 0.00953913f, -0.05111816f, + 0.29417614f, -0.31640032f, 0.18433114f, 0.54718234f, -0.39812097f, + -0.24805083f, 0.61018603f, -0.40592682f, -0.22219216f, 0.39241133f, + -0.23560742f, 0.06353694f, 0.3067938f , -0.0304029f , 0.35893188f + }); + + auto actual = NDArrayFactory::create('c', { 5, 4, 3 }); + + Context ctx(1); + ctx.setInputArray(0, &rgb); + ctx.setOutputArray(0, &actual); + + sd::ops::rgb_to_yiq op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_2) { + + auto rgb = NDArrayFactory::create('c', { 5, 3, 4 }, + { + 0.48055f , 0.94277316f, 0.41727918f, 0.3305715f , 0.80757356f, + 0.17006584f, 0.54528666f, 0.98633456f, 0.2564435f , 0.33366168f, + 0.48942474f, 0.00158441f, 0.97605824f, 0.00112842f, 0.41196227f, + 0.30888894f, 0.02462568f, 0.99260217f, 0.3095014f , 0.3122602f , + 0.14837205f, 0.9585542f , 0.6620493f , 0.7993488f , 0.86656475f, + 0.72481847f, 0.3573504f , 0.25830218f, 0.5997049f , 0.7835693f , + 0.33301765f, 0.59289205f, 0.9776477f , 0.14649455f, 0.7853056f , + 0.41357264f, 0.5934154f , 0.96197623f, 0.1427159f , 0.10614152f, + 0.72647524f, 0.0720306f , 0.19581454f, 0.26093867f, 0.6623308f , + 0.23853847f, 0.06766324f, 0.9584985f , 0.01258832f, 0.08418505f, + 0.20662387f, 0.50057423f, 0.8160156f , 0.86440504f, 0.4153733f , + 0.08274968f, 0.56506383f, 0.6807802f , 0.76146203f, 0.9521758f + }); + + auto expected = NDArrayFactory::create('c', { 5, 3, 4 }, + { + 0.64696468f, 0.41975525f, 0.50064416f, 0.67799989f, -0.01777124f, + 0.40788622f, -0.05832884f, -0.07432612f, -0.24070648f, 0.21433232f, + -0.04447775f, -0.44518381f, 0.32321111f, 0.69227005f, 0.38032767f, + 0.3667803f , 0.52719408f, -0.57987869f, -0.05223263f, -0.15853189f, + 0.2397369f , -0.22032876f, 0.13137188f, 0.15085728f, 0.72258149f, + 0.69337627f, 0.39185397f, 0.47240727f, 0.03757231f, 0.16971045f, + -0.13084008f, -0.1417591f , 0.17403452f, -0.21071186f, 0.145886f , + -0.12659159f, 0.67937788f, 0.35710624f, 0.1653288f , 0.29417614f, + -0.05867803f, 0.47681283f, 0.00953913f, -0.31640032f, -0.04813048f, + 0.24003804f, -0.05111816f, 0.18433114f, 0.54718234f, 0.61018603f, + 0.39241133f, 0.3067938f , -0.39812097f, -0.40592682f, -0.23560742f, + -0.0304029f , -0.24805083f, -0.22219216f, 0.06353694f, 0.35893188f + }); + + auto actual = NDArrayFactory::create('c', { 5, 3, 4 }); + + Context ctx(1); + ctx.setInputArray(0, &rgb); + ctx.setOutputArray(0, &actual); + ctx.setIArguments({ 1 }); + sd::ops::rgb_to_yiq op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_3) { + + auto rgb = NDArrayFactory::create('c', { 4, 3 }, + { + 0.48055f , 0.80757356f, 0.2564435f , 0.94277316f, 0.17006584f, + 0.33366168f, 0.41727918f, 0.54528666f, 0.48942474f, 0.3305715f , + 0.98633456f, 0.00158441f + }); + + auto expected = NDArrayFactory::create('c', { 4, 3 }, + { + 0.64696468f, -0.01777124f, -0.24070648f, 0.41975525f, 0.40788622f, + 0.21433232f, 0.50064416f, -0.05832884f, -0.04447775f, 0.67799989f, + -0.07432612f, -0.44518381f + }); + + auto actual = NDArrayFactory::create('c', { 4, 3 }); + + Context ctx(1); + ctx.setInputArray(0, &rgb); + ctx.setOutputArray(0, &actual); + + sd::ops::rgb_to_yiq op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_4) { + + auto rgb = NDArrayFactory::create('c', { 3, 4 }, + { + 0.48055f , 0.94277316f, 0.41727918f, 0.3305715f , 0.80757356f, + 0.17006584f, 0.54528666f, 0.98633456f, 0.2564435f , 0.33366168f, + 0.48942474f, 0.00158441f + }); + + auto expected = NDArrayFactory::create('c', { 3, 4 }, + { + 0.64696468f, 0.41975525f, 0.50064416f, 0.67799989f, -0.01777124f, + 0.40788622f, -0.05832884f, -0.07432612f, -0.24070648f, 0.21433232f, + -0.04447775f, -0.44518381f + }); + + auto actual = NDArrayFactory::create('c', { 3, 4 }); + + Context ctx(1); + ctx.setInputArray(0, &rgb); + ctx.setOutputArray(0, &actual); + ctx.setIArguments({ 0 }); + sd::ops::rgb_to_yiq op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); +} + + + +TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_5) { + + auto rgbs = NDArrayFactory::create('c', { 3 }, + { 0.48055f , 0.80757356f, 0.2564435f }); + auto expected = NDArrayFactory::create('c', { 3 }, + { 0.64696468f, -0.01777124f, -0.24070648f, }); + + + auto actual = NDArrayFactory::create('c', { 3 }); + + Context ctx(1); + ctx.setInputArray(0, &rgbs); + ctx.setOutputArray(0, &actual); + + sd::ops::rgb_to_yiq op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); +} + +TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_6) { + + auto rgbs = NDArrayFactory::create('c', { 3, 4 }, + { + 0.48055f , 0.94277316f, 0.41727918f, 0.3305715f , 0.80757356f, + 0.17006584f, 0.54528666f, 0.98633456f, 0.2564435f , 0.33366168f, + 0.48942474f, 0.00158441f + }); + + auto yiqs = NDArrayFactory::create('c', { 3, 4 }, + { + 0.64696468f, 0.41975525f, 0.50064416f, 0.67799989f, -0.01777124f, + 0.40788622f, -0.05832884f, -0.07432612f, -0.24070648f, 0.21433232f, + -0.04447775f, -0.44518381f + }); + + //get subarray + NDArray subArrRgbs = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) }); + NDArray expected = yiqs.subarray({ NDIndex::all(), NDIndex::point(0) }); + subArrRgbs.reshapei({ 3 }); + expected.reshapei({ 3 }); +#if 0 + //[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER] + subArrRgbs.printShapeInfo("subArrRgbs"); +#endif + auto actual = NDArrayFactory::create('c', { 3 }); + + Context ctx(1); + ctx.setInputArray(0, &subArrRgbs); + ctx.setOutputArray(0, &actual); + sd::ops::rgb_to_yiq op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_1) { + + auto yiqs = NDArrayFactory::create('c', { 5, 4, 3 }, { + 0.775258899f, -0.288912386f, -0.132725924f, 0.0664454922f, -0.212469354f, + 0.455438733f, 0.418221354f, 0.349350512f, 0.145902053f, 0.947576523f, + -0.471601307f, 0.263960421f, 0.700227439f, 0.32434237f, -0.278446227f, + 0.130805135f, -0.438441873f, 0.187127829f, 0.0276055578f, -0.179727226f, + 0.305075705f, 0.716282248f, 0.278215706f, -0.44586885f, 0.76971364f, + 0.131288841f, -0.141177326f, 0.900081575f, -0.0788725987f, 0.14756602f, + 0.387832165f, 0.229834676f, 0.47921446f, 0.632930398f, 0.0443540029f, + -0.268817365f, 0.0977194682f, -0.141669706f, -0.140715122f, 0.946808815f, + -0.52525419f, -0.106209636f, 0.659476519f, 0.391066104f, 0.426448852f, + 0.496989518f, -0.283434421f, -0.177366048f, 0.715208411f, -0.496444523f, + 0.189553142f, 0.616444945f, 0.345852494f, 0.447739422f, 0.224696323f, + 0.451372236f, 0.298027098f, 0.446561724f, -0.187599331f, -0.448159873f + }); + auto expected = NDArrayFactory::create('c', { 5, 4, 3 }, { + 0.416663059f, 0.939747555f, 0.868814286f, 0.146075352f, -0.170521997f, + 1.07776645f, 0.842775284f, 0.228765106f, 0.280231822f, 0.660605291f, + 0.905021825f, 1.91936605f, 0.837427991f, 0.792213732f, -0.133271854f, + -0.17216571f, 0.128957025f, 0.934955336f, 0.0451873479f, -0.120952621f, + 0.746436225f, 0.705446224f, 0.929172217f, -0.351493549f, 0.807577594f, + 0.825371955f, 0.383812296f, 0.916293093f, 0.82603058f, 1.23885956f, + 0.905059196f, 0.015164554f, 0.950156781f, 0.508443732f, 0.794845279f, + 0.12571529f, -0.125074273f, 0.227326869f, 0.0147000261f, 0.378735409f, + 1.15842402f, 1.34712305f, 1.2980804f, 0.277102016f, 0.953435072f, + 0.115916842f, 0.688879376f, 0.508405162f, 0.35829352f, 0.727568094f, + 1.58768577f, 1.22504294f, 0.232589777f, 0.996727258f, 0.841224629f, + -0.0909671176f, 0.233051388f, -0.0110094378f, 0.787642119f, -0.109582274f + }); + auto actual = NDArrayFactory::create('c', { 5, 4, 3 }); + + Context ctx(1); + ctx.setInputArray(0, &yiqs); + ctx.setOutputArray(0, &actual); + + sd::ops::yiq_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_2) { + + auto yiqs = NDArrayFactory::create('c', { 5, 3, 4 }, { + 0.775258899f, 0.0664454922f, 0.418221354f, 0.947576523f, -0.288912386f, + -0.212469354f, 0.349350512f, -0.471601307f, -0.132725924f, 0.455438733f, + 0.145902053f, 0.263960421f, 0.700227439f, 0.130805135f, 0.0276055578f, + 0.716282248f, 0.32434237f, -0.438441873f, -0.179727226f, 0.278215706f, + -0.278446227f, 0.187127829f, 0.305075705f, -0.44586885f, 0.76971364f, + 0.900081575f, 0.387832165f, 0.632930398f, 0.131288841f, -0.0788725987f, + 0.229834676f, 0.0443540029f, -0.141177326f, 0.14756602f, 0.47921446f, + -0.268817365f, 0.0977194682f, 0.946808815f, 0.659476519f, 0.496989518f, + -0.141669706f, -0.52525419f, 0.391066104f, -0.283434421f, -0.140715122f, + -0.106209636f, 0.426448852f, -0.177366048f, 0.715208411f, 0.616444945f, + 0.224696323f, 0.446561724f, -0.496444523f, 0.345852494f, 0.451372236f, + -0.187599331f, 0.189553142f, 0.447739422f, 0.298027098f, -0.448159873f + }); + auto expected = NDArrayFactory::create('c', { 5, 3, 4 }, { + 0.416663059f, 0.146075352f, 0.842775284f, 0.660605291f, 0.939747555f, + -0.170521997f, 0.228765106f, 0.905021825f, 0.868814286f, 1.07776645f, + 0.280231822f, 1.91936605f, 0.837427991f, -0.17216571f, 0.0451873479f, + 0.705446224f, 0.792213732f, 0.128957025f, -0.120952621f, 0.929172217f, + -0.133271854f, 0.934955336f, 0.746436225f, -0.351493549f, 0.807577594f, + 0.916293093f, 0.905059196f, 0.508443732f, 0.825371955f, 0.82603058f, + 0.015164554f, 0.794845279f, 0.383812296f, 1.23885956f, 0.950156781f, + 0.12571529f, -0.125074273f, 0.378735409f, 1.2980804f, 0.115916842f, + 0.227326869f, 1.15842402f, 0.277102016f, 0.688879376f, 0.0147000261f, + 1.34712305f, 0.953435072f, 0.508405162f, 0.35829352f, 1.22504294f, + 0.841224629f, -0.0110094378f, 0.727568094f, 0.232589777f, -0.0909671176f, + 0.787642119f, 1.58768577f, 0.996727258f, 0.233051388f, -0.109582274f + }); + auto actual = NDArrayFactory::create('c', { 5, 3, 4 }); + + Context ctx(1); + ctx.setInputArray(0, &yiqs); + ctx.setOutputArray(0, &actual); + ctx.setIArguments({ 1 }); + sd::ops::yiq_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_3) { + + auto yiqs = NDArrayFactory::create('c', { 4, 3 }, { + 0.775258899f, -0.288912386f, -0.132725924f, 0.0664454922f, -0.212469354f, + 0.455438733f, 0.418221354f, 0.349350512f, 0.145902053f, 0.947576523f, + -0.471601307f, 0.263960421f + }); + auto expected = NDArrayFactory::create('c', { 4, 3 }, { + 0.416663059f, 0.939747555f, 0.868814286f, 0.146075352f, -0.170521997f, + 1.07776645f, 0.842775284f, 0.228765106f, 0.280231822f, 0.660605291f, + 0.905021825f, 1.91936605f + }); + auto actual = NDArrayFactory::create('c', { 4, 3 }); + + Context ctx(1); + ctx.setInputArray(0, &yiqs); + ctx.setOutputArray(0, &actual); + + sd::ops::yiq_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_4) { + + auto yiqs = NDArrayFactory::create('c', { 3, 4 }, { + 0.775258899f, 0.0664454922f, 0.418221354f, 0.947576523f, -0.288912386f, + -0.212469354f, 0.349350512f, -0.471601307f, -0.132725924f, 0.455438733f, + 0.145902053f, 0.263960421f + }); + auto expected = NDArrayFactory::create('c', { 3, 4 }, { + 0.416663059f, 0.146075352f, 0.842775284f, 0.660605291f, 0.939747555f, + -0.170521997f, 0.228765106f, 0.905021825f, 0.868814286f, 1.07776645f, + 0.280231822f, 1.91936605f + }); + auto actual = NDArrayFactory::create('c', { 3, 4 }); + + Context ctx(1); + ctx.setInputArray(0, &yiqs); + ctx.setOutputArray(0, &actual); + ctx.setIArguments({ 0 }); + sd::ops::yiq_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); + +} + +TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_5) { + + auto yiqs = NDArrayFactory::create('c', { 3 }, { + 0.775258899f, -0.288912386f, -0.132725924f + }); + auto expected = NDArrayFactory::create('c', { 3 }, { + 0.416663059f, 0.939747555f, 0.868814286f + }); + auto actual = NDArrayFactory::create('c', { 3 }); + + Context ctx(1); + ctx.setInputArray(0, &yiqs); + ctx.setOutputArray(0, &actual); + + sd::ops::yiq_to_rgb op; + auto status = op.execute(&ctx); +#if 0 + actual.printBuffer("actual"); + expected.printBuffer("expected"); +#endif + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); +} + +TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_6) { + + auto yiqs = NDArrayFactory::create('c', { 3, 4 }, { + 0.775258899f, 0.0664454922f, 0.418221354f, 0.947576523f, -0.288912386f, + -0.212469354f, 0.349350512f, -0.471601307f, -0.132725924f, 0.455438733f, + 0.145902053f, 0.263960421f + }); + auto rgbs = NDArrayFactory::create('c', { 3, 4 }, { + 0.416663059f, 0.146075352f, 0.842775284f, 0.660605291f, 0.939747555f, + -0.170521997f, 0.228765106f, 0.905021825f, 0.868814286f, 1.07776645f, + 0.280231822f, 1.91936605f + }); + + //get subarray + NDArray subArrYiqs = yiqs.subarray({ NDIndex::all(), NDIndex::point(0) }); + NDArray expected = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) }); + subArrYiqs.reshapei({ 3 }); + expected.reshapei({ 3 }); +#if 0 + //[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER] + subArrYiqs.printShapeInfo("subArrYiqs"); +#endif + auto actual = NDArrayFactory::create('c', { 3 }); + + Context ctx(1); + ctx.setInputArray(0, &subArrYiqs); + ctx.setOutputArray(0, &actual); + sd::ops::yiq_to_rgb op; + auto status = op.execute(&ctx); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.equalsTo(actual)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbynorm_1) { + auto x= NDArrayFactory::create('c', {2, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0}); + auto exp= NDArrayFactory::create('c', {2, 3}, {-2.4, 0.0, 0.0, 3.2, 0.0, 0.0}); + + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {4.0}, {}); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +TEST_F(DeclarableOpsTests16, clipbynorm_2) { + auto x= NDArrayFactory::create('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f}); + auto exp= NDArrayFactory::create('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f}); + + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {6.0}, {}); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbynorm_3) { + + auto x = NDArrayFactory::create('c', {3, 5}); + auto unities = NDArrayFactory::create('c', {3, 1}, {1., 1., 1.}); + auto scale = NDArrayFactory::create('c', {3, 1}, {1.1, 1., 0.9}); + + x.linspace(100.); + + auto xNorm1 = x.reduceAlongDimension(reduce::Norm2, {1}, true); + x /= xNorm1; + xNorm1 = x.reduceAlongDimension(reduce::Norm2,{1}, true); + + ASSERT_TRUE(unities.isSameShape(xNorm1)); + ASSERT_TRUE(unities.equalsTo(xNorm1)); + + x *= scale; + xNorm1 = x.reduceAlongDimension(reduce::Norm2, {1}, true); + + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {1.0}, {1}); + auto z = result.at(0); + + auto zNorm1 = z->reduceAlongDimension(reduce::Norm2, {1}, true); + auto exp = NDArrayFactory::create('c', {3, 1}, {1., 1., xNorm1.e(2)}); + + ASSERT_TRUE(exp.isSameShape(&zNorm1)); + ASSERT_TRUE(exp.equalsTo(&zNorm1)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbynorm_4) { + + auto x = NDArrayFactory::create('c', {3, 5}, {0.7044955, 0.55606544, 0.15833677, 0.001874401, 0.61595726, 0.3924779, 0.7414847, 0.4127324, 0.24026828, 0.26093036, 0.46741188, 0.01863421, 0.08528871, 0.529365, 0.5510694}); + auto exp = NDArrayFactory::create('c', {3, 5}, {0.405392, 0.319980, 0.091113, 0.001079, 0.354444, 0.225846, 0.426676, 0.237501, 0.138259, 0.150149, 0.268965, 0.010723, 0.049078, 0.304615, 0.317105}); + + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {1.f}, {}); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbynorm_5) { + + // auto x = NDArrayFactory::create('c', {3, 5}, {1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5}); + auto x = NDArrayFactory::create('c', {3, 5}); + auto exp = NDArrayFactory::create('c', {3, 5}, {1., 2., 2.89271, 3.50524, 4.00892, 6., 7., 7.71389, 7.88678, 8.01784, 11., 12., 12.53507, 12.26833, 12.02676}); + // auto exp = NDArrayFactory::create('c', {3, 5}, {1,1,1,1,1,1,1,1,1,1,1,1,1,1,1}); + + x.linspace(1); + + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {15.f}, {0}); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbynorm_6) { + + auto x = NDArrayFactory::create('c', {3, 5}); + auto exp = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 4.95434, 5.78006, 6.60578, 7.43151, 8.25723, 5.64288, 6.15587, 6.66886, 7.18185, 7.69484}); + + x.linspace(1); + + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {15.f}, {1}); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbynorm_7) { + + auto x = NDArrayFactory::create('c', {3, 5}); + auto exp = NDArrayFactory::create('c', {3, 5}, {0.42597, 0.85194, 1.27791, 1.70389, 2.12986, 2.55583, 2.9818 , 3.40777, 3.83374, 4.25971, 4.68569, 5.11166, 5.53763, 5.9636 , 6.38957}); + + x.linspace(1); + + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {15.f}, {0,1}); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbynorm_8) { + + auto x = NDArrayFactory::create('c', {3, 5}); + auto exp = NDArrayFactory::create('c', {3, 5}, {0.42597, 0.85194, 1.27791, 1.70389, 2.12986, 2.55583, 2.9818 , 3.40777, 3.83374, 4.25971, 4.68569, 5.11166, 5.53763, 5.9636 , 6.38957}); + + x.linspace(1); + + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {15.}, {}); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbynorm_9) { + + auto x = NDArrayFactory::create('c', {2}, {3., 4.}); + auto exp = NDArrayFactory::create('c', {2}, {2.4, 3.2}); + + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {4.}, {}); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbynorm_10) { + + auto x = NDArrayFactory::create(6.); + auto exp = NDArrayFactory::create(5.); + + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {5.}, {}); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbynorm_11) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}, {1., 2., 3., 4., 4.44787, 5.33745, 6.22702, 7.1166 , 6.33046, 7.03384, 7.73723, 8.44061, + 13., 14., 15., 16., 15.12277, 16.01235, 16.90192, 17.7915 ,14.77107, 15.47446, 16.17784, 16.88123}); + + x.linspace(1); + + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {35.}, {0, 2}); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbynorm_12) { + auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5,6, 7, 8, 9}); + auto e = NDArrayFactory::create('c', {3, 3}, {0.03198684, 0.06397368, 0.09596053, 0.12794736, 0.15993419, 0.19192106, 0.22390789, 0.25589472, 0.28788155}); + + sd::ops::clipbynorm op; + auto result = op.evaluate({&x}, {0.54}, {}); + + ASSERT_EQ(e, *result.at(0)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbynorm_13) { + + const int bS = 5; + const int nOut = 4; + const int axis = 0; + const double clip = 2.; + + auto x = NDArrayFactory::create('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.897 ,0.173 ,0.931 ,0.736 ,0.540 ,0.953 ,0.278 ,0.573 ,0.787 ,0.320 ,0.776 ,0.338 ,0.311 ,0.835 ,0.909 ,0.890 ,0.290}); // uniform random in range [0,1] + auto colVect = NDArrayFactory::create('c', {bS, 1}, {0.9, 0.95, 1.00, 1.05, 1.1}); + auto expect = NDArrayFactory::create('c', {bS, nOut}); + + auto norm2 = x.reduceAlongDimension(reduce::Norm2, {axis}, true); // norm2 has shape [1, nOut] + + auto y = ( (x / norm2) * clip) * colVect ; + auto temp = (x / norm2) * clip; + + for (int j = 0; j < nOut; ++j) { + auto yCol = y({0,0, j,j+1}); + const double norm2Col = yCol.reduceNumber(reduce::Norm2).e(0); + if (norm2Col <= clip) + expect({0,0, j,j+1}).assign(yCol); + else + expect({0,0, j,j+1}).assign ( yCol * (clip / norm2Col) ); + } + + sd::ops::clipbynorm op; + auto result = op.evaluate({&y}, {clip}, {axis}); + auto outFF = result.at(0); + + ASSERT_TRUE(expect.isSameShape(outFF)); + ASSERT_TRUE(expect.equalsTo(outFF)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbynorm_bp_1) { + + const int bS = 2; + const int nOut = 3; + const double clip = 0.7; + + auto x = NDArrayFactory::create('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1] + auto gradO = NDArrayFactory::create('c', {bS, nOut}); + + const OpArgsHolder argsHolderFF({&x}, {clip}, {}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {}); + + sd::ops::clipbynorm opFF; + sd::ops::clipbynorm_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbynorm_bp_2) { + + const int bS = 2; + const int nOut = 3; + const int axis = 0; + const double clip = 0.7; + + auto x = NDArrayFactory::create('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1] + auto gradO = NDArrayFactory::create('c', {bS, nOut}); + + const OpArgsHolder argsHolderFF({&x}, {clip}, {axis}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {axis}); + + sd::ops::clipbynorm opFF; + sd::ops::clipbynorm_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbynorm_bp_3) { + + const int bS = 2; + const int nOut = 3; + const int axis = 1; + const double clip = 1.; + + auto x = NDArrayFactory::create('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1] + auto gradO = NDArrayFactory::create('c', {bS, nOut}); + + const OpArgsHolder argsHolderFF({&x}, {clip}, {axis}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {axis}); + + sd::ops::clipbynorm opFF; + sd::ops::clipbynorm_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbyavgnorm_1) { + auto x = NDArrayFactory::create('c', {2, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0}); + auto exp = NDArrayFactory::create('c', {2, 3}, {-2.88, 0.0, 0.0, 3.84, 0.0, 0.0}); + + sd::ops::clipbyavgnorm op; + auto result = op.evaluate({&x}, {0.8}, {}); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbyavgnorm_2) { + auto x= NDArrayFactory::create('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f}); + auto exp= NDArrayFactory::create('c', {2, 3}, {-3.f, 0.0f, 0.0f, 4.f, 0.0f, 0.0f}); + + sd::ops::clipbyavgnorm op; + auto result = op.evaluate({&x}, {0.9}, {}); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + + + + + + + + + + + + + + + + + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbyavgnorm_bp_1) { + + const int bS = 2; + const int nOut = 3; + const double clip = 0.7; + + auto x = NDArrayFactory::create('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1] + auto gradO = NDArrayFactory::create('c', {bS, nOut}); + + const OpArgsHolder argsHolderFF({&x}, {clip}, {}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {}); + + sd::ops::clipbyavgnorm opFF; + sd::ops::clipbyavgnorm_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbyavgnorm_bp_2) { + + const int bS = 2; + const int nOut = 3; + const int axis = 1; + const double clip = 1.; + + auto x = NDArrayFactory::create('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1] + auto gradO = NDArrayFactory::create('c', {bS, nOut}); + + const OpArgsHolder argsHolderFF({&x}, {clip}, {axis}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {axis}); + + sd::ops::clipbyavgnorm opFF; + sd::ops::clipbyavgnorm_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests16, clipbyavgnorm_bp_3) { + + NDArray x('c', {2, 3, 4}, {-0.14 ,0.96 ,0.47 ,-0.98 ,0.03 ,0.95 ,0.33 ,-0.97 ,0.59 ,-0.92 ,-0.12 ,-0.33 ,0.82 ,-0.76 ,-0.69 ,-0.95 ,-0.77 ,0.25 ,-0.35 ,0.94 ,0.50 ,0.04 ,0.61 ,0.99}, sd::DataType::DOUBLE); + NDArray gradO('c', {2, 3, 4}, sd::DataType::DOUBLE); + + const OpArgsHolder argsHolderFF({&x}, {0.7}, {0,2}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {0.7}, {0,2}); + + sd::ops::clipbyavgnorm opFF; + sd::ops::clipbyavgnorm_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests17.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests17.cpp new file mode 100644 index 000000000..be157ac40 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests17.cpp @@ -0,0 +1,94 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include + + +using namespace sd; + + +class DeclarableOpsTests17 : public testing::Test { +public: + + DeclarableOpsTests17() { + printf("\n"); + fflush(stdout); + } +}; + +TEST_F(DeclarableOpsTests17, test_sparse_to_dense_1) { + auto values = NDArrayFactory::create({1.f, 2.f, 3.f}); + auto shape = NDArrayFactory::create({3, 3}); + auto ranges = NDArrayFactory::create({0,0, 1,1, 2,2}); + auto def = NDArrayFactory::create(0.f); + auto exp = NDArrayFactory::create('c', {3, 3}, {1.f,0.f,0.f, 0.f,2.f,0.f, 0.f,0.f,3.f}); + + + sd::ops::compat_sparse_to_dense op; + auto result = op.evaluate({&ranges, &shape, &values, &def}); + ASSERT_EQ(Status::OK(), result.status()); +} + +TEST_F(DeclarableOpsTests17, test_sparse_to_dense_2) { + auto values = NDArrayFactory::string({3}, {"alpha", "beta", "gamma"}); + auto shape = NDArrayFactory::create({3, 3}); + auto ranges = NDArrayFactory::create({0,0, 1,1, 2,2}); + auto def = NDArrayFactory::string("d"); + auto exp = NDArrayFactory::string( {3, 3}, {"alpha","d","d", "d","beta","d", "d","d","gamma"}); + + + sd::ops::compat_sparse_to_dense op; + auto result = op.evaluate({&ranges, &shape, &values, &def}); + ASSERT_EQ(Status::OK(), result.status()); + +} + +TEST_F(DeclarableOpsTests17, test_compat_string_split_1) { + auto x = NDArrayFactory::string( {2}, {"first string", "second"}); + auto delimiter = NDArrayFactory::string(" "); + + auto exp0 = NDArrayFactory::create({0,0, 0,1, 1,0}); + auto exp1 = NDArrayFactory::string( {3}, {"first", "string", "second"}); + + sd::ops::compat_string_split op; + auto result = op.evaluate({&x, &delimiter}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(2, result.size()); + + auto z0 = result.at(0); + auto z1 = result.at(1); + + ASSERT_TRUE(exp0.isSameShape(z0)); + ASSERT_TRUE(exp1.isSameShape(z1)); + + ASSERT_EQ(exp0, *z0); + ASSERT_EQ(exp1, *z1); + +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests18.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests18.cpp new file mode 100644 index 000000000..d6a3bb41d --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests18.cpp @@ -0,0 +1,1683 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + + + // + // @author raver119@gmail.com + // + +#include "testlayers.h" +#include +#include +#include +#include +#include + + +using namespace sd; + + +class DeclarableOpsTests18 : public testing::Test { +public: + + DeclarableOpsTests18() { + printf("\n"); + fflush(stdout); + } +}; + +TEST_F(DeclarableOpsTests18, test_bitcast_1) { + auto x = NDArrayFactory::create(0.23028551377579154); + auto z = NDArrayFactory::create(0); + auto e = NDArrayFactory::create(4597464930322771456L); + + sd::ops::bitcast op; + auto status = op.execute({ &x }, { &z }, {}, { (Nd4jLong)sd::DataType::INT64 }, {}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); +} + +TEST_F(DeclarableOpsTests18, test_tanh_1) { + auto x = NDArrayFactory::create('c', { 8 }, { 0.23f, -0.23f, 0.35f, -0.35f, 0.64f, -0.64f, 100000.f, -100000.f }); + auto z = x.ulike(); + auto e = NDArrayFactory::create('c', { 8 }, { 0.226028f, -0.226028f, 0.336376f, -0.336376f, 0.564900f, -0.564900f, 1.f, -1.f }); + + sd::ops::tanh op; + op.execute({ &x }, { &z }); + + ASSERT_EQ(e, z); +} +///////////////////////////////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, test_tanh_2) { + + NDArray x('c', { 2, 2, 3, 3, 4, 4 }, sd::DataType::FLOAT32); + NDArray z('c', { 2, 2, 3, 3, 4, 4 }, sd::DataType::FLOAT32); + + x.linspace(-1., 0.003); + + NDArray e('c', { 2, 2, 3, 3, 4, 4 }, { -0.761594, -0.760331, -0.759063, -0.757788, -0.756508, -0.755222, -0.753930, -0.752633, -0.751329, -0.750020, -0.748704, -0.747383, -0.746056, -0.744723, -0.743383, -0.742038, -0.740687, -0.739330, -0.737967, -0.736598, -0.735222, -0.733841, -0.732453, -0.731060, -0.729660, -0.728254, -0.726842, -0.725424, -0.724000, -0.722569, -0.721132, -0.719689, -0.718240, -0.716784, -0.715323, -0.713854, -0.712380, -0.710899, -0.709412, -0.707919, -0.706419, -0.704913, -0.703401, -0.701882, -0.700357, -0.698825, -0.697287, -0.695742, -0.694191, -0.692634, -0.691069, -0.689499, -0.687922, -0.686338, -0.684748, -0.683152, -0.681548, -0.679939, -0.678322, -0.676699, -0.675070, -0.673434, -0.671791, -0.670142, -0.668486, -0.666823, -0.665153, -0.663477, -0.661795, -0.660105, -0.658409, -0.656706, -0.654997, -0.653280, -0.651557, -0.649827, -0.648091, -0.646348, -0.644597, -0.642841, -0.641077, -0.639306, -0.637529, -0.635745, -0.633954, -0.632157, -0.630352, -0.628541, -0.626722, -0.624897, -0.623065, -0.621227, -0.619381, -0.617528, -0.615669, -0.613803, -0.611929, -0.610049, -0.608162, -0.606269, -0.604368, -0.602460, -0.600546, -0.598624, -0.596696, -0.594760, -0.592818, -0.590869, -0.588913, -0.586950, -0.584980, -0.583003, -0.581019, -0.579029, -0.577031, -0.575026, -0.573015, -0.570996, -0.568971, -0.566939, -0.564900, -0.562853, -0.560800, -0.558740, -0.556674, -0.554600, -0.552519, -0.550431, -0.548337, -0.546235, -0.544127, -0.542012, -0.539890, -0.537761, -0.535625, -0.533482, -0.531332, -0.529176, -0.527013, -0.524842, -0.522665, -0.520482, -0.518291, -0.516093, -0.513889, -0.511678, -0.509460, -0.507235, -0.505004, -0.502765, -0.500520, -0.498268, -0.496010, -0.493745, -0.491472, -0.489194, -0.486908, -0.484616, -0.482318, -0.480012, -0.477700, -0.475381, -0.473056, -0.470724, -0.468385, -0.466040, -0.463689, -0.461330, -0.458966, -0.456594, -0.454216, -0.451832, -0.449441, -0.447044, -0.444640, -0.442230, -0.439814, -0.437391, -0.434962, -0.432526, -0.430084, -0.427636, -0.425181, -0.422721, -0.420254, -0.417780, -0.415301, -0.412815, -0.410323, -0.407825, -0.405321, -0.402811, -0.400295, -0.397773, -0.395244, -0.392710, -0.390170, -0.387623, -0.385071, -0.382513, -0.379949, -0.377379, -0.374803, -0.372222, -0.369635, -0.367042, -0.364443, -0.361839, -0.359229, -0.356613, -0.353992, -0.351365, -0.348732, -0.346095, -0.343451, -0.340802, -0.338148, -0.335488, -0.332823, -0.330153, -0.327477, -0.324796, -0.322110, -0.319419, -0.316723, -0.314021, -0.311314, -0.308602, -0.305886, -0.303164, -0.300437, -0.297705, -0.294969, -0.292227, -0.289481, -0.286730, -0.283975, -0.281214, -0.278449, -0.275679, -0.272905, -0.270126, -0.267343, -0.264555, -0.261763, -0.258966, -0.256165, -0.253360, -0.250550, -0.247737, -0.244919, -0.242097, -0.239270, -0.236440, -0.233606, -0.230768, -0.227925, -0.225079, -0.222229, -0.219376, -0.216518, -0.213657, -0.210792, -0.207923, -0.205051, -0.202176, -0.199297, -0.196414, -0.193528, -0.190639, -0.187746, -0.184850, -0.181951, -0.179049, -0.176144, -0.173235, -0.170324, -0.167409, -0.164492, -0.161572, -0.158649, -0.155723, -0.152794, -0.149863, -0.146929, -0.143992, -0.141053, -0.138112, -0.135168, -0.132221, -0.129273, -0.126322, -0.123368, -0.120413, -0.117455, -0.114496, -0.111534, -0.108570, -0.105605, -0.102637, -0.099668, -0.096697, -0.093724, -0.090750, -0.087774, -0.084796, -0.081817, -0.078836, -0.075854, -0.072871, -0.069886, -0.066900, -0.063913, -0.060924, -0.057935, -0.054945, -0.051953, -0.048961, -0.045968, -0.042974, -0.039979, -0.036983, -0.033987, -0.030990, -0.027993, -0.024995, -0.021996, -0.018998, -0.015999, -0.012999, -0.010000, -0.007000, -0.004000, -0.001000, 0.002000, 0.005000, 0.008000, 0.011000, 0.013999, 0.016998, 0.019997, 0.022996, 0.025994, 0.028992, 0.031989, 0.034986, 0.037982, 0.040977, 0.043972, 0.046965, 0.049958, 0.052950, 0.055942, 0.058932, 0.061921, 0.064909, 0.067895, 0.070881, 0.073865, 0.076848, 0.079830, 0.082810, 0.085789, 0.088766, 0.091741, 0.094715, 0.097687, 0.100658, 0.103627, 0.106594, 0.109558, 0.112521, 0.115482, 0.118441, 0.121398, 0.124353, 0.127305, 0.130256, 0.133204, 0.136149, 0.139092, 0.142033, 0.144971, 0.147907, 0.150840, 0.153771, 0.156698, 0.159623, 0.162545, 0.165465, 0.168381, 0.171294, 0.174205, 0.177112, 0.180017, 0.182918, 0.185816, 0.188711, 0.191602, 0.194490, 0.197375, 0.200257, 0.203135, 0.206009, 0.208880, 0.211747, 0.214611, 0.217471, 0.220327, 0.223180, 0.226028, 0.228873, 0.231714, 0.234551, 0.237384, 0.240213, 0.243038, 0.245858, 0.248675, 0.251487, 0.254296, 0.257099, 0.259899, 0.262694, 0.265485, 0.268271, 0.271053, 0.273830, 0.276603, 0.279371, 0.282135, 0.284894, 0.287648, 0.290397, 0.293142, 0.295882, 0.298617, 0.301347, 0.304072, 0.306792, 0.309507, 0.312217, 0.314922, 0.317622, 0.320317, 0.323006, 0.325691, 0.328370, 0.331044, 0.333712, 0.336376, 0.339033, 0.341686, 0.344333, 0.346974, 0.349611, 0.352241, 0.354866, 0.357485, 0.360099, 0.362707, 0.365310, 0.367907, 0.370498, 0.373083, 0.375663, 0.378236, 0.380804, 0.383366, 0.385922, 0.388473, 0.391017, 0.393555, 0.396088, 0.398614, 0.401134, 0.403649, 0.406157, 0.408659, 0.411155, 0.413644, 0.416128, 0.418605, 0.421077, 0.423542, 0.426000, 0.428453, 0.430899, 0.433339, 0.435772, 0.438199, 0.440620, 0.443034, 0.445442, 0.447844, 0.450239, 0.452628, 0.455010, 0.457385, 0.459755, 0.462117, 0.464473, 0.466823, 0.469166, 0.471502, 0.473832, 0.476155, 0.478471, 0.480781, 0.483085, 0.485381, 0.487671, 0.489954, 0.492231, 0.494500, 0.496763, 0.499020, 0.501269, 0.503512, 0.505748, 0.507977, 0.510200, 0.512416, 0.514624, 0.516827, 0.519022, 0.521210, 0.523392, 0.525567, 0.527735, 0.529896, 0.532050, 0.534197, 0.536338, 0.538471, 0.540598, 0.542718, 0.544831, 0.546937, 0.549036, 0.551128, 0.553213, 0.555292, 0.557363, 0.559428, 0.561486, 0.563536, 0.565580, 0.567617, 0.569647, 0.571670, 0.573686, 0.575695, 0.577697, 0.579693, 0.581681, 0.583663, 0.585637, 0.587605, 0.589566, 0.591519, 0.593466, 0.595406, 0.597339, 0.599265, 0.601184, 0.603097, 0.605002, 0.606901, 0.608792, 0.610677, 0.612555, 0.614425, 0.616289, 0.618147, 0.619997 }, sd::DataType::FLOAT32); + + sd::ops::tanh op; + op.execute({ &x }, { &z }); + ASSERT_EQ(e, z); +} +///////////////////////////////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, test_tanh_bp) { + + NDArray x('c', { 2, 3, 4 }, sd::DataType::FLOAT32); + NDArray dLdz('c', { 2, 3, 4 }, sd::DataType::FLOAT32); + NDArray dLdx('c', { 2, 3, 4 }, sd::DataType::FLOAT32); + + x.linspace(-1., 0.003); + dLdz.linspace(0.01, 0.01); + + NDArray e('c', { 2, 3, 4 }, { 0.004200, 0.008438, 0.012715, 0.017030, 0.021385, 0.025778, 0.030211, 0.034684, 0.039195, 0.043747, 0.048339, 0.052970, 0.057642, 0.062354, 0.067107, 0.071901, 0.076735, 0.081610, 0.086527, 0.091485, 0.096484, 0.101525, 0.106608, 0.111732 }, sd::DataType::FLOAT32); + + sd::ops::tanh_bp op; + op.execute({ &x, &dLdz }, { &dLdx }); + ASSERT_EQ(e, dLdx); +} +///////////////////////////////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, test_tanh_bp2) { + + NDArray x('f', { 2, 3, 4 }, sd::DataType::FLOAT32); + NDArray dLdz('f', { 2, 3, 4 }, sd::DataType::FLOAT32); + NDArray dLdx('f', { 2, 3, 4 }, sd::DataType::FLOAT32); + + x.linspace(-1., 0.003); + dLdz.linspace(0.01, 0.01); + + NDArray exp('c', { 2, 3, 4 }, { 0.004200, 0.008438, 0.012715, 0.017030, 0.021385, 0.025778, 0.030211, 0.034684, 0.039195, 0.043747, 0.048339, 0.052970, 0.057642, 0.062354, 0.067107, 0.071901, 0.076735, 0.081610, 0.086527, 0.091485, 0.096484, 0.101525, 0.106608, 0.111732 }, sd::DataType::FLOAT32); + NDArray e('f', { 2, 3, 4 }, sd::DataType::FLOAT32); + e.assign(exp); + + sd::ops::tanh_bp op; + op.execute({ &x, &dLdz }, { &dLdx }); + ASSERT_EQ(e, dLdx); +} +///////////////////////////////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, test_tanh_bp3) { + + NDArray x('f', { 2, 2, 3, 3, 4, 4 }, sd::DataType::FLOAT32); + NDArray dLdz('f', { 2,2, 3,3, 4,4 }, sd::DataType::FLOAT32); + NDArray dLdx('f', { 2, 2, 3, 3, 4, 4 }, sd::DataType::FLOAT32); + + x.linspace(-1.5, 0.005); + dLdz.linspace(-1., 0.01); + + NDArray exp('c', { 2, 2, 3, 3, 4, 4 }, { -0.180707, -0.180525, -0.180324, -0.180103, -0.179861, -0.179599, -0.179315, -0.179009, -0.178682, -0.178333, -0.177961, -0.177566, -0.177148, -0.176706, -0.176240, -0.175750, -0.175236, -0.174696, -0.174130, -0.173539, -0.172922, -0.172278, -0.171607, -0.170909, -0.170183, -0.169429, -0.168646, -0.167834, -0.166993, -0.166123, -0.165222, -0.164290, -0.163327, -0.162334, -0.161308, -0.160250, -0.159159, -0.158035, -0.156877, -0.155686, -0.154460, -0.153199, -0.151903, -0.150571, -0.149203, -0.147798, -0.146356, -0.144876, -0.143359, -0.141803, -0.140207, -0.138573, -0.136898, -0.135183, -0.133428, -0.131630, -0.129792, -0.127910, -0.125986, -0.124019, -0.122008, -0.119953, -0.117853, -0.115708, -0.113517, -0.111279, -0.108996, -0.106665, -0.104286, -0.101859, -0.099383, -0.096859, -0.094284, -0.091660, -0.088984, -0.086258, -0.083480, -0.080649, -0.077766, -0.074830, -0.071840, -0.068796, -0.065697, -0.062543, -0.059334, -0.056068, -0.052745, -0.049365, -0.045928, -0.042432, -0.038878, -0.035264, -0.031591, -0.027858, -0.024064, -0.020209, -0.016292, -0.012313, -0.008272, -0.004168, 0.000000, 0.004232, 0.008528, 0.012889, 0.017316, 0.021808, 0.026367, 0.030992, 0.035684, 0.040444, 0.045272, 0.050169, 0.055134, 0.060168, 0.065273, 0.070447, 0.075692, 0.081007, 0.086394, 0.091853, 0.097383, 0.102986, 0.108662, 0.114411, 0.120233, 0.126129, 0.132099, 0.138144, 0.144263, 0.150457, 0.156727, 0.163072, 0.169493, 0.175990, 0.182564, 0.189214, 0.195941, 0.202745, 0.209627, 0.216585, 0.223622, 0.230736, 0.237929, 0.245200, 0.252549, 0.259976, 0.267482, 0.275066, 0.282730, 0.290472, 0.298293, 0.306193, 0.314172, 0.322230, 0.330366, 0.338582, 0.346877, 0.355250, 0.363703, 0.372234, 0.380844, 0.389532, 0.398299, 0.407144, 0.416067, 0.425068, 0.434147, 0.443303, 0.452537, 0.461848, 0.471235, 0.480699, 0.490240, 0.499856, 0.509548, 0.519314, 0.529156, 0.539072, 0.549062, 0.559126, 0.569262, 0.579471, 0.589753, 0.600106, 0.610530, 0.621024, 0.631588, 0.642222, 0.652924, 0.663694, 0.674532, 0.685436, 0.696406, 0.707441, 0.718541, 0.729704, 0.740931, 0.752219, 0.763568, 0.774978, 0.786448, 0.797976, 0.809561, 0.821203, 0.832901, 0.844654, 0.856460, 0.868319, 0.880230, 0.892191, 0.904201, 0.916260, 0.928366, 0.940518, 0.952715, 0.964955, 0.977238, 0.989561, 1.001925, 1.014327, 1.026767, 1.039242, 1.051752, 1.064295, 1.076870, 1.089475, 1.102109, 1.114771, 1.127459, 1.140171, 1.152907, 1.165664, 1.178441, 1.191237, 1.204050, 1.216878, 1.229720, 1.242573, 1.255438, 1.268311, 1.281192, 1.294078, 1.306968, 1.319860, 1.332753, 1.345644, 1.358533, 1.371417, 1.384294, 1.397163, 1.410022, 1.422870, 1.435704, 1.448522, 1.461323, 1.474105, 1.486867, 1.499606, 1.512321, 1.525009, 1.537669, 1.550299, 1.562897, 1.575462, 1.587991, 1.600483, 1.612935, 1.625347, 1.637715, 1.650040, 1.662317, 1.674545, 1.686724, 1.698850, 1.710922, 1.722939, 1.734897, 1.746797, 1.758635, 1.770409, 1.782119, 1.793762, 1.805337, 1.816842, 1.828274, 1.839633, 1.850916, 1.862121, 1.873248, 1.884294, 1.895258, 1.906137, 1.916931, 1.927637, 1.938255, 1.948782, 1.959216, 1.969557, 1.979802, 1.989950, 2.000000, 2.009950, 2.019798, 2.029543, 2.039184, 2.048719, 2.058147, 2.067466, 2.076675, 2.085773, 2.094759, 2.103630, 2.112386, 2.121026, 2.129548, 2.137952, 2.146235, 2.154397, 2.162437, 2.170354, 2.178146, 2.185813, 2.193353, 2.200766, 2.208051, 2.215207, 2.222232, 2.229127, 2.235889, 2.242520, 2.249017, 2.255379, 2.261607, 2.267699, 2.273656, 2.279475, 2.285158, 2.290702, 2.296108, 2.301376, 2.306503, 2.311491, 2.316339, 2.321046, 2.325613, 2.330038, 2.334321, 2.338464, 2.342464, 2.346322, 2.350037, 2.353610, 2.357041, 2.360329, 2.363475, 2.366478, 2.369338, 2.372056, 2.374632, 2.377065, 2.379356, 2.381505, 2.383512, 2.385378, 2.387103, 2.388686, 2.390128, 2.391431, 2.392593, 2.393615, 2.394499, 2.395244, 2.395850, 2.396319, 2.396650, 2.396845, 2.396904, 2.396826, 2.396615, 2.396268, 2.395789, 2.395176, 2.394431, 2.393554, 2.392547, 2.391410, 2.390144, 2.388749, 2.387227, 2.385578, 2.383804, 2.381904, 2.379880, 2.377734, 2.375465, 2.373075, 2.370565, 2.367936, 2.365188, 2.362324, 2.359343, 2.356247, 2.353038, 2.349715, 2.346280, 2.342735, 2.339080, 2.335316, 2.331445, 2.327468, 2.323386, 2.319200, 2.314912, 2.310522, 2.306031, 2.301442, 2.296754, 2.291970, 2.287090, 2.282116, 2.277049, 2.271890, 2.266641, 2.261302, 2.255876, 2.250362, 2.244763, 2.239080, 2.233314, 2.227467, 2.221538, 2.215531, 2.209445, 2.203284, 2.197047, 2.190736, 2.184352, 2.177897, 2.171371, 2.164777, 2.158115, 2.151386, 2.144592, 2.137735, 2.130815, 2.123833, 2.116792, 2.109692, 2.102533, 2.095320, 2.088051, 2.080727, 2.073352, 2.065925, 2.058447, 2.050921, 2.043347, 2.035727, 2.028061, 2.020351, 2.012599, 2.004804, 1.996969, 1.989094, 1.981181, 1.973232, 1.965246, 1.957225, 1.949171, 1.941084, 1.932965, 1.924816, 1.916638, 1.908432, 1.900198, 1.891938, 1.883654, 1.875345, 1.867014, 1.858661, 1.850286, 1.841892, 1.833479, 1.825048, 1.816600, 1.808136, 1.799657, 1.791165, 1.782659, 1.774141, 1.765612, 1.757073, 1.748523, 1.739967, 1.731401, 1.722829, 1.714251, 1.705668, 1.697082, 1.688491, 1.679897, 1.671302, 1.662707, 1.654110, 1.645514, 1.636920, 1.628328, 1.619738, 1.611152, 1.602570, 1.593993, 1.585422, 1.576857, 1.568299, 1.559749, 1.551207, 1.542674, 1.534151, 1.525638, 1.517136, 1.508645, 1.500167, 1.491701, 1.483248, 1.474810, 1.466385, 1.457976, 1.449581, 1.441203, 1.432841, 1.424496, 1.416169, 1.407860, 1.399569, 1.391297, 1.383045, 1.374812, 1.366600, 1.358408, 1.350237, 1.342088, 1.333961, 1.325856, 1.317774, 1.309715, 1.301679, 1.293668, 1.285680, 1.277718, 1.269780, 1.261867, 1.253980, 1.246119, 1.238283, 1.230474, 1.222692, 1.214937, 1.207210, 1.199510, 1.191837, 1.184193, 1.176577, 1.168990, 1.161430, 1.153901, 1.146401, 1.138930, 1.131489, 1.124077, 1.116696, 1.109345, 1.102024, 1.094734, 1.087475, 1.080246, 1.073049 }, sd::DataType::FLOAT32); + + NDArray e('f', { 2, 2, 3, 3, 4, 4 }, sd::DataType::FLOAT32); + e.assign(exp); + + sd::ops::tanh_bp op; + op.execute({ &x, &dLdz }, { &dLdx }); + ASSERT_EQ(e, dLdx); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestSoftMax_bp_TEST) { + + NDArray input('c', { 2, 2 }, { 1,2,3,4 }, DataType::FLOAT32); + NDArray epsilon('c', { 2, 2 }, { .1, .2, .3, .4 }, DataType::FLOAT32); + + int axis = 1; + + NDArray output('c', { 2, 2 }, DataType::FLOAT32); + + NDArray exp('c', { 2, 2 }, { -0.019661, 0.019661, -0.019661, 0.019661 }, DataType::FLOAT32); + + sd::ops::softmax_bp op; + + Nd4jStatus status = op.execute({ &input, &epsilon }, { &output }, {}, { axis }); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(output.equalsTo(exp)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestSoftMax_bp_TEST2) { + + NDArray input('c', { 4, 5, 2, 3 }, DataType::FLOAT32); + NDArray epsilon('c', { 4, 5, 2, 3 }, { -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855 }, DataType::FLOAT32); + input.linspace(0.1, 0.2); + + int axis = -1; + + NDArray output('c', { 4, 5, 2, 3 }, DataType::FLOAT32); + NDArray exp('c', { 4, 5, 2, 3 }, { -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253, -0.009387, -0.002866, 0.012253 }, DataType::FLOAT32); + + sd::ops::softmax_bp op; + + Nd4jStatus status = op.execute({ &input, &epsilon }, { &output }, {}, { axis }); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(output.equalsTo(exp)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestSoftMax_bp_TEST3) { + + NDArray input('f', { 4, 5, 2, 3 }, DataType::FLOAT32); + NDArray epsilon('f', { 4, 5, 2, 3 }, { -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855, -0.030498, -0.004357, 0.034855 }, DataType::FLOAT32); + input.linspace(-5., 0.5); + + int axis = 1; + + NDArray output('f', { 4, 5, 2, 3 }, DataType::FLOAT32); + NDArray expC('c', { 4, 5, 2, 3 }, { -0.0, -0.0, 0.0, 0.0, -0.0, -0.0, 0.0, 0.0, -0.0, -0.0, 0.0, 0.0, 0.000095, -0.000149, 0.000054, 0.000054, 0.000095, -0.000149, -0.001183, -0.001760, 0.002943, 0.002943, -0.001183, -0.001760, 0.001088, 0.001909, -0.002997, -0.002997, 0.001088, 0.001909, -0.000000, 0.000000, -0.000000, -0.000000, -0.000000, 0.000000, 0.000000, -0.000000, 0.000000, 0.000000, 0.000000, -0.000000, -0.000149, 0.000054, 0.000095, 0.000095, -0.000149, 0.000054, -0.001760, 0.002943, -0.001183, -0.001183, -0.001760, 0.002943, 0.001909, -0.002997, 0.001088, 0.001088, 0.001909, -0.002997, 0.000000, -0.000000, -0.000000, -0.000000, 0.000000, -0.000000, -0.000000, 0.000000, 0.000000, 0.000000, -0.000000, 0.000000, 0.000054, 0.000095, -0.000149, -0.000149, 0.000054, 0.000095, 0.002943, -0.001183, -0.001760, -0.001760, 0.002943, -0.001183, -0.002997, 0.001088, 0.001909, 0.001909, -0.002997, 0.001088, -0.000000, -0.000000, 0.000000, 0.000000, -0.000000, -0.000000, 0.000000, 0.000000, -0.000000, -0.000000, 0.000000, 0.000000, 0.000095, -0.000149, 0.000054, 0.000054, 0.000095, -0.000149, -0.001183, -0.001760, 0.002943, 0.002943, -0.001183, -0.001760, 0.001088, 0.001909, -0.002997, -0.002997, 0.001088, 0.001909 }, DataType::FLOAT32); + + NDArray exp('f', { 4, 5, 2, 3 }, DataType::FLOAT32); + exp.assign(expC); + + sd::ops::softmax_bp op; + + Nd4jStatus status = op.execute({ &input, &epsilon }, { &output }, {}, { axis }); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(output.equalsTo(exp)); +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, XWPlusB_Bp_1) { + + auto x = NDArrayFactory::create('c', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); + auto w = NDArrayFactory::create('c', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); + auto b = NDArrayFactory::create({ 100.f, 200.f }); + + NDArray dLdz('c', { 2, 2 }, DataType::FLOAT32); + dLdz.linspace(1); + + sd::ops::xw_plus_b_bp op; + auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto dLdx = result.at(0); + auto dLdw = result.at(1); + auto dLdb = result.at(2); + + auto edLdx = NDArrayFactory::create('c', { 2,3 }, { 17.f, 14.f, 10.f, 45.f, 32.f, 26.f }); + auto edLdw = NDArrayFactory::create('c', { 3,2 }, { 43.f, 58.f, 26.f, 42.f, 21.f, 30.f }); + auto edLdb = NDArrayFactory::create('c', { 2 }, { 4.f, 6.f }); + + ASSERT_TRUE(edLdx.isSameShape(dLdx)); + ASSERT_TRUE(edLdw.isSameShape(dLdw)); + ASSERT_TRUE(edLdb.isSameShape(dLdb)); + ASSERT_TRUE(edLdx.equalsTo(dLdx)); + ASSERT_TRUE(edLdw.equalsTo(dLdw)); + ASSERT_TRUE(edLdb.equalsTo(dLdb)); +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, XWPlusB_Bp_2) { + + auto x = NDArrayFactory::create('c', { 6,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f, 1.f, 11.f, 3.f, 14.f, 5.f, 6.f, 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); + auto w = NDArrayFactory::create('c', { 3,4 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f, 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); + auto b = NDArrayFactory::create('c', { 4 }, { 100.f, 200.f, 100.f, 200.f }); + + NDArray dLdz('c', { 6, 4 }, DataType::FLOAT32); + dLdz.linspace(.1, .5); + + sd::ops::xw_plus_b_bp op; + auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto dLdx = result.at(0); + auto dLdw = result.at(1); + auto dLdb = result.at(2); + + auto edLdx = NDArrayFactory::create('c', { 6,3 }, { 15.3f, 18.700001f, 13.2f, 61.299995f, 62.699997f, 47.200001f, 107.299995f, 106.699997f, 81.199997f, 153.299988f, 150.699997f, 115.199997f, 199.300018f, 194.700012f, 149.199997f, 245.300018f, 238.700012f, 183.199997f }); + auto edLdw = NDArrayFactory::create('c', { 3,4 }, { 268.5f, 291.f, 313.5f, 336.f, 226.800003f, 250.800003f, 274.799988f, 298.799988f, 146.699997f, 160.199997f, 173.700012f, 187.200012f }); + auto edLdb = NDArrayFactory::create('c', { 4 }, { 30.6f, 33.599998f, 36.599998f, 39.599998f }); + ASSERT_TRUE(edLdx.isSameShape(dLdx)); + ASSERT_TRUE(edLdw.isSameShape(dLdw)); + ASSERT_TRUE(edLdb.isSameShape(dLdb)); + ASSERT_TRUE(edLdx.equalsTo(dLdx)); + ASSERT_TRUE(edLdw.equalsTo(dLdw)); + ASSERT_TRUE(edLdb.equalsTo(dLdb)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, XWPlusB_Bp_3) { + + auto x = NDArrayFactory::create('c', { 1, 2 }, { 1.f, 11.f }); + auto w = NDArrayFactory::create('c', { 2, 3 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); + auto b = NDArrayFactory::create({ 100.f, 200.f, 300.f }); + + auto dLdz = NDArrayFactory::create('c', { 1, 3 }, { 166.f, 269.f, 326.f }); + + sd::ops::xw_plus_b_bp op; + auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto dLdx = result.at(0); + auto dLdw = result.at(1); + auto dLdb = result.at(2); + + auto edLdx = NDArrayFactory::create('c', { 1,2 }, { 3937.f, 3096.f }); + auto edLdw = NDArrayFactory::create('c', { 2,3 }, { 166.f, 269.f, 326.f, 1826.f, 2959.f, 3586.f }); + auto edLdb = NDArrayFactory::create('c', { 3 }, { 166.f, 269.f, 326.f }); + ASSERT_TRUE(edLdx.isSameShape(dLdx)); + ASSERT_TRUE(edLdw.isSameShape(dLdw)); + ASSERT_TRUE(edLdb.isSameShape(dLdb)); + ASSERT_TRUE(edLdx.equalsTo(dLdx)); + ASSERT_TRUE(edLdw.equalsTo(dLdw)); + ASSERT_TRUE(edLdb.equalsTo(dLdb)); + +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, XWPlusB_Bp_4) { + + auto x = NDArrayFactory::create('c', { 1, 2 }, { 1.f, 11.f }); + auto w = NDArrayFactory::create('c', { 2, 1 }, { 11.f, 3.f }); + auto b = NDArrayFactory::create('c', { 1 }, { 200.f }); + + auto dLdz = NDArrayFactory::create('c', { 1,1 }, { 244.f }); + + sd::ops::xw_plus_b_bp op; + auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto dLdx = result.at(0); + auto dLdw = result.at(1); + auto dLdb = result.at(2); + + auto edLdx = NDArrayFactory::create('c', { 1,2 }, { 2684.f, 732.f }); + auto edLdw = NDArrayFactory::create('c', { 2,1 }, { 244.f, 2684.f }); + auto edLdb = NDArrayFactory::create('c', { 1 }, { 244.f }); + ASSERT_TRUE(edLdx.isSameShape(dLdx)); + ASSERT_TRUE(edLdw.isSameShape(dLdw)); + ASSERT_TRUE(edLdb.isSameShape(dLdb)); + ASSERT_TRUE(edLdx.equalsTo(dLdx)); + ASSERT_TRUE(edLdw.equalsTo(dLdw)); + ASSERT_TRUE(edLdb.equalsTo(dLdb)); + +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, XWPlusB_Bp_5) { + + auto x = NDArrayFactory::create('f', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); + auto w = NDArrayFactory::create('f', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); + auto b = NDArrayFactory::create({ 100.f, 200.f }); + + auto dLdz = NDArrayFactory::create('f', { 2,2 }, { 140.f, 287.f, 233.f, 351.f }); + + sd::ops::xw_plus_b_bp op; + auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto dLdx = result.at(0); + auto dLdw = result.at(1); + auto dLdb = result.at(2); + + auto edLdxC = NDArrayFactory::create('c', { 2,3 }, { 2705.f, 1818.f, 1026.f, 4912.f, 2967.f, 1850.f }); + auto edLdwC = NDArrayFactory::create('c', { 3,2 }, { 3297.f, 4094.f, 4438.f, 5613.f, 2422.f, 3271.f }); + auto edLdbC = NDArrayFactory::create('c', { 2 }, { 427.f, 584.f }); + + auto edLdx = NDArrayFactory::create('f', { 2,3 }); + auto edLdw = NDArrayFactory::create('f', { 3,2 }); + auto edLdb = NDArrayFactory::create('f', { 2 }); + + edLdx.assign(edLdxC); + edLdw.assign(edLdwC); + edLdb.assign(edLdbC); + + ASSERT_TRUE(edLdx.isSameShape(dLdx)); + ASSERT_TRUE(edLdw.isSameShape(dLdw)); + ASSERT_TRUE(edLdb.isSameShape(dLdb)); + ASSERT_TRUE(edLdx.equalsTo(dLdx)); + ASSERT_TRUE(edLdw.equalsTo(dLdw)); + ASSERT_TRUE(edLdb.equalsTo(dLdb)); + +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, XWPlusB_Bp_6) { + + auto x = NDArrayFactory::create('c', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); + auto w = NDArrayFactory::create('c', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); + auto b = NDArrayFactory::create({ 100.f, 200.f }); + + auto dLdz = NDArrayFactory::create('c', { 2,2 }, { 173.f, 264.f, 310.f, 279.f }); + + // mkl-format + w.permutei({ 1,0 }); + + sd::ops::xw_plus_b_bp op; + auto result = op.evaluate({ &x, &w, &b, &dLdz }, {}, { 1 }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto dLdx = result.at(0); + auto dLdw = result.at(1); + auto dLdb = result.at(2); + + auto edLdx = NDArrayFactory::create('c', { 2,3 }, { 2695.f, 2012.f, 1566.f, 4247.f, 2635.f, 2418.f }); + auto edLdwC = NDArrayFactory::create('c', { 3,2 }, { 4513.f, 3453.f, 2379.f, 4170.f, 4299.f, 2466.f }); + auto edLdb = NDArrayFactory::create('c', { 2 }, { 483.f, 543.f }); + auto edLdw = NDArrayFactory::create('c', { 3,2 }, { 4513.f, 3453.f, 2379.f, 4170.f, 4299.f, 2466.f }); + edLdw.permutei({ 1,0 }); + edLdw.assign(edLdwC); + + ASSERT_TRUE(edLdx.isSameShape(dLdx)); + ASSERT_TRUE(edLdw.isSameShape(dLdw)); + ASSERT_TRUE(edLdb.isSameShape(dLdb)); + ASSERT_TRUE(edLdx.equalsTo(dLdx)); + ASSERT_TRUE(edLdw.equalsTo(dLdw)); + ASSERT_TRUE(edLdb.equalsTo(dLdb)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterSgd1) { + + NDArray gradient('c', { 1, 5 }, { 0.21138794720172882, 0.38947954773902893, 0.2822134494781494, 0.4342866837978363, 0.7928546667098999 }, DataType::FLOAT32); + auto lr = NDArrayFactory::create(0.001f); + + NDArray update('c', { 1, 5 }, { 0.00021138794720173, 0.00038947954773903, 0.00028221344947815, 0.00043428668379784, 0.0007928546667099 }, DataType::FLOAT32); + + sd::ops::sgd_updater op; + + Nd4jStatus status = op.execute({ &gradient, &lr }, { &gradient }, {}, { }); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(update.equalsTo(gradient)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterSgd2) { + + NDArray gradient('c', { 1, 5 }, { 0.21138794720172882, 0.38947954773902893, 0.2822134494781494, 0.4342866837978363, 0.7928546667098999 }, DataType::FLOAT32); + + NDArray update('c', { 1, 5 }, { 0.00021138794720173, 0.00038947954773903, 0.00028221344947815, 0.00043428668379784, 0.0007928546667099 }, DataType::FLOAT32); + + sd::ops::sgd_updater op; + + Nd4jStatus status = op.execute({ &gradient }, { &gradient }, { 0.001f }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(update.equalsTo(gradient)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterSgd3) { + + NDArray gradientC('c', { 1, 5 }, { 0.21138794720172882, 0.38947954773902893, 0.2822134494781494, 0.4342866837978363, 0.7928546667098999 }, DataType::FLOAT32); + + NDArray updateC('c', { 1, 5 }, { 0.00021138794720173, 0.00038947954773903, 0.00028221344947815, 0.00043428668379784, 0.0007928546667099 }, DataType::FLOAT32); + + NDArray gradient('f', { 1, 5 }, DataType::FLOAT32); + NDArray update('f', { 1, 5 }, DataType::FLOAT32); + + gradient.assign(gradientC); + update.assign(updateC); + + sd::ops::sgd_updater op; + + auto results = op.evaluate({ &gradient }, { 0.001f }, { }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterRmsProm1) { + + NDArray grad0('c', { 1, 5 }, { 0.1811431348323822, 0.10499879717826843, 0.8736756443977356, 0.9707390666007996, 0.7415646314620972 }, DataType::FLOAT32); + NDArray init('c', { 1, 5 }, { 0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001 }, DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.1f); + auto decay = NDArrayFactory::create(0.95f); + auto epsilon = NDArrayFactory::create(1.e-8f); + + sd::ops::rms_prop_updater op; + + Nd4jStatus status = op.execute({ &grad0, &init, &lr, &decay, &epsilon }, { &grad0, &init }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp0('c', { 1, 5 }, { 0.4472121903197142, 0.4472095514452829, 0.4472135169488324, 0.44721352981195367, 0.44721349127249754 }, DataType::FLOAT32); + NDArray stateG0('c', { 1, 5 }, { 0.00164065126484513, 0.00055124687044416, 0.03816546608068996, 0.04711672627124962, 0.02749591463177582 }, DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(init.equalsTo(stateG0)); + + + NDArray grad1('c', { 1, 5 }, { 0.0139725673943758, 0.19333727657794952, 0.9288347363471985, 0.9253600239753723, 0.3578299283981323 }, DataType::FLOAT32); + status = op.execute({ &grad1, &init, &lr, &decay, &epsilon }, { &grad1, &init }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp1('c', { 1, 5 }, { 0.03528177364993147, 0.3952537075263024, 0.32964378302079766, 0.31269398966616074, 0.1984174163852542 }, DataType::FLOAT32); + NDArray stateG1('c', { 1, 5 }, { 0.00156838033358239, 0.00239264965265088, 0.07939389114891399, 0.08757544865627226, 0.03252323178305766 }, DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(init.equalsTo(stateG1)); + + NDArray grad2('c', { 1, 5 }, { 0.5442887544631958, 0.5386605262756348, 0.884294331073761, 0.15599730610847473, 0.7259345054626465 }, DataType::FLOAT32); + status = op.execute({ &grad2, &init, &lr, &decay, &epsilon }, { &grad2, &init }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp2('c', { 1, 5 }, { 0.4262874753567082, 0.41582357367557454, 0.2613066321005825, 0.05369221235564697, 0.3034061716240995 }, DataType::FLOAT32); + NDArray stateG2('c', { 1, 5 }, { 0.01630247372865814, 0.01678077529839554, 0.11452301978992785, 0.0844134341991137, 0.05724611550496966 }, DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(init.equalsTo(stateG2)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterRmsProm2) { + + NDArray grad('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); + NDArray init('c', { 1, 5 }, { 0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001 }, DataType::FLOAT32); + + NDArray update('c', { 1, 5 }, DataType::FLOAT32); + + sd::ops::rms_prop_updater op; + + Nd4jStatus status = op.execute({ &grad, &init }, { &update, &init }, { 0.1f, 0.95f, 1.e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp0('c', { 1, 5 }, { 0.4472135330146769, 0.44721357487863594, 0.44721358411270346, 0.4472135878446271, 0.447213589800546 }, DataType::FLOAT32); + NDArray stateG0('c', { 1, 5 }, { 0.05000000950000005, 0.2000000095000002, 0.4500000095000004, 0.8000000095000007, 1.250000009500001 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(init.equalsTo(stateG0)); + + status = op.execute({ &grad, &init }, { &update, &init }, { 0.1f, 0.95f, 1.e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp1('c', { 1, 5 }, { 0.32025628253164734, 0.3202562987764395, 0.32025630254446874, 0.3202563041196892, 0.3202563049660074 }, DataType::FLOAT32); + NDArray stateG1('c', { 1, 5 }, { 0.09750000902500008, 0.3900000090250003, 0.8775000090250007, 1.5600000090250012, 2.437500009025002 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(init.equalsTo(stateG1)); + + status = op.execute({ &grad, &init }, { &update, &init }, { 0.1f, 0.95f, 1.e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp2('c', { 1, 5 }, { 0.2647903457769699, 0.2647903552517623, 0.26479035752571606, 0.2647903584968847, 0.2647903590265272 }, DataType::FLOAT32); + NDArray stateG2('c', { 1, 5 }, { 0.1426250085737501, 0.5705000085737504, 1.283625008573751, 2.2820000085737515, 3.565625008573753 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(init.equalsTo(stateG2)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterRmsProm3) { + + NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); + NDArray initC('c', { 1, 5 }, { 0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001 }, DataType::FLOAT32); + + NDArray grad('f', { 1, 5 }, DataType::FLOAT32); + NDArray init('f', { 1, 5 }, DataType::FLOAT32); + grad.assign(gradC); + init.assign(initC); + + sd::ops::rms_prop_updater op; + auto results = op.evaluate({ &grad, &init }, { 0.1f, 0.95f, 1.e-8 }, { }); + + NDArray updateC('c', { 1, 5 }, { 0.4472135330146769, 0.44721357487863594, 0.44721358411270346, 0.4472135878446271, 0.447213589800546 }, DataType::FLOAT32); + NDArray update('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateG0C('c', { 1, 5 }, { 0.05000000950000005, 0.2000000095000002, 0.4500000095000004, 0.8000000095000007, 1.250000009500001 }, DataType::FLOAT32); + NDArray stateG('f', { 1, 5 }, DataType::FLOAT32); + + update.assign(updateC); + stateG.assign(stateG0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateG.isSameShape(results.at(1))); + ASSERT_TRUE(stateG.equalsTo(results.at(1))); + + results = op.evaluate({ &grad, &stateG }, { 0.1f, 0.95f, 1.e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update1C('c', { 1, 5 }, { 0.32025628253164734, 0.3202562987764395, 0.32025630254446874, 0.3202563041196892, 0.3202563049660074 }, DataType::FLOAT32); + NDArray stateG1C('c', { 1, 5 }, { 0.09750000902500008, 0.3900000090250003, 0.8775000090250007, 1.5600000090250012, 2.437500009025002 }, DataType::FLOAT32); + + update.assign(update1C); + stateG.assign(stateG1C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateG.isSameShape(results.at(1))); + ASSERT_TRUE(stateG.equalsTo(results.at(1))); + + + results = op.evaluate({ &grad, &stateG }, { 0.1f, 0.95f, 1.e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update2C('c', { 1, 5 }, { 0.2647903457769699, 0.2647903552517623, 0.26479035752571606, 0.2647903584968847, 0.2647903590265272 }, DataType::FLOAT32); + NDArray stateG2C('c', { 1, 5 }, { 0.1426250085737501, 0.5705000085737504, 1.283625008573751, 2.2820000085737515, 3.565625008573753 }, DataType::FLOAT32); + + update.assign(update2C); + stateG.assign(stateG2C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateG.isSameShape(results.at(1))); + ASSERT_TRUE(stateG.equalsTo(results.at(1))); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAdaGrad1) { + + // need Java test + + NDArray grad0('c', { 1, 5 }, { 0.1811431348323822, 0.10499879717826843, 0.8736756443977356, 0.9707390666007996, 0.7415646314620972 }, DataType::FLOAT32); + NDArray init('c', { 1, 5 }, { 0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001 }, DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.1f); + auto epsilon = NDArrayFactory::create(1.e-8f); + + sd::ops::ada_grad_updater op; + + Nd4jStatus status = op.execute({ &grad0, &init, &lr, &epsilon }, { &grad0, &init }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterNesterovs1) { + + NDArray grad0('c', { 1, 5 }, { 0.6877592206001282, 0.7830561399459839, 0.7647699117660522, 0.6183066964149475, 0.3303879499435425 }, DataType::FLOAT32); + NDArray init('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + sd::ops::nesterovs_updater op; + + Nd4jStatus status = op.execute({ &grad0, &init }, { &grad0, &init }, { 0.1f, 0.9f }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.13067425191402435, 0.14878066658973696, 0.14530628323554992, 0.11747827231884002, 0.06277371048927306 }, DataType::FLOAT32); + NDArray stateV0('c', { 1, 5 }, { -0.06877592206001282, -0.0783056139945984, -0.07647699117660522, -0.06183066964149475, -0.03303879499435425 }, DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(init.equalsTo(stateV0)); + + NDArray grad1('c', { 1, 5 }, { 0.3676236569881439, 0.07645636051893234, 0.45949840545654297, 0.6335387825965881, 0.2953402101993561 }, DataType::FLOAT32); + status = op.execute({ &grad1, &init }, { &grad1, &init }, { 0.1f, 0.9f }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp1('c', { 1, 5 }, { 0.12555699169635773, 0.07795425583422186, 0.14925105988979342, 0.17045521110296247, 0.08287606388330458 }, DataType::FLOAT32); + NDArray stateV1('c', { 1, 5 }, { -0.09866069555282593, -0.0781206886470318, -0.11477913260459902, -0.11900148093700408, -0.05926893651485443 }, DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(init.equalsTo(stateV1)); + + NDArray grad2('c', { 1, 5 }, { 0.9874004125595093, 0.41817641258239746, 0.16838215291500092, 0.00803728867322206, 0.37015461921691895 }, DataType::FLOAT32); + status = op.execute({ &grad2, &init }, { &grad2, &init }, { 0.1f, 0.9f }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp2('c', { 1, 5 }, { 0.26752124178409575, 0.1427312761947513, 0.12496370646357537, 0.09791828440688549, 0.11833721622824667 }, DataType::FLOAT32); + NDArray stateV2('c', { 1, 5 }, { -0.18753466725349427, -0.11212626104056837, -0.12013943463563921, -0.10790506171062587, -0.09035750478506088 }, DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(init.equalsTo(stateV2)); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterNesterovs2) { + + NDArray grad('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); + NDArray init('c', { 1, 5 }, { 0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001 }, DataType::FLOAT32); + + NDArray update('c', { 1, 5 }, DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.1f); + auto momentum = NDArrayFactory::create(0.9f); + + sd::ops::nesterovs_updater op; + + Nd4jStatus status = op.execute({ &grad, &init, &lr, &momentum }, { &update, &init }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.19, 0.38, 0.5700000000000001, 0.76, 0.95 }, DataType::FLOAT32); + NDArray stateV0('c', { 1, 5 }, { -0.1, -0.2, -0.30000000000000004, -0.4, -0.5 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(init.equalsTo(stateV0)); + + status = op.execute({ &grad, &init, &lr, &momentum }, { &update, &init }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp1('c', { 1, 5 }, { 0.27099999999999996, 0.5419999999999999, 0.813, 1.0839999999999999, 1.355 }, DataType::FLOAT32); + NDArray stateV1('c', { 1, 5 }, { -0.19, -0.38, -0.5700000000000001, -0.76, -0.95 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(init.equalsTo(stateV1)); + + status = op.execute({ &grad, &init, &lr, &momentum }, { &update, &init }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + NDArray updateExp2('c', { 1, 5 }, { 0.3439, 0.6878, 1.0317, 1.3756, 1.7195 }, DataType::FLOAT32); + NDArray stateV2('c', { 1, 5 }, { -0.271, -0.542, -0.8130000000000002, -1.084, -1.355 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(init.equalsTo(stateV2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterNesterovs3) { + + NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); + NDArray initC('c', { 1, 5 }, { 0.00000001, 0.00000001, 0.00000001, 0.00000001, 0.00000001 }, DataType::FLOAT32); + + NDArray grad('f', { 1, 5 }, DataType::FLOAT32); + NDArray init('f', { 1, 5 }, DataType::FLOAT32); + grad.assign(gradC); + init.assign(initC); + + sd::ops::nesterovs_updater op; + auto results = op.evaluate({ &grad, &init }, { 0.1f, 0.9f }, { }); + + NDArray updateC('c', { 1, 5 }, { 0.19, 0.38, 0.5700000000000001, 0.76, 0.95 }, DataType::FLOAT32); + NDArray update('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateG0C('c', { 1, 5 }, { -0.1, -0.2, -0.30000000000000004, -0.4, -0.5 }, DataType::FLOAT32); + NDArray stateG('f', { 1, 5 }, DataType::FLOAT32); + + update.assign(updateC); + stateG.assign(stateG0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateG.isSameShape(results.at(1))); + ASSERT_TRUE(stateG.equalsTo(results.at(1))); + + results = op.evaluate({ &grad, &stateG }, { 0.1f, 0.9f }, { }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update1C('c', { 1, 5 }, { 0.27099999999999996, 0.5419999999999999, 0.813, 1.0839999999999999, 1.355 }, DataType::FLOAT32); + NDArray stateG1C('c', { 1, 5 }, { -0.19, -0.38, -0.5700000000000001, -0.76, -0.95 }, DataType::FLOAT32); + + update.assign(update1C); + stateG.assign(stateG1C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateG.isSameShape(results.at(1))); + ASSERT_TRUE(stateG.equalsTo(results.at(1))); + + + results = op.evaluate({ &grad, &stateG }, { 0.1f, 0.9f }, { }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update2C('c', { 1, 5 }, { 0.3439, 0.6878, 1.0317, 1.3756, 1.7195 }, DataType::FLOAT32); + NDArray stateG2C('c', { 1, 5 }, { -0.271, -0.542, -0.8130000000000002, -1.084, -1.355 }, DataType::FLOAT32); + + update.assign(update2C); + stateG.assign(stateG2C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateG.isSameShape(results.at(1))); + ASSERT_TRUE(stateG.equalsTo(results.at(1))); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAdaMax1) { + + NDArray grad('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); + NDArray initU('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + NDArray update('c', { 1, 5 }, DataType::FLOAT32); + + sd::ops::ada_max_updater op; + + Nd4jStatus status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.001, 0.001, 0.001, 0.001, 0.001 }, DataType::FLOAT32); + NDArray stateU('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); + NDArray stateM0('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.4999999999999999 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(initU.equalsTo(stateU)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', { 1, 5 }, { 0.0019, 0.0019, 0.0019, 0.0019, 0.0019 }, DataType::FLOAT32); + NDArray stateM1('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(initU.equalsTo(stateU)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', { 1, 5 }, { 0.00271, 0.00271, 0.00271, 0.00271, 0.00271 }, DataType::FLOAT32); + NDArray stateM2('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(initU.equalsTo(stateU)); + ASSERT_TRUE(initM.equalsTo(stateM2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAdaMax2) { + + NDArray grad0('c', { 1, 5 }, { 0.05387359112501144, 0.9700437784194946, 0.8912011384963989, 0.8891847729682922, 0.18823780119419098 }, DataType::FLOAT32); + NDArray initU('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.001f); + auto beta1 = NDArrayFactory::create(0.9f); + auto beta2 = NDArrayFactory::create(0.999f); + auto epsilon = NDArrayFactory::create(1.0e-8); + + sd::ops::ada_max_updater op; + + Nd4jStatus status = op.execute({ &grad0, &initU, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad0, &initU, &initM }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.001, 0.001, 0.001, 0.001, 0.001 }, DataType::FLOAT32); + NDArray stateU0('c', { 1, 5 }, { 0.05387359112501144, 0.9700437784194946, 0.8912011384963989, 0.8891847729682922, 0.18823780119419098 }, DataType::FLOAT32); + NDArray stateM0('c', { 1, 5 }, { 0.00538735911250114, 0.09700437784194944, 0.08912011384963987, 0.08891847729682921, 0.01882378011941909 }, DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(initU.equalsTo(stateU0)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + NDArray grad1('c', { 1, 5 }, { 0.6400517821311951, 0.3779360353946686, 0.35128724575042725, 0.6554615497589111, 0.8420050740242004 }, DataType::FLOAT32); + + status = op.execute({ &grad1, &initU, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad1, &initU, &initM }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', { 1, 5 }, { 0.00107575360832691, 0.00129089809294599, 0.00129546826560191, 0.00163878765669416, 0.00120120308808246 }, DataType::FLOAT32); + NDArray stateU1('c', { 1, 5 }, { 0.6400517821311951, 0.9690737346410752, 0.8903099373579025, 0.888295588195324, 0.8420050740242004 }, DataType::FLOAT32); + NDArray stateM1('c', { 1, 5 }, { 0.06885380141437052, 0.12509754359722136, 0.11533682703971859, 0.1455727845430374, 0.10114190950989721 }, DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(initU.equalsTo(stateU1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + NDArray grad2('c', { 1, 5 }, { 0.5984494686126709, 0.05978915095329285, 0.5749519467353821, 0.2804091274738312, 0.0192152876406908 }, DataType::FLOAT32); + + status = op.execute({ &grad2, &initU, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad2, &initU, &initM }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', { 1, 5 }, { 0.00190508497658779, 0.00122473022928962, 0.00181352349370876, 0.00179237223044249, 0.00110500865710834 }, DataType::FLOAT32); + NDArray stateU2('c', { 1, 5 }, { 0.6394117303490638, 0.9681046609064341, 0.8894196274205446, 0.8874072926071286, 0.8411630689501762 }, DataType::FLOAT32); + NDArray stateM2('c', { 1, 5 }, { 0.12181336813420054, 0.11856670433282851, 0.16129833900928492, 0.15905641883611676, 0.09294924732297657 }, DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(initU.equalsTo(stateU2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAdaMax3) { + + NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); + NDArray initVC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initMC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + NDArray grad('f', { 1, 5 }, DataType::FLOAT32); + NDArray initV('f', { 1, 5 }, DataType::FLOAT32); + NDArray initM('f', { 1, 5 }, DataType::FLOAT32); + + grad.assign(gradC); + initV.assign(initVC); + initM.assign(initMC); + + sd::ops::ada_max_updater op; + auto results = op.evaluate({ &grad, &initV, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + + NDArray updateC('c', { 1, 5 }, { 0.001, 0.001, 0.001, 0.001, 0.001 }, DataType::FLOAT32); + NDArray update('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateV0C('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); + NDArray stateV('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateM0C('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.4999999999999999 }, DataType::FLOAT32); + NDArray stateM('f', { 1, 5 }, DataType::FLOAT32); + + update.assign(updateC); + stateV.assign(stateV0C); + stateM.assign(stateM0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + + results = op.evaluate({ &grad, &stateV, &stateM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update1C('c', { 1, 5 }, { 0.0019, 0.0019, 0.0019, 0.0019, 0.0019 }, DataType::FLOAT32); + NDArray stateM1C('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); + + update.assign(update1C); + stateM.assign(stateM1C); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + + + results = op.evaluate({ &grad, &stateV, &stateM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update2C('c', { 1, 5 }, { 0.00271, 0.00271, 0.00271, 0.00271, 0.00271 }, DataType::FLOAT32); + NDArray stateM2C('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); + + update.assign(update2C); + stateM.assign(stateM2C); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAdam1) { + + NDArray grad('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); + NDArray initU('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + NDArray update('c', { 1, 5 }, DataType::FLOAT32); + + sd::ops::adam_updater op; + + Nd4jStatus status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.00099999968377233, 0.00099999984188614, 0.00099999989459076, 0.00099999992094306, 0.00099999993675445 }, DataType::FLOAT32); + NDArray stateV('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); + NDArray stateM0('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.4999999999999999 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(initU.equalsTo(stateV)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', { 1, 5 }, { 0.00134383858541481, 0.00134383873569809, 0.00134383878579252, 0.00134383881083974, 0.00134383882586807 }, DataType::FLOAT32); + NDArray stateV1('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); + NDArray stateM1('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(initU.equalsTo(stateV1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', { 1, 5 }, { 0.00156540157923389, 0.00156540172220632, 0.0015654017698638, 0.00156540179369254, 0.00156540180798979 }, DataType::FLOAT32); + NDArray stateV2('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); + NDArray stateM2('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(initU.equalsTo(stateV2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAdam2) { + + NDArray grad0('c', { 1, 5 }, { 0.7124611735343933, 0.7283763289451599, 0.8196553587913513, 0.9501070976257324, 0.2654055953025818 }, DataType::FLOAT32); + NDArray initU('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.001f); + auto beta1 = NDArrayFactory::create(0.9f); + auto beta2 = NDArrayFactory::create(0.999f); + auto epsilon = NDArrayFactory::create(1.0e-8); + + sd::ops::adam_updater op; + + Nd4jStatus status = op.execute({ &grad0, &initU, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad0, &initU, &initM }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.00099999955614757, 0.00099999956584582, 0.00099999961419438, 0.0009999996671663, 0.00099999880851273 }, DataType::FLOAT32); + NDArray stateU0('c', { 1, 5 }, { 0.00050760092379401, 0.00053053207656763, 0.00067183490719538, 0.00090270349695879, 0.00007044013001792 }, DataType::FLOAT32); + NDArray stateM0('c', { 1, 5 }, { 0.07124611735343932, 0.07283763289451597, 0.08196553587913512, 0.09501070976257323, 0.02654055953025817 }, DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(initU.equalsTo(stateU0)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + NDArray grad1('c', { 1, 5 }, { 0.4374369978904724, 0.11488933861255646, 0.6765823364257812, 0.7659900188446045, 0.04410457238554955 }, DataType::FLOAT32); + + status = op.execute({ &grad1, &initU, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad1, &initU, &initM }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', { 1, 5 }, { 0.00129067017716555, 0.00104532555849556, 0.00133106720937621, 0.00132869584719374, 0.00105226561254395 }, DataType::FLOAT32); + NDArray stateU1('c', { 1, 5 }, { 0.00069844444999364, 0.00054320110461789, 0.00112892673025155, 0.00148854150243139, 0.00007231490319321 }, DataType::FLOAT32); + NDArray stateM1('c', { 1, 5 }, { 0.10786520540714262, 0.07704280346632002, 0.14142721593379973, 0.16210864067077635, 0.02829696081578731 }, DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(initU.equalsTo(stateU1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + NDArray grad2('c', { 1, 5 }, { 0.496029257774353, 0.11621368676424026, 0.9112075567245483, 0.5717480182647705, 0.5975669026374817 }, DataType::FLOAT32); + + status = op.execute({ &grad2, &initU, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad2, &initU, &initM }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', { 1, 5 }, { 0.00150986322036664, 0.00108559662275258, 0.00156079502787382, 0.00150778241516558, 0.00130066803775601 }, DataType::FLOAT32); + NDArray stateU2('c', { 1, 5 }, { 0.00094379103011182, 0.00055616352450461, 0.00195809701495322, 0.00181394875731865, 0.00042932879141777 }, DataType::FLOAT32); + NDArray stateM2('c', { 1, 5 }, { 0.14668161064386365, 0.08095989179611204, 0.21840525001287456, 0.20307257843017573, 0.08522395499795674 }, DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(initU.equalsTo(stateU2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAdam3) { + + NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); + NDArray initVC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initMC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + NDArray grad('f', { 1, 5 }, DataType::FLOAT32); + NDArray initV('f', { 1, 5 }, DataType::FLOAT32); + NDArray initM('f', { 1, 5 }, DataType::FLOAT32); + + grad.assign(gradC); + initV.assign(initVC); + initM.assign(initMC); + + sd::ops::adam_updater op; + auto results = op.evaluate({ &grad, &initV, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + + NDArray updateC('c', { 1, 5 }, { 0.00099999968377233, 0.00099999984188614, 0.00099999989459076, 0.00099999992094306, 0.00099999993675445 }, DataType::FLOAT32); + NDArray update('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateV0C('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); + NDArray stateV('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateM0C('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.4999999999999999 }, DataType::FLOAT32); + NDArray stateM('f', { 1, 5 }, DataType::FLOAT32); + + update.assign(updateC); + stateV.assign(stateV0C); + stateM.assign(stateM0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + + results = op.evaluate({ &grad, &stateV, &stateM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update1C('c', { 1, 5 }, { 0.00134383858541481, 0.00134383873569809, 0.00134383878579252, 0.00134383881083974, 0.00134383882586807 }, DataType::FLOAT32); + NDArray stateV1C('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); + NDArray stateM1C('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); + + update.assign(update1C); + stateV.assign(stateV1C); + stateM.assign(stateM1C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + + results = op.evaluate({ &grad, &stateV, &stateM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + + NDArray update2C('c', { 1, 5 }, { 0.00156540157923389, 0.00156540172220632, 0.0015654017698638, 0.00156540179369254, 0.00156540180798979 }, DataType::FLOAT32); + NDArray stateV2C('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); + NDArray stateM2C('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); + + update.assign(update2C); + stateV.assign(stateV2C); + stateM.assign(stateM2C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); +} +// +TEST_F(DeclarableOpsTests18, TestUpdaterAdaBelief1) { + //here is the python code used for generating test numbers + //import numpy as np + //alpha=0.001 + //beta1=0.9 + //beta2=0.999 + //epsilon=1.e-8 + //#https://arxiv.org/pdf/2010.07468.pdf + //def update( t, w, gradW, mt, st): + // mt = beta1* mt + (1- beta1)*gradW + // st = beta2* st + (1- beta2)*((gradW-mt)**2) + epsilon + // mt_corr = mt/(1- beta1**t) + // st_corr = st/(1- beta2**t) + // upW= alpha*(mt_corr/(np.sqrt(st_corr)+epsilon)) + // w = w - upW + // return ( w, upW, mt, st ) + //#if you want to test with more precision np.set_printoptions(precision=9) + //grad = np.array([1,2,3,4,5], dtype = np.float32) + //w=np.zeros(5, dtype = np.float32) + //mt=np.zeros(5, dtype = np.float32) + //st = np.zeros(5, dtype = np.float32) + //for t in range(1,4): + // w, upW, mt, st = update(t,w,grad, mt,st ) + // print(f"---{t}----") + // print(f"update {upW}") + // print(f" s state {st} ") + // print(f" m state {mt} ") + + + NDArray grad('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); + NDArray initU('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + NDArray update('c', { 1, 5 }, DataType::FLOAT32); + + sd::ops::adabelief_updater op; + auto t=0; + Nd4jStatus status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.0011111f, 0.00111111f, 0.00111111f, 0.00111111f, 0.00111111f }, DataType::FLOAT32); + NDArray stateV('c', { 1, 5 }, { 0.00081001f, 0.00324001f, 0.00729001f, 0.01296001f, 0.02025001f }, DataType::FLOAT32); + NDArray stateM0('c', { 1, 5 }, { 0.1f, 0.2f, 0.3f, 0.4f, 0.5f}, DataType::FLOAT32); + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(initU.equalsTo(stateV)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + t=1; + status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { t}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', { 1, 5 }, { 0.001168f, 0.001168f, 0.001168f, 0.001168f, 0.001168f}, DataType::FLOAT32); + NDArray stateV1('c', { 1, 5 }, { 0.00146531f, 0.00586118f, 0.01318763f, 0.02344466f, 0.03663227f }, DataType::FLOAT32); + NDArray stateM1('c', { 1, 5 }, { 0.19f, 0.38f, 0.57000005f, 0.76f, 0.95f }, DataType::FLOAT32); + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(initU.equalsTo(stateV1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + t=2; + status = op.execute({ &grad, &initU, &initM }, { &update, &initU, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, {t}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', { 1, 5 }, { 0.00122557f, 0.00122558f, 0.00122558f, 0.00122558f, 0.00122558f }, DataType::FLOAT32); + NDArray stateV2('c', { 1, 5 }, { 0.0019953f, 0.00798109f, 0.01795742f, 0.03192428f, 0.04988168f }, DataType::FLOAT32); + NDArray stateM2('c', { 1, 5 }, { 0.271f, 0.542f, 0.813f, 1.084f, 1.355f }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(initU.equalsTo(stateV2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAdaDelta1) { + + NDArray grad('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); + NDArray initMsg('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initMsdx('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + NDArray update('c', { 1, 5 }, DataType::FLOAT32); + + sd::ops::ada_delta_updater op; + + Nd4jStatus status = op.execute({ &grad, &initMsg, &initMsdx }, { &update, &initMsg, &initMsdx }, { 0.95f, 1.0e-6 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.00447209123431084, 0.00447212477470162, 0.00447213098596791, 0.00447213315991723, 0.00447213416614627 }, DataType::FLOAT32); + NDArray stateMsg0('c', { 1, 5 }, { 0.05000000000000004, 0.20000000000000018, 0.4500000000000004, 0.8000000000000007, 1.250000000000001 }, DataType::FLOAT32); + NDArray stateMsdx0('c', { 1, 5 }, { 0.0000009999800004, 0.00000099999500002, 0.00000099999777778, 0.00000099999875, 0.0000009999992 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(initMsg.equalsTo(stateMsg0)); + ASSERT_TRUE(initMsdx.equalsTo(stateMsdx0)); + + status = op.execute({ &grad, &initMsg, &initMsdx }, { &update, &initMsg, &initMsdx }, { 0.95f, 1.0e-6 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', { 1, 5 }, { 0.0045290622655332, 0.00452909666868751, 0.00452910303972733, 0.00452910526959756, 0.00452910630171004 }, DataType::FLOAT32); + NDArray stateMsg1('c', { 1, 5 }, { 0.09750000000000009, 0.39000000000000035, 0.8775000000000008, 1.5600000000000014, 2.4375000000000018 }, DataType::FLOAT32); + NDArray stateMsdx1('c', { 1, 5 }, { 0.00000197560125063, 0.00000197563108174, 0.00000197563660612, 0.00000197563853966, 0.00000197563943461 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(initMsg.equalsTo(stateMsg1)); + ASSERT_TRUE(initMsdx.equalsTo(stateMsdx1)); + + status = op.execute({ &grad, &initMsg, &initMsdx }, { &update, &initMsg, &initMsdx }, { 0.95f, 1.0e-6 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', { 1, 5 }, { 0.00456759948242601, 0.00456763438748812, 0.00456764085147516, 0.00456764311387702, 0.004567644161047 }, DataType::FLOAT32); + NDArray stateMsg2('c', { 1, 5 }, { 0.1426250000000001, 0.5705000000000005, 1.2836250000000011, 2.282000000000002, 3.5656250000000025 }, DataType::FLOAT32); + NDArray stateMsdx2('c', { 1, 5 }, { 0.0000029199694397, 0.00000292001372254, 0.00000292002192321, 0.00000292002479346, 0.00000292002612198 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(initMsg.equalsTo(stateMsg2)); + ASSERT_TRUE(initMsdx.equalsTo(stateMsdx2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAdaDelta2) { + + NDArray grad0('c', { 1, 5 }, { 0.22060230374336243, 0.10593396425247192, 0.9027279019355774, 0.831809401512146, 0.2733047902584076 }, DataType::FLOAT32); + NDArray initMsg('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initMsdx('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + auto rho = NDArrayFactory::create(0.95f); + auto epsilon = NDArrayFactory::create(1.0e-6); + + sd::ops::ada_delta_updater op; + + Nd4jStatus status = op.execute({ &grad0, &initMsg, &initMsdx, &rho, &epsilon }, { &grad0, &initMsg, &initMsdx }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.0044712172817412, 0.00446815612502933, 0.00447208107763182, 0.004472071321461, 0.00447153735969189 }, DataType::FLOAT32); + NDArray stateMsg0('c', { 1, 5 }, { 0.00243326882084394, 0.0005611002391122, 0.04074588324665051, 0.03459534402219976, 0.00373477541890961 }, DataType::FLOAT32); + NDArray stateMsdx0('c', { 1, 5 }, { 0.00000099958919903, 0.00000099822095788, 0.00000099997545825, 0.00000099997109521, 0.00000099973231796 }, DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(initMsg.equalsTo(stateMsg0)); + ASSERT_TRUE(initMsdx.equalsTo(stateMsdx0)); + + NDArray grad1('c', { 1, 5 }, { 0.6351608633995056, 0.21878601610660553, 0.6470938920974731, 0.3742971122264862, 0.9453978538513184 }, DataType::FLOAT32); + + status = op.execute({ &grad1, &initMsg, &initMsdx, &rho, &epsilon }, { &grad1, &initMsg, &initMsdx }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', { 1, 5 }, { 0.00598985959779411, 0.00571609509028959, 0.00374704195122062, 0.00265092283150538, 0.00608704322078556 }, DataType::FLOAT32); + NDArray stateMsg1('c', { 1, 5 }, { 0.02248307149952203, 0.00292641126934659, 0.05964511434381081, 0.03987049323214412, 0.0482368917512981 }, DataType::FLOAT32); + NDArray stateMsdx1('c', { 1, 5 }, { 0.00000274353063914, 0.00000258199706405, 0.00000165199285454, 0.00000130134213338, 0.00000280235046064 }, DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(initMsg.equalsTo(stateMsg1)); + ASSERT_TRUE(initMsdx.equalsTo(stateMsdx1)); + + NDArray grad2('c', { 1, 5 }, { 0.8484492301940918, 0.9634076952934265, 0.6676893830299377, 0.4450211524963379, 0.32364124059677124 }, DataType::FLOAT32); + + status = op.execute({ &grad2, &initMsg, &initMsdx, &rho, &epsilon }, { &grad2, &initMsg, &initMsdx }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', { 1, 5 }, { 0.00685468722145889, 0.00822128238053265, 0.00386965914609878, 0.00308849888680941, 0.00279277397245112 }, DataType::FLOAT32); + NDArray stateMsg2('c', { 1, 5 }, { 0.05735222273539331, 0.04918781007340889, 0.07895331423716523, 0.04777915987899536, 0.05106222979448406 }, DataType::FLOAT32); + NDArray stateMsdx2('c', { 1, 5 }, { 0.00000495569095238, 0.00000583237140987, 0.00000231810630717, 0.0000017132162954, 0.00000305221226067 }, DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(initMsg.equalsTo(stateMsg2)); + ASSERT_TRUE(initMsdx.equalsTo(stateMsdx2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAdaDelta3) { + + NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); + NDArray initVC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); // Msg + NDArray initMC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); // Msdx + + NDArray grad('f', { 1, 5 }, DataType::FLOAT32); + NDArray initMsg('f', { 1, 5 }, DataType::FLOAT32); + NDArray initMsdx('f', { 1, 5 }, DataType::FLOAT32); + + grad.assign(gradC); + initMsg.assign(initVC); + initMsdx.assign(initMC); + + sd::ops::ada_delta_updater op; + auto results = op.evaluate({ &grad, &initMsg, &initMsdx }, { 0.95f, 1.0e-6 }, { }); + + NDArray updateC('c', { 1, 5 }, { 0.00447209123431084, 0.00447212477470162, 0.00447213098596791, 0.00447213315991723, 0.00447213416614627 }, DataType::FLOAT32); + NDArray update('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateV0C('c', { 1, 5 }, { 0.05000000000000004, 0.20000000000000018, 0.4500000000000004, 0.8000000000000007, 1.250000000000001 }, DataType::FLOAT32); + NDArray stateMsg('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateM0C('c', { 1, 5 }, { 0.0000009999800004, 0.00000099999500002, 0.00000099999777778, 0.00000099999875, 0.0000009999992 }, DataType::FLOAT32); + NDArray stateMsdx('f', { 1, 5 }, DataType::FLOAT32); + + update.assign(updateC); + stateMsg.assign(stateV0C); + stateMsdx.assign(stateM0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateMsg.isSameShape(results.at(1))); + ASSERT_TRUE(stateMsg.equalsTo(results.at(1))); + ASSERT_TRUE(stateMsdx.isSameShape(results.at(2))); + ASSERT_TRUE(stateMsdx.equalsTo(results.at(2))); + + results = op.evaluate({ &grad, results.at(1), results.at(2) }, { 0.95, 1.0e-6 }, { }); + + NDArray update1C('c', { 1, 5 }, { 0.0045290622655332, 0.00452909666868751, 0.00452910303972733, 0.00452910526959756, 0.00452910630171004 }, DataType::FLOAT32); + + NDArray stateV1C('c', { 1, 5 }, { 0.09750000000000009, 0.39000000000000035, 0.8775000000000008, 1.5600000000000014, 2.4375000000000018 }, DataType::FLOAT32); + NDArray stateM1C('c', { 1, 5 }, { 0.00000197560125063, 0.00000197563108174, 0.00000197563660612, 0.00000197563853966, 0.00000197563943461 }, DataType::FLOAT32); + + update.assign(update1C); + stateMsg.assign(stateV1C); + stateMsdx.assign(stateM1C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateMsg.isSameShape(results.at(1))); + ASSERT_TRUE(stateMsg.equalsTo(results.at(1))); + ASSERT_TRUE(stateMsdx.isSameShape(results.at(2))); + ASSERT_TRUE(stateMsdx.equalsTo(results.at(2))); + + results = op.evaluate({ &grad, &stateMsg, &stateMsdx }, { 0.95f, 1.0e-6 }, { }); + + NDArray update2C('c', { 1, 5 }, { 0.00456759948242601, 0.00456763438748812, 0.00456764085147516, 0.00456764311387702, 0.004567644161047 }, DataType::FLOAT32); + NDArray stateV2C('c', { 1, 5 }, { 0.1426250000000001, 0.5705000000000005, 1.2836250000000011, 2.282000000000002, 3.5656250000000025 }, DataType::FLOAT32); + NDArray stateM2C('c', { 1, 5 }, { 0.0000029199694397, 0.00000292001372254, 0.00000292002192321, 0.00000292002479346, 0.00000292002612198 }, DataType::FLOAT32); + + update.assign(update2C); + stateMsg.assign(stateV2C); + stateMsdx.assign(stateM2C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateMsg.isSameShape(results.at(1))); + ASSERT_TRUE(stateMsg.equalsTo(results.at(1))); + ASSERT_TRUE(stateMsdx.isSameShape(results.at(2))); + ASSERT_TRUE(stateMsdx.equalsTo(results.at(2))); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterNadam1) { + + NDArray grad('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); + NDArray initV('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + NDArray update('c', { 1, 5 }, DataType::FLOAT32); + + sd::ops::nadam_updater op; + + Nd4jStatus status = op.execute({ &grad, &initV, &initM }, { &update, &initV, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.06008325654320519, 0.06008326604320069, 0.06008326920986652, 0.06008327079319956, 0.0600832717431994 }, DataType::FLOAT32); + NDArray stateV('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); + NDArray stateM0('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.499999999999999 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(initV.equalsTo(stateV)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + status = op.execute({ &grad, &initV, &initM }, { &update, &initV, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', { 1, 5 }, { 0.06061258367739481, 0.06061259045578174, 0.06061259271524436, 0.06061259384497576, 0.06061259452281461 }, DataType::FLOAT32); + NDArray stateV1('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); + NDArray stateM1('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(initV.equalsTo(stateV1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + status = op.execute({ &grad, &initV, &initM }, { &update, &initV, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', { 1, 5 }, { 0.06281865774973168, 0.06281866348713228, 0.06281866539959938, 0.06281866635583296, 0.06281866692957314 }, DataType::FLOAT32); + NDArray stateV2('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); + NDArray stateM2('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(initV.equalsTo(stateV2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterNadam2) { + + NDArray grad0('c', { 1, 5 }, { 0.8047558665275574, 0.9653639197349548, 0.31240877509117126, 0.9530212879180908, 0.01295729912817478 }, DataType::FLOAT32); + NDArray initV('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.001f); + auto beta1 = NDArrayFactory::create(0.9f); + auto beta2 = NDArrayFactory::create(0.999f); + auto epsilon = NDArrayFactory::create(1.0e-8); + + sd::ops::nadam_updater op; + + Nd4jStatus status = op.execute({ &grad0, &initV, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad0, &initV, &initM }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.06008325193356386, 0.0600832558615088, 0.06008321472550684, 0.06008325560661022, 0.0600818092240132 }, DataType::FLOAT32); + NDArray stateV0('c', { 1, 5 }, { 0.00064763200471052, 0.00093192749752604, 0.00009759924275397, 0.00090824957522506, 0.0000001678916007 }, DataType::FLOAT32); + NDArray stateM0('c', { 1, 5 }, { 0.08047558665275573, 0.09653639197349546, 0.03124087750911712, 0.09530212879180906, 0.00129572991281748 }, DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(initV.equalsTo(stateV0)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + NDArray grad1('c', { 1, 5 }, { 0.9839006662368774, 0.8964805603027344, 0.3631269931793213, 0.00931886397302151, 0.6320028901100159 }, DataType::FLOAT32); + + status = op.execute({ &grad1, &initV, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad1, &initV, &initM }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', { 1, 5 }, { 0.06273730114378717, 0.0596708938019245, 0.06226533928512862, 0.02621380498466489, 0.06059567064824535 }, DataType::FLOAT32); + NDArray stateV1('c', { 1, 5 }, { 0.00161504489372718, 0.00173467296502922, 0.00022936285668667, 0.00090742816687558, 0.0003995953768165 }, DataType::FLOAT32); + NDArray stateM1('c', { 1, 5 }, { 0.17081809461116787, 0.17653080880641933, 0.06442948907613753, 0.08670380230993031, 0.06436644593253729 }, DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(initV.equalsTo(stateV1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + NDArray grad2('c', { 1, 5 }, { 0.7712154984474182, 0.1282273381948471, 0.7019220590591431, 0.8883536458015442, 0.33057701587677 }, DataType::FLOAT32); + + status = op.execute({ &grad2, &initV, &initM, &lr, &beta1, &beta2, &epsilon }, { &grad2, &initV, &initM }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', { 1, 5 }, { 0.06062658222261493, 0.04001212712739213, 0.06906390273197544, 0.05804376499107734, 0.05097529565845974 }, DataType::FLOAT32); + NDArray stateV2('c', { 1, 5 }, { 0.00220820319387896, 0.00174938054232472, 0.00072182807082381, 0.0016956929387176, 0.00050847694486568 }, DataType::FLOAT32); + NDArray stateM2('c', { 1, 5 }, { 0.2308578349947929, 0.1717004617452621, 0.12817874607443808, 0.16686878665909166, 0.09098750292696056 }, DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(initV.equalsTo(stateV2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterNadam3) { + + NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); + NDArray initVC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initMC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + NDArray grad('f', { 1, 5 }, DataType::FLOAT32); + NDArray initV('f', { 1, 5 }, DataType::FLOAT32); + NDArray initM('f', { 1, 5 }, DataType::FLOAT32); + + grad.assign(gradC); + initV.assign(initVC); + initM.assign(initMC); + + sd::ops::nadam_updater op; + auto results = op.evaluate({ &grad, &initV, &initM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + + NDArray updateC('c', { 1, 5 }, { 0.06008325654320519, 0.06008326604320069, 0.06008326920986652, 0.06008327079319956, 0.0600832717431994 }, DataType::FLOAT32); + NDArray update('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateV0C('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); + NDArray stateV('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateM0C('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.499999999999999 }, DataType::FLOAT32); + NDArray stateM('f', { 1, 5 }, DataType::FLOAT32); + + update.assign(updateC); + stateV.assign(stateV0C); + stateM.assign(stateM0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + + results = op.evaluate({ &grad, &stateV, &stateM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update1C('c', { 1, 5 }, { 0.06061258367739481, 0.06061259045578174, 0.06061259271524436, 0.06061259384497576, 0.06061259452281461 }, DataType::FLOAT32); + NDArray stateV1C('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); + NDArray stateM1C('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); + + update.assign(update1C); + stateV.assign(stateV1C); + stateM.assign(stateM1C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + + results = op.evaluate({ &grad, &stateV, &stateM }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + + NDArray update2C('c', { 1, 5 }, { 0.06281865774973168, 0.06281866348713228, 0.06281866539959938, 0.06281866635583296, 0.06281866692957314 }, DataType::FLOAT32); + NDArray stateV2C('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); + NDArray stateM2C('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); + + update.assign(update2C); + stateV.assign(stateV2C); + stateM.assign(stateM2C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAmsGrad1) { + + NDArray grad('c', { 1, 5 }, { 1,2,3,4,5 }, DataType::FLOAT32); + NDArray initV('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initH('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + NDArray update('c', { 1, 5 }, DataType::FLOAT32); + + sd::ops::ams_grad_updater op; + + Nd4jStatus status = op.execute({ &grad, &initV, &initM, &initH }, { &update, &initV, &initM, &initH }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.00099999968377233, 0.00099999984188614, 0.00099999989459076, 0.00099999992094306, 0.00099999993675445 }, DataType::FLOAT32); + NDArray stateV0('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); + NDArray stateH0('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); + NDArray stateM0('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.4999999999999999 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp0)); + ASSERT_TRUE(initV.equalsTo(stateV0)); + ASSERT_TRUE(initH.equalsTo(stateH0)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + status = op.execute({ &grad, &initV, &initM, &initH }, { &update, &initV, &initM, &initH }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', { 1, 5 }, { 0.00134383858541481, 0.00134383873569809, 0.00134383878579252, 0.00134383881083974, 0.00134383882586807 }, DataType::FLOAT32); + NDArray stateV1('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); + NDArray stateH1('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); + NDArray stateM1('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp1)); + ASSERT_TRUE(initV.equalsTo(stateV1)); + ASSERT_TRUE(initH.equalsTo(stateH1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + status = op.execute({ &grad, &initV, &initM, &initH }, { &update, &initV, &initM, &initH }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', { 1, 5 }, { 0.00156540157923389, 0.00156540172220632, 0.0015654017698638, 0.00156540179369254, 0.00156540180798979 }, DataType::FLOAT32); + NDArray stateV2('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); + NDArray stateH2('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); + NDArray stateM2('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); + + ASSERT_TRUE(update.equalsTo(updateExp2)); + ASSERT_TRUE(initV.equalsTo(stateV2)); + ASSERT_TRUE(initH.equalsTo(stateH2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAmsGrad2) { + + NDArray grad0('c', { 1, 5 }, { 0.5730348229408264, 0.04330538213253021, 0.249028742313385, 0.6514443755149841, 0.7017051577568054 }, DataType::FLOAT32); + NDArray initH('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initV('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initM('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + auto lr = NDArrayFactory::create(0.001f); + auto beta1 = NDArrayFactory::create(0.9f); + auto beta2 = NDArrayFactory::create(0.999f); + auto epsilon = NDArrayFactory::create(1.0e-8); + + sd::ops::ams_grad_updater op; + + Nd4jStatus status = op.execute({ &grad0, &initV, &initM, &initH, &lr, &beta1, &beta2, &epsilon }, { &grad0, &initV, &initM, &initH }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp0('c', { 1, 5 }, { 0.00099999944815292, 0.00099999269777932, 0.00099999873015716, 0.00099999951457465, 0.00099999954934402 }, DataType::FLOAT32); + NDArray stateV0('c', { 1, 5 }, { 0.00032836890830282, 0.00000187535612164, 0.00006201531449819, 0.00042437977439011, 0.0004923901284225 }, DataType::FLOAT32); + NDArray stateH0('c', { 1, 5 }, { 0.00032836890830282, 0.00000187535612164, 0.00006201531449819, 0.00042437977439011, 0.00049239012842255 }, DataType::FLOAT32); + NDArray stateM0('c', { 1, 5 }, { 0.05730348229408263, 0.00433053821325302, 0.0249028742313385, 0.0651444375514984, 0.07017051577568052 }, DataType::FLOAT32); + + ASSERT_TRUE(grad0.equalsTo(updateExp0)); + ASSERT_TRUE(initV.equalsTo(stateV0)); + ASSERT_TRUE(initH.equalsTo(stateH0)); + ASSERT_TRUE(initM.equalsTo(stateM0)); + + NDArray grad1('c', { 1, 5 }, { 0.6404328346252441, 0.9432603120803833, 0.45608729124069214, 0.9097326993942261, 0.748093843460083 }, DataType::FLOAT32); + + status = op.execute({ &grad1, &initV, &initM, &initH, &lr, &beta1, &beta2, &epsilon }, { &grad1, &initV, &initM, &initH }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp1('c', { 1, 5 }, { 0.00134565543815267, 0.00104022434054697, 0.00130914539820157, 0.00133725290576052, 0.0013453914974122 }, DataType::FLOAT32); + NDArray stateV1('c', { 1, 5 }, { 0.00073819475506065, 0.00089161349711151, 0.00026996891641496, 0.00125156897896282, 0.00105154213691696 }, DataType::FLOAT32); + NDArray stateH1('c', { 1, 5 }, { 0.00073819475506065, 0.00089161349711151, 0.00026996891641496, 0.00125156897896282, 0.00105154213691696 }, DataType::FLOAT32); + NDArray stateM1('c', { 1, 5 }, { 0.11561641752719877, 0.09822351559996603, 0.06802131593227385, 0.14960326373577115, 0.13796284854412078 }, DataType::FLOAT32); + + ASSERT_TRUE(grad1.equalsTo(updateExp1)); + ASSERT_TRUE(initV.equalsTo(stateV1)); + ASSERT_TRUE(initH.equalsTo(stateH1)); + ASSERT_TRUE(initM.equalsTo(stateM1)); + + NDArray grad2('c', { 1, 5 }, { 0.46250319480895996, 0.09698919206857681, 0.21754667162895203, 0.46824514865875244, 0.6005083918571472 }, DataType::FLOAT32); + + status = op.execute({ &grad2, &initV, &initM, &initH, &lr, &beta1, &beta2, &epsilon }, { &grad2, &initV, &initM, &initH }, { }, { }); + ASSERT_EQ(ND4J_STATUS_OK, status); + + NDArray updateExp2('c', { 1, 5 }, { 0.00154098993679222, 0.00103399135000281, 0.00147364850040774, 0.00149693641196572, 0.00155078467854623 }, DataType::FLOAT32); + NDArray stateV2('c', { 1, 5 }, { 0.00095136576551408, 0.00090012878699251, 0.00031702550183538, 0.00146957092922632, 0.0014111009234709 }, DataType::FLOAT32); + NDArray stateH2('c', { 1, 5 }, { 0.00095136576551408, 0.00090012878699251, 0.00031702550183538, 0.00146957092922632, 0.0014111009234709 }, DataType::FLOAT32); + NDArray stateM2('c', { 1, 5 }, { 0.1503050952553749, 0.09810008324682712, 0.08297385150194167, 0.1814674522280693, 0.1842174028754234 }, DataType::FLOAT32); + + ASSERT_TRUE(grad2.equalsTo(updateExp2)); + ASSERT_TRUE(initV.equalsTo(stateV2)); + ASSERT_TRUE(initH.equalsTo(stateH2)); + ASSERT_TRUE(initM.equalsTo(stateM2)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests18, TestUpdaterAmsGrad3) { + + NDArray gradC('c', { 1, 5 }, { 1, 2, 3, 4, 5 }, DataType::FLOAT32); + NDArray initVC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initMC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + NDArray initHC('c', { 1, 5 }, { 0.0, 0.0, 0.0, 0.0, 0.0 }, DataType::FLOAT32); + + NDArray grad('f', { 1, 5 }, DataType::FLOAT32); + NDArray initV('f', { 1, 5 }, DataType::FLOAT32); + NDArray initM('f', { 1, 5 }, DataType::FLOAT32); + NDArray initH('f', { 1, 5 }, DataType::FLOAT32); + + grad.assign(gradC); + initV.assign(initVC); + initM.assign(initMC); + initH.assign(initHC); + + sd::ops::ams_grad_updater op; + auto results = op.evaluate({ &grad, &initV, &initM, &initH }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + + NDArray updateC('c', { 1, 5 }, { 0.00099999968377233, 0.00099999984188614, 0.00099999989459076, 0.00099999992094306, 0.00099999993675445 }, DataType::FLOAT32); + NDArray update('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateV0C('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); + NDArray stateV('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateM0C('c', { 1, 5 }, { 0.09999999999999998, 0.19999999999999996, 0.29999999999999993, 0.3999999999999999, 0.4999999999999999 }, DataType::FLOAT32); + NDArray stateM('f', { 1, 5 }, DataType::FLOAT32); + + NDArray stateH0C('c', { 1, 5 }, { 0.001, 0.004, 0.00900000000000001, 0.01600000000000001, 0.02500000000000002 }, DataType::FLOAT32); + NDArray stateH('f', { 1, 5 }, DataType::FLOAT32); + + update.assign(updateC); + stateV.assign(stateV0C); + stateM.assign(stateM0C); + stateH.assign(stateH0C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + ASSERT_TRUE(stateH.isSameShape(results.at(3))); + ASSERT_TRUE(stateH.equalsTo(results.at(3))); + + results = op.evaluate({ &grad, &stateV, &stateM, &stateH }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + NDArray update1C('c', { 1, 5 }, { 0.00134383858541481, 0.00134383873569809, 0.00134383878579252, 0.00134383881083974, 0.00134383882586807 }, DataType::FLOAT32); + NDArray stateV1C('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); + NDArray stateM1C('c', { 1, 5 }, { 0.18999999999999995, 0.3799999999999999, 0.5699999999999998, 0.7599999999999998, 0.9499999999999997 }, DataType::FLOAT32); + NDArray stateH1C('c', { 1, 5 }, { 0.001999, 0.00799600000000001, 0.01799100000000001, 0.03198400000000003, 0.04997500000000005 }, DataType::FLOAT32); + + + update.assign(update1C); + stateV.assign(stateV1C); + stateM.assign(stateM1C); + stateH.assign(stateH1C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + ASSERT_TRUE(stateH.isSameShape(results.at(3))); + ASSERT_TRUE(stateH.equalsTo(results.at(3))); + + results = op.evaluate({ &grad, &stateV, &stateM, &stateH }, { 0.001f, 0.9f, 0.999f, 1.0e-8 }, { }); + + + NDArray update2C('c', { 1, 5 }, { 0.00156540157923389, 0.00156540172220632, 0.0015654017698638, 0.00156540179369254, 0.00156540180798979 }, DataType::FLOAT32); + NDArray stateV2C('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); + NDArray stateM2C('c', { 1, 5 }, { 0.2709999999999999, 0.5419999999999998, 0.8129999999999998, 1.0839999999999996, 1.3549999999999995 }, DataType::FLOAT32); + NDArray stateH2C('c', { 1, 5 }, { 0.002997001, 0.01198800400000001, 0.02697300900000002, 0.04795201600000004, 0.07492502500000006 }, DataType::FLOAT32); + + + update.assign(update2C); + stateV.assign(stateV2C); + stateM.assign(stateM2C); + stateH.assign(stateH2C); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(update.isSameShape(results.at(0))); + ASSERT_TRUE(update.equalsTo(results.at(0))); + ASSERT_TRUE(stateV.isSameShape(results.at(1))); + ASSERT_TRUE(stateV.equalsTo(results.at(1))); + ASSERT_TRUE(stateM.isSameShape(results.at(2))); + ASSERT_TRUE(stateM.equalsTo(results.at(2))); + ASSERT_TRUE(stateH.isSameShape(results.at(3))); + ASSERT_TRUE(stateH.equalsTo(results.at(3))); +} diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests19.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests19.cpp new file mode 100644 index 000000000..0f2c73aa8 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests19.cpp @@ -0,0 +1,427 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include + + +using namespace sd; + + +class DeclarableOpsTests19 : public testing::Test { +public: + + DeclarableOpsTests19() { + printf("\n"); + fflush(stdout); + } +}; + + +TEST_F(DeclarableOpsTests19, test_argmax_maxint_vector_1) { + auto x = NDArrayFactory::create('c', {3}, {0.1f, 0.5f, 0.7f}); + auto z = NDArrayFactory::create(0); + auto e = NDArrayFactory::create(2); + + sd::ops::argmax op; + auto status = op.execute({&x}, {&z}, {DataTypeUtils::max()}); + ASSERT_EQ(Status::OK(), status); + ASSERT_EQ(e, z); +} + + +TEST_F(DeclarableOpsTests19, test_threshold_encode_1) { + auto x = NDArrayFactory::create('c', {3}, {1.5, 2.5, -3.5}); + auto exp_encoded = NDArrayFactory::create('c', {7}, {3, 3, 1056964608, 0, 1, 2, -3}); + auto exp_gradients = NDArrayFactory::create('c', {3}, {1.0, 2.0, -3.0}); + + sd::ops::encode_threshold op; + auto result = op.evaluate({&x}, {0.5}); + + auto gradients = result.at(0); + auto encoded = result.at(1); + + //encoded->printIndexedBuffer("ENC"); + + ASSERT_EQ(exp_encoded, *encoded); + ASSERT_EQ(exp_gradients, x); + + // FIXME: we need to add a way to declare individual inplace outputs + //ASSERT_EQ(exp_gradients, *gradients); +} + +TEST_F(DeclarableOpsTests19, test_threshold_encode_2) { + for (int length = 5; length < 35; length++) { + auto x = NDArrayFactory::create('c', {10000}); + auto exp_gradients = NDArrayFactory::create('c', {10000}); + + for (int e = 0; e < length; e++) { + x.p(e, 2e-3); + exp_gradients.p(e, 1e-3); + } + + sd::ops::encode_threshold op; + auto result = op.evaluate({&x}, {1e-3}); + + auto encoded = result.at(1); + + ASSERT_EQ(length + 4, encoded->lengthOf()); + ASSERT_EQ(exp_gradients, x); + } +} + +TEST_F(DeclarableOpsTests19, test_threshold_encode_boundary_1) { + auto x = NDArrayFactory::create('c', {6}); + x = 1.0f; + + sd::ops::encode_threshold op; + auto result = op.evaluate({&x}, {1.0}, {3}); + + auto gradients = result.at(0); + auto encoded = result.at(1); + + ASSERT_EQ(7, encoded->lengthOf()); + ASSERT_EQ(3, x.sumNumber().e(0)); +} + +TEST_F(DeclarableOpsTests19, test_threshold_encode_boundary_2) { + auto x = NDArrayFactory::create('c', {1000}); + x = 1.0f; + + sd::ops::encode_threshold op; + auto result = op.evaluate({&x}, {1.0}, {100}); + + auto gradients = result.at(0); + auto encoded = result.at(1); + + ASSERT_EQ(104, encoded->lengthOf()); + + ASSERT_EQ(900, x.sumNumber().e(0)); +} + +TEST_F(DeclarableOpsTests19, test_threshold_decode_1) { + auto x = NDArrayFactory::create('c', {3}, {1.0, 2.0, -3.0}); + auto y = NDArrayFactory::create('c', {7}, {3, 3, 1056964608, 0, 1, 2, -3}); + auto exp_gradients = NDArrayFactory::create('c', {3}, {1.5, 2.5, -3.5}); + + sd::ops::decode_threshold op; + auto status = op.execute({&x, &y}, {&x}); + ASSERT_EQ(Status::OK(), status); + ASSERT_EQ(exp_gradients, x); +} + +TEST_F(DeclarableOpsTests19, test_bitmap_encode_1) { + auto initial = NDArrayFactory::create('c', {6}, {0.0f, 0.0f, 1e-3f, -1e-3f, 0.0f, 0.0f}); + auto exp_0 = initial.like(); + auto exp_1 = initial.dup(); + auto exp_c = NDArrayFactory::create(2L); + + sd::ops::encode_bitmap enc; + auto enc_result = enc.evaluate({&initial}, {1e-3f}); + ASSERT_EQ(Status::OK(), enc_result.status()); + + //initial.printIndexedBuffer("initial"); + ASSERT_EQ(exp_0, initial); + + auto encoded = enc_result.at(1); + auto counter = enc_result.at(2); + + //encoded->printIndexedBuffer("encoded"); + + ASSERT_EQ(exp_c, *counter); + + sd::ops::decode_bitmap dec; + auto status = dec.execute({&initial, encoded}, {&initial}); + ASSERT_EQ(Status::OK(), status); + + + //initial.printIndexedBuffer(); + + ASSERT_EQ(exp_1, initial); +} + +TEST_F(DeclarableOpsTests19, test_bitmap_encode_decode) { + auto initial = NDArrayFactory::create('c', {256000}); + initial = 1.0f; + auto exp = initial.dup(); + auto neg = initial.like(); + neg = 0.5f; + + sd::ops::encode_bitmap enc; + auto enc_result = enc.evaluate({&initial}, {0.5f}); + auto encoded = enc_result.at(1); + + // checking equality of all encoded bits + for (int e = 5; e < encoded->lengthOf() - 1; e++) { + if (encoded->e(e) != encoded->e(e - 1)) + nd4j_printf("Non equal encoded values at E[%i]: %i;\n", e, encoded->e(e)); + } + + ASSERT_NE(exp, initial); + ASSERT_EQ(neg, initial); + + sd::ops::decode_bitmap dec; + auto status = dec.execute({&initial, encoded}, {&initial}); + ASSERT_EQ(Status::OK(), status); + + // checking equality of all dedoded bits + for (int e = 0; e < initial.lengthOf(); e++) { + auto f = initial.e(e); + if (f != 1.0f) + nd4j_printf("initial[%i] = %f\n", e, f); + } + + + ASSERT_EQ(exp, initial); +} + +TEST_F(DeclarableOpsTests19, test_threshold_encode_decode) { + auto initial = NDArrayFactory::create('c', {256000}); + initial = 1.0f; + auto exp = initial.dup(); + auto neg = initial.like(); + neg = 0.5f; + + sd::ops::encode_threshold enc; + auto enc_result = enc.evaluate({&initial}, {0.5f}); + auto encoded = enc_result.at(1); + + ASSERT_EQ(256000 + 4, encoded->lengthOf()); + ASSERT_NE(exp, initial); + + for (int e = 0; e < initial.lengthOf(); e++) { + auto f = initial.e(e); + if (f != 0.5f) { + nd4j_printf("initial[%i] = %f\n", e, f); + throw std::runtime_error(""); + } + } + ASSERT_EQ(neg, initial); + + // checking equality of all encoded bits + //for (int e = 5; e < encoded->lengthOf() - 1; e++) { + //if (encoded->e(e) != encoded->e(e - 1) + 1) + //nd4j_printf("Non equal encoded values at E[%i]: %i;\n", e, encoded->e(e)); + //} + + sd::ops::decode_threshold dec; + auto status = dec.execute({&initial, encoded}, {&initial}); + ASSERT_EQ(Status::OK(), status); + + // checking equality of all dedoded bits + for (int e = 0; e < initial.lengthOf(); e++) { + auto f = initial.e(e); + if (f != 1.0f) + nd4j_printf("initial[%i] = %f\n", e, f); + } + + ASSERT_EQ(exp, initial); +} + +#ifdef _RELEASE +TEST_F(DeclarableOpsTests19, test_threshold_encode_decode_2) { + // [2,1,135079944,1,1,8192,1,99] + constexpr int sizeX= 10*1000*1000; + auto initial = NDArrayFactory::create('c', {1, sizeX}); + initial = 1.0f; + auto exp = initial.dup(); + auto neg = initial.like(); + neg = 0.5f; + + sd::ops::encode_threshold enc; + auto enc_result = enc.evaluate({&initial}, {0.5f}); + auto encoded = enc_result.at(1); + + ASSERT_EQ(sizeX + 4, encoded->lengthOf()); + ASSERT_NE(exp, initial); +/* + for (int e = 0; e < initial.lengthOf(); e++) { + auto f = initial.e(e); + if (f != 0.5f) { + nd4j_printf("initial[%i] = %f\n", e, f); + throw std::runtime_error(""); + } + } + */ + ASSERT_EQ(neg, initial); + + // checking equality of all encoded bits + //for (int e = 5; e < encoded->lengthOf() - 1; e++) { + //if (encoded->e(e) != encoded->e(e - 1) + 1) + //nd4j_printf("Non equal encoded values at E[%i]: %i;\n", e, encoded->e(e)); + //} + + sd::ops::decode_threshold dec; + auto status = dec.execute({&initial, encoded}, {&initial}); + ASSERT_EQ(Status::OK(), status); + + // checking equality of all dedoded bits + /* + for (int e = 0; e < initial.lengthOf(); e++) { + auto f = initial.e(e); + if (f != 1.0f) + nd4j_printf("initial[%i] = %f\n", e, f); + } + */ + + ASSERT_EQ(exp, initial); +} +#endif + + + +TEST_F(DeclarableOpsTests19, test_matmul_ccc) { + auto x = NDArrayFactory::create('c', {10, 10}); + auto y = NDArrayFactory::create('c', {10, 10}); + auto e = NDArrayFactory::create('c', {10, 10}); + auto z = NDArrayFactory::create('c', {10, 10}); + + z.assign(100.f); + e.assign(110.f); + x.assign(1.0f); + y.assign(1.0f); + + sd::ops::matmul op; + auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); +} + +TEST_F(DeclarableOpsTests19, test_matmul_fcf) { + auto x = NDArrayFactory::create('f', {10, 10}); + auto y = NDArrayFactory::create('c', {10, 10}); + auto e = NDArrayFactory::create('f', {10, 10}); + auto z = NDArrayFactory::create('f', {10, 10}); + + z.assign(100.f); + e.assign(110.f); + x.assign(1.0f); + y.assign(1.0f); + + sd::ops::matmul op; + auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); +} + +TEST_F(DeclarableOpsTests19, test_matmul_cff) { + auto x = NDArrayFactory::create('c', {10, 10}); + auto y = NDArrayFactory::create('f', {10, 10}); + auto e = NDArrayFactory::create('f', {10, 10}); + auto z = NDArrayFactory::create('f', {10, 10}); + + z.assign(100.f); + e.assign(110.f); + x.assign(1.0f); + y.assign(1.0f); + + sd::ops::matmul op; + auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); +} + + +TEST_F(DeclarableOpsTests19, test_matmul_ccf) { + auto x = NDArrayFactory::create('c', {10, 10}); + auto y = NDArrayFactory::create('c', {10, 10}); + auto e = NDArrayFactory::create('f', {10, 10}); + auto z = NDArrayFactory::create('f', {10, 10}); + + z.assign(100.f); + e.assign(110.f); + x.assign(1.0f); + y.assign(1.0f); + + sd::ops::matmul op; + auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); +} + +TEST_F(DeclarableOpsTests19, test_matmul_fff) { + auto x = NDArrayFactory::create('f', {10, 10}); + auto y = NDArrayFactory::create('f', {10, 10}); + auto e = NDArrayFactory::create('f', {10, 10}); + auto z = NDArrayFactory::create('f', {10, 10}); + + z.assign(100.f); + e.assign(110.f); + x.assign(1.0f); + y.assign(1.0f); + + sd::ops::matmul op; + auto status = op.execute({&x, &y}, {&z}, {1.0, 1.0}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); +} + +TEST_F(DeclarableOpsTests19, test_conv1d_bp_1) { + /* + DynamicCustomOp op = DynamicCustomOp.builder("conv1d_bp") + .addInputs( + Nd4j.create(DataType.FLOAT, 2,2,12), + Nd4j.create(DataType.FLOAT, 3,2,3), + Nd4j.create(DataType.FLOAT, 2,3,6) + ) + .addOutputs( + Nd4j.create(DataType.FLOAT, 2,2,12), + Nd4j.create(DataType.FLOAT, 3,2,3)) + .addIntegerArguments(3,2,0,1,2,0) + .build(); + + Nd4j.exec(op); + */ + + auto t = NDArrayFactory::create('c', {2, 2, 12}); + auto u = NDArrayFactory::create('c', {3, 2, 3}); + auto v = NDArrayFactory::create('c', {2, 3, 6}); + + sd::ops::conv1d_bp op; + auto result = op.evaluate({&t, &u, &v}, {3, 2, 0, 1, 2,0}); + ASSERT_EQ(Status::OK(), result.status()); + +} + +TEST_F(DeclarableOpsTests19, test_squeeze_1) { + auto x = NDArrayFactory::create('c', {3, 4, 1}); + auto e = NDArrayFactory::create('c', {3, 4}); + int axis = 2; + + sd::ops::squeeze op; + auto status = op.execute({&x}, {&e}, {axis}); + ASSERT_EQ(Status::OK(), status); +} + diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests2.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests2.cpp new file mode 100644 index 000000000..a7fefe18f --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests2.cpp @@ -0,0 +1,4487 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +#include "testlayers.h" +#include +#include +#include +#include +#include +using namespace sd; +using namespace sd::graph; + +class DeclarableOpsTests2 : public testing::Test { +public: + + DeclarableOpsTests2() { + printf("\n"); + } +}; + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, gather_1) { + + NDArray input('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}, sd::DataType::FLOAT32); + NDArray indices('c', {1,6}, {0,1, 2,2, 1,2}, sd::DataType::INT32); + NDArray expected('c', {2,1,6,4}, {1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 9,10,11,12, 5, 6, 7, 8, 9,10,11,12, 13,14,15,16, 17,18,19,20, 21,22,23,24, 21,22,23,24, 17,18,19,20, 21,22,23,24}, sd::DataType::FLOAT32); + + sd::ops::gather op; + + auto result = op.evaluate({&input, &indices}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto* output = result.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +TEST_F(DeclarableOpsTests2, gather_2) { + + NDArray input('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}); + //auto indices ('c', {1,6}, {0,1, 2,2, 1,2}); + NDArray expected('c', {2,6,4}, {1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 9,10,11,12, 5, 6, 7, 8, 9,10,11,12, 13,14,15,16, 17,18,19,20, 21,22,23,24, 21,22,23,24, 17,18,19,20, 21,22,23,24}); + + sd::ops::gather op; + + auto result = op.evaluate({&input}, {}, {1, 0,1, 2,2, 1,2}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto* output = result.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, gather_3) { + + NDArray input ('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}); + NDArray indices ('c', {1,1}, std::vector{2}, sd::DataType::INT32); + NDArray expected('c', {2,1,1,4}, {9,10,11,12,21,22,23,24}); + + sd::ops::gather op; + + auto result = op.evaluate({&input, &indices}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto* output = result.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +TEST_F(DeclarableOpsTests2, gather_4) { + + NDArray input('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}); + //auto indices ('c', {1,1}, {2}); + NDArray expected('c', {2,4}, {9,10,11,12,21,22,23,24}); + + sd::ops::gather op; + + auto result = op.evaluate({&input}, {}, {1, 2}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto* output = result.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, gather_5) { + + NDArray input ('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}); + NDArray indices ('c', {2,3}, {0, 1, 2, 2, 1,2}, sd::DataType::INT32); + NDArray expected('c', {2,2,3,4}, {1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 9,10,11,12, 5, 6, 7, 8, 9,10,11,12, 13,14,15,16,17,18,19,20,21,22,23,24, 21,22,23,24,17,18,19,20,21,22,23,24}); + + sd::ops::gather op; + + auto result = op.evaluate({&input, &indices}, {}, {1}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto* output = result.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, gather_6) { + + NDArray input ('c', {3,3,4}, {1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 13,14,15,16,17,18,19,20,21,22,23,24, 25,26,27,28,29,30,31,32,33,34,35,36}); + NDArray indices ('c', {2,3}, {0, 1, 2, 2, 1,2}, sd::DataType::INT32); + NDArray expected('c', {2,3,3,4}, {1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 13,14,15,16,17,18,19,20,21,22,23,24, 25,26,27,28,29,30,31,32,33,34,35,36, 25,26,27,28,29,30,31,32,33,34,35,36, 13,14,15,16,17,18,19,20,21,22,23,24, 25,26,27,28,29,30,31,32,33,34,35,36}); + + sd::ops::gather op; + + auto result = op.evaluate({&input, &indices}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto* output = result.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, gather_7) { + + NDArray input ('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}); + NDArray indices ('c', {2,3}, {0, 1, 2, 2, 1,2}, sd::DataType::INT64); + NDArray expected('c', {2,3,2,3}, {1, 2, 3, 3, 2, 3, 5, 6, 7, 7, 6, 7, 9,10,11,11,10,11, 13,14,15,15,14,15, 17,18,19,19,18,19, 21,22,23,23,22,23}); + + sd::ops::gather op; + + auto result = op.evaluate({&input, &indices}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto* output = result.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, gather_8) { + + NDArray input('c', {3,5}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, sd::DataType::FLOAT32); + NDArray indices('c', {1}, std::vector{2}, sd::DataType::INT32); + NDArray expected('c', {1,5}, {11, 12, 13, 14, 15.}, sd::DataType::FLOAT32); + + sd::ops::gather op; + + auto result = op.evaluate({&input, &indices}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto* output = result.at(0); + // output->printShapeInfo(); + // output->printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, gather_9) { + NDArray x('c', {2, 4, 3, 2}, sd::DataType::FLOAT32); + NDArray indices('c', {2}, std::vector{1, 0}, sd::DataType::INT32); + + sd::ops::gather op; + auto result = op.evaluate({&x, &indices}, {}, {-2}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, gather_10) { + NDArray x('c', {2, 2}, {1, 2, 3, 4}); + NDArray e('c', {2, 2}, {3, 4, 1, 2}); + + sd::ops::gather op; + auto result = op.evaluate({&x}, {}, {0, 1, 0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, gather_11) { + + NDArray x('c', {2, 2}, {1, 2, 3, 4}); + NDArray indices('c', {2}, std::vector{1, 0}, sd::DataType::INT64); + NDArray e('c', {2, 2}, {3, 4, 1, 2}); + + sd::ops::gather op; + auto result = op.evaluate({&x, &indices}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, gather_12) { + + NDArray input('c', {4}, {2.f, 3.f, 4.f, 5.f}); + NDArray indices('c', {2}, {0, 2}, sd::DataType::INT32); + NDArray exp('c', {2}, {2.f, 4.f}); + + sd::ops::gather op; + auto result = op.evaluate({&input, &indices}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, gather_13) { + + NDArray input ('c', {2,3,4,5}, sd::DataType::DOUBLE); + NDArray indices ('c', {2,3,4}, {0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3,0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3}, sd::DataType::INT32); + NDArray expected('c', {2,3, 2,3,4, 5}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, + 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, + 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, + 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, + 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, + 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, + 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, + 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, + 100,101,102,103,104, 105,106,107,108,109, 110,111,112,113,114, 115,116,117,118,119, 100,101,102,103,104, 105,106,107,108,109, 110,111,112,113,114, 115,116,117,118,119, 100,101,102,103,104, 105,106,107,108,109, 110,111,112,113,114, 115,116,117,118,119, + 100,101,102,103,104, 105,106,107,108,109, 110,111,112,113,114, 115,116,117,118,119, 100,101,102,103,104, 105,106,107,108,109, 110,111,112,113,114, 115,116,117,118,119, 100,101,102,103,104, 105,106,107,108,109, 110,111,112,113,114, 115,116,117,118,119}); + + input.linspace(0); + + sd::ops::gather op; + + auto result = op.evaluate({&input, &indices}, {}, {2}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto* output = result.at(0); + + ASSERT_TRUE(expected.isSameShapeStrict(*output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, gather_14) { + + NDArray input ('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}); + NDArray indices ('c', {2,3}, {0, 10, 2, 20, 1,2}, sd::DataType::INT32); + NDArray output('c', {2,2,3,4}); + + sd::ops::gather op; + + ASSERT_ANY_THROW(op.execute({&input, &indices}, {&output}, {}, {1}, {true})); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, gather_15) { + + NDArray input ('c', {2,3,4,5}, sd::DataType::DOUBLE); + NDArray indices ('c', {2,3,4}, {0, 10, 2, 3, 0, 1, 20, 3, 0, 1, 2, 3,0, 1, 2, 3, 0, 1, 2, 30, 0, 1, 2, 3}, sd::DataType::INT32); + NDArray output('c', {2,3, 2,3,4, 5}); + + sd::ops::gather op; + + ASSERT_ANY_THROW(op.execute({&input, &indices}, {&output}, {}, {2}, {true})); +} + +TEST_F(DeclarableOpsTests2, BroadcastGradientArgs_1) { + + NDArray input ('c', {3,3,4}, {1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 13,14,15,16,17,18,19,20,21,22,23,24, 25,26,27,28,29,30,31,32,33,34,35,36}, sd::DataType::INT32); + NDArray indices ('c', {2,3}, {0, 1, 2, 2, 1,2}, sd::DataType::INT32); + + sd::ops::broadcastgradientargs op; + + auto result = op.evaluate({&input, &indices}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_KERNEL_FAILURE, result.status()); + +} + +TEST_F(DeclarableOpsTests2, NLP_Cbow_Test_1) { + auto exp0 = NDArrayFactory::create('c', {1, 10}); + auto exp1 = NDArrayFactory::create('c', {1, 10}); + auto exp2 = NDArrayFactory::create('c', {1, 10}); + + exp0.assign(0.0095); + exp1.assign(0.019875); + exp2.assign(0.02); + + auto target = NDArrayFactory::create(0); + auto ngStarter = NDArrayFactory::empty(); + auto context = NDArrayFactory::create('c', {3}, {0, 1, 2}); + auto locked = NDArrayFactory::create('c', {3}); + auto indices = NDArrayFactory::create('c', {2}, {4, 5}); + auto codes = NDArrayFactory::create('c', {2}, {1, 1}); + auto syn0 = NDArrayFactory::create('c', {100, 10}); + auto syn1 = NDArrayFactory::create('c', {100, 10}); + auto syn1Neg = NDArrayFactory::empty(); + auto expTable = NDArrayFactory::create('c', {10000}); + auto negTable = NDArrayFactory::empty(); + auto numWords = NDArrayFactory::create('c', {1}, {1}); + + syn0.assign(0.01); + syn1.assign(0.02); + expTable.assign(0.5); + + auto alpha = NDArrayFactory::create(0.025); + auto randomValue = NDArrayFactory::create(2L); + auto inferenceVector = NDArrayFactory::empty(); + + sd::ops::cbow op; + auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, {}, true); + ASSERT_EQ(Status::OK(), result.status()); + + auto row_s0_0 = syn0({0,1, 0,0}, true); + auto row_s0_1 = syn0({1,2, 0,0}, true); + auto row_s0_2 = syn0({2,3, 0,0}, true); + + auto row_s1_4 = syn1({4,5, 0,0}, true); + auto row_s1_5 = syn1({5,6, 0,0}, true); + auto row_s1_6 = syn1({6,7, 0,0}, true); + + ASSERT_EQ(exp0, row_s0_0); + ASSERT_EQ(exp0, row_s0_1); + ASSERT_EQ(exp0, row_s0_2); + + ASSERT_EQ(exp1, row_s1_4); + ASSERT_EQ(exp1, row_s1_5); + ASSERT_EQ(exp2, row_s1_6); + +} + +TEST_F(DeclarableOpsTests2, Test_Squeeze_1) { + auto x = NDArrayFactory::create('c', {2, 1, 3, 1, 1, 1, 4}); + x.linspace(1); + auto exp = x.reshape('c', {2, 3, 4}); + + sd::ops::squeeze op; + auto result = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + + +TEST_F(DeclarableOpsTests2, Test_Squeeze_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1); + auto exp = new NDArray(x.dup()); + + sd::ops::squeeze op; + auto result = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp->isSameShape(z)); + ASSERT_TRUE(exp->equalsTo(z)); + + delete exp; +} + +TEST_F(DeclarableOpsTests2, Test_FloorMod_1) { + auto x = NDArrayFactory::create('c', {1, 3}, {2.0f, 6.0f, -3.0f}); + auto y = NDArrayFactory::create('c', {1, 3}, {-3.0f, 2.0f, -2.0f}); + auto exp = NDArrayFactory::create('c', {1, 3}, {-1.f, 0.f, -1.f}); + + sd::ops::floormod op; + + auto result = op.evaluate({&x, &y}, {}, {}); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + + ASSERT_TRUE(exp.equalsTo(z)); + +} + +TEST_F(DeclarableOpsTests2, Test_FloorDiv_1) { + auto x = NDArrayFactory::create('c', {1, 3}, {3.0f, 6.0f, -3.0f}); + auto y = NDArrayFactory::create('c', {1, 3}, {-2.0f, 2.0f, -2.0f}); + auto exp = NDArrayFactory::create('c', {1, 3}, {-2.f, 3.f, 1.f}); + + sd::ops::floordiv op; + + auto result = op.evaluate({&x, &y}, {}, {}); + + auto z = result.at(0); +// z->printShapeInfo("FloorDiv1 shape"); +// z->printIndexedBuffer("FloorDiv1"); + ASSERT_TRUE(exp.isSameShape(z)); +} + +TEST_F(DeclarableOpsTests2, Test_FloorDiv_2) { + auto x = NDArrayFactory::create('c', {1, 3}, {3.0f, 6.0f, -3.0f}); + auto y = NDArrayFactory::create('c', {1, 3}, {-2.0f, 2.0f, -2.0f}); + auto eps = NDArrayFactory::create('c', {1, 3}, {1.f, 2.f, 3.f}); + + auto exp1 = NDArrayFactory::create('c', {1, 3}, {0.f, 0.f, 0.f}); + auto exp2 = NDArrayFactory::create('c', {1, 3}, {0.f, 0.f, 0.f}); + + sd::ops::floordiv_bp op; + + auto result = op.evaluate({&x, &y, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + auto z1 = result.at(0); + auto z2 = result.at(1); +// z->printShapeInfo("FloorDiv1 shape"); +// z1->printIndexedBuffer("FloorDiv2_1"); +// z2->printIndexedBuffer("FloorDiv2_2"); + + ASSERT_TRUE(exp1.equalsTo(z1)); + ASSERT_TRUE(exp2.equalsTo(z2)); + +} + +TEST_F(DeclarableOpsTests2, Test_CRelu_1) { + auto x = NDArrayFactory::create('c', {2, 2}, {1.0f, 2.0f, 3.0f, 4.0f}); + auto exp = NDArrayFactory::create('c', {2, 4}, {1.0f, 2.0f, 0.f, 0.f, 3.0f, 4.0f, 0.f, 0.f}); + + sd::ops::crelu op; + + auto result = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +TEST_F(DeclarableOpsTests2, Test_CRelu_BP_2) { + auto x = NDArrayFactory::create('c', {2, 2}, {1.0f, 2.0f, -3.0f, 4.0f}); + auto eps = NDArrayFactory::create('c', {2, 4}, {1.0f, 2.0f, 4.f, 3.f, 3.0f, 4.0f, 2.f, 1.f}); + auto exp = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, -2.f, 4.f}); + + sd::ops::crelu_bp op; + auto result = op.evaluate({&x, &eps}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(1, result.size()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +TEST_F(DeclarableOpsTests2, Test_Concat_BP_1) { + auto x = NDArrayFactory::create('c', {2, 2}); + auto y = NDArrayFactory::create('c', {2, 2}); + auto eps = NDArrayFactory::create('c', {2, 4}, {1.0f, 2.0f, 0.f, 1.f, 3.0f, 4.0f, 0.f, 1.f}); + auto expEX = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); + auto expEY = NDArrayFactory::create('c', {2, 2}, {0.f, 1.f, 0.f, 1.f}); + + sd::ops::concat_bp op; + auto result = op.evaluate({&x, &y, &eps}, {}, {-1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); + + auto epsX = result.at(0); + auto epsY = result.at(1); + + ASSERT_TRUE(expEX.isSameShape(epsX)); + ASSERT_TRUE(expEX.equalsTo(epsX)); + + ASSERT_TRUE(expEY.isSameShape(epsY)); + ASSERT_TRUE(expEY.equalsTo(epsY)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_1) { + + auto labels = NDArrayFactory::create('c', {2,3,4,5}); + auto predictions = NDArrayFactory::create('c', {2,3,4,5}); + auto weights = NDArrayFactory::create('c', {2,3,4,5}); + auto expected = NDArrayFactory::create('c', {2,3,4,5}); + + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5f); + expected.assign(0.5f); + + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_2) { + + auto labels = NDArrayFactory::create('c', {2,3,4,5}); + auto predictions = NDArrayFactory::create('c', {2,3,4,5}); + auto weights = NDArrayFactory::create('c', {1,1,4,5}); + auto expected = NDArrayFactory::create('c', {2,3,4,5}); + + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5f); + expected.assign(0.5f); + + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + // result->printIndexedBuffer("ADL test2"); + // expected.printIndexedBuffer("ADL expec"); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_3) { + + auto labels = NDArrayFactory::create('c', {2,3,4,5}); + auto predictions = NDArrayFactory::create('c', {2,3,4,5}); + auto weights = NDArrayFactory::create('c', {1,1,1,5}); + auto expected = NDArrayFactory::create('c', {2,3,4,5}); + + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5f); + expected.assign(0.5f); + + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_4) { + + auto labels = NDArrayFactory::create('c', {2,3,4,5}); + auto predictions = NDArrayFactory::create('c', {2,3,4,5}); + auto weights = NDArrayFactory::create('c', {2,1,1,5}); + auto expected = NDArrayFactory::create('c', {2,3,4,5}); + + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5f); + expected.assign(0.5f); + + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_5) { + + auto labels = NDArrayFactory::create('c', {2,3,4,5}); + auto predictions = NDArrayFactory::create('c', {2,3,4,5}); + auto weights = NDArrayFactory::create('c', {1,1}); + auto expected = NDArrayFactory::create('c', {2,3,4,5}); + + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5f); + expected.assign(0.5f); + + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_6) { + + auto labels = NDArrayFactory::create('c', {2,3,4,5}); + auto predictions = NDArrayFactory::create('c', {2,3,4,5}); + auto weights = NDArrayFactory::create('c', {1,1}); + auto expected = NDArrayFactory::create('c', {2,3,4,5}); + + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.f); + expected.assign(0.f); + + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_7) { + + auto labels = NDArrayFactory::create('c', {2,3,4,5}); + auto predictions = NDArrayFactory::create('c', {2,3,4,5}); + auto weights = NDArrayFactory::create('c', {2,3,4,5}); + + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5f); + + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_TRUE(result->e(0) == 60.f); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_8) { + + auto labels = NDArrayFactory::create('c', {2,3,4,5}); + auto predictions = NDArrayFactory::create('c', {2,3,4,5}); + auto weights = NDArrayFactory::create('c', {2,3,4,5}); + + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.f); + + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_TRUE(result->e(0) == 0.f); + + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_9) { + + auto labels = NDArrayFactory::create('c', {2,3,4,5}); + auto predictions = NDArrayFactory::create('c', {2,3,4,5}); + auto weights = NDArrayFactory::create('c', {2,1,4,1}); + + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5f); + + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_TRUE(result->e(0) == 60.); + + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_10) { + + auto labels = NDArrayFactory::create('c', {2,3,4,5}); + auto predictions = NDArrayFactory::create('c', {2,3,4,5}); + auto weights = NDArrayFactory::create('c', {1,1}); + + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5f); + + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_TRUE(result->e(0) == 60.f); + + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_11) { + + auto labels = NDArrayFactory::create('c', {2,3,4,5}); + auto predictions = NDArrayFactory::create('c', {2,3,4,5}); + auto weights = NDArrayFactory::create('c', {1,1}); + + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5f); + + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_TRUE(result->e(0) == 1.f); + + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_12) { + + auto labels = NDArrayFactory::create('c', {2,3,4,5}); + auto predictions = NDArrayFactory::create('c', {2,3,4,5}); + auto weights = NDArrayFactory::create('c', {1,1}); + + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.f); + + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_TRUE(result->e(0) == 0.f); + + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_13) { + + auto labels = NDArrayFactory::create('c', {2,3,4,5}); + auto predictions = NDArrayFactory::create('c', {2,3,4,5}); + auto weights = NDArrayFactory::create('c', {2,3,4,5}); + + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5f); + + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_TRUE(result->e(0) == 1.f); + + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_14) { + + auto labels = NDArrayFactory::create('c', {2,3,4,5}); + auto predictions = NDArrayFactory::create('c', {2,3,4,5}); + auto weights = NDArrayFactory::create('c', {2,3,4,5}); + + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5); + weights.p(1, 0.f); + weights.p(2, 0.f); + + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_TRUE(result->e(0) == 1.f); + + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_15) { + + auto labels = NDArrayFactory::create('c', {2,3,4,5}); + auto predictions = NDArrayFactory::create('c', {2,3,4,5}); + auto weights = NDArrayFactory::create('c', {2,3,4,5}); + + labels.linspace(1); + predictions.linspace(3); + weights.assign(0.5f); + + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_TRUE(result->e(0) == 2.f); + + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_16) { + + auto labels = NDArrayFactory::create('c', {2,3,4,5}); + auto predictions = NDArrayFactory::create('c', {2,3,4,5}); + auto weights = NDArrayFactory::create('c', {2,3,4,5}); + + labels.linspace(1); + predictions.linspace(3); + weights.assign(0.5f); + predictions.p(0, 0.f); + predictions.p(1, 0.f); + predictions.p(2, 0.f); + predictions.p(3, 0.f); + + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 2.01667, 1e-5); + + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_17) { + + auto labels = NDArrayFactory::create('c', {2,3,4,5}); + auto predictions = NDArrayFactory::create('c', {2,3,4,5}); + auto weights = NDArrayFactory::create('c', {2,3,4,5}); + + labels.linspace(1); + predictions.linspace(3); + weights.assign(0.5f); + predictions.p(0, 0.f); + predictions.p(1, 0.f); + predictions.p(2, 0.f); + predictions.p(3, 0.f); + labels.p(0, 0.f); + labels.p(1, 0.f); + labels.p(2, 0.f); + labels.p(3, 0.f); + + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 1.93333, 1e-5); + + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_18) { + + auto labels = NDArrayFactory::create('c', {2,3,4,5}); + auto predictions = NDArrayFactory::create('c', {2,3,4,5}); + auto weights = NDArrayFactory::create('c', {2,1,1,5}); + + labels.linspace(1); + predictions.linspace(3); + weights.assign(0.5f); + predictions.p(0, 0.f); + predictions.p(1, 0.f); + predictions.p(2, 0.f); + predictions.p(3, 0.); + labels.p(0, 0.f); + labels.p(1, 0.f); + labels.p(2, 0.f); + labels.p(3, 0.f); + + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 1.93333f, 1e-5); + + + +} + + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_19) { + + auto labels = NDArrayFactory::create('c', {2,3,4,5}); + auto predictions = NDArrayFactory::create('c', {2,3,4,5}); + auto weights = NDArrayFactory::create('c', {1,1}); + + labels.linspace(1); + predictions.linspace(3); + weights.assign(0.5); + + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_TRUE(result->e(0) == 1.); + + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_20) { + + auto labels = NDArrayFactory::create('c', {2,3,4,5}); + auto predictions = NDArrayFactory::create('c', {2,3,4,5}); + auto weights = NDArrayFactory::create('c', {2,3,4,5}); + + labels.linspace(1); + predictions.linspace(3); + weights.assign(0.5); + + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_TRUE(result->e(0) == 1.); + + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_21) { + + auto labels = NDArrayFactory::create('c', {2,3,4,5}); + auto predictions = NDArrayFactory::create('c', {2,3,4,5}); + auto weights = NDArrayFactory::create('c', {2,3,1,1}); + + labels.linspace(1); + predictions.linspace(3); + weights.assign(0.5); + + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_TRUE(result->e(0) == 1.f); + + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_22) { + + auto labels = NDArrayFactory::create('c', {2,3,4,5}); + auto predictions = NDArrayFactory::create('c', {2,3,4,5}); + auto weights = NDArrayFactory::create('c', {1,1}); + + labels.linspace(1); + predictions.linspace(3); + weights.assign(0.); + + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_TRUE(result->e(0) == 0.); + + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, absolute_difference_loss_test_23) { + + auto labels = NDArrayFactory::create('c', {2,3,4,5}); + auto predictions = NDArrayFactory::create('c', {2,3,4,5}); + auto weights = NDArrayFactory::create('c', {2,3,4,5}); + + labels.linspace(1); + predictions.linspace(3); + weights.assign(0.5); + predictions.p(0, 0.); + predictions.p(1, 0.); + predictions.p(2, 0.); + predictions.p(3, 0.); + labels.p(0, 0.); + labels.p(1, 0.); + labels.p(2, 0.); + labels.p(3, 0.); + weights.p(40+0, 0.); + weights.p(40+1, 0.); + weights.p(40+2, 0.); + weights.p(40+3, 0.); + + sd::ops::absolute_difference_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 0.965517, 1e-5); + + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, cosine_distance_loss_test1) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,3,4}); + auto expected = NDArrayFactory::create('c', {1,3,4}, {-91.5f, -107.5f, -125.5f, -145.5f, -167.5f, -191.5f, -217.5f, -245.5f, -275.5f, -307.5f, -341.5f, -377.5f}); + + labels.linspace(1); + predictions.linspace(2); + weights.assign(0.5); + + sd::ops::cosine_distance_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0,0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, cosine_distance_loss_test2) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,1,4}); + auto expected = NDArrayFactory::create('c', {2,1,4}, {-3.25f, -4.f, -4.75f, -5.5f, -12.25f, -13.f, -13.75f, -14.5f}); + + labels.linspace(1); + weights.assign(0.5); + predictions.assign(0.5); + + sd::ops::cosine_distance_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + + +} + + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, cosine_distance_loss_test3) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,1}); + auto expected = NDArrayFactory::create('c', {2,3,1}, {-2.f, -6.f,-10.f,-14.f,-18.f,-22.f}); + + labels.linspace(1); + weights.assign(0.5); + predictions.assign(0.5); + + sd::ops::cosine_distance_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0,2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, cosine_distance_loss_test4) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,1}); + auto expected = NDArrayFactory::create('c', {2,3,1}, {-2.f, -6.f,-10.f,-14.f,-18.f,-22.f}); + + labels.linspace(1); + weights.assign(0.5); + predictions.assign(0.5); + + sd::ops::cosine_distance_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0,2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, cosine_distance_loss_test5) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,1,4}); + + labels.linspace(1); + weights.assign(0.5); + predictions.assign(0.5); + + sd::ops::cosine_distance_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_TRUE(result->e(0) == -71.); + + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, cosine_distance_loss_test6) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,1}); + + labels.linspace(1); + weights.assign(0.5); + predictions.assign(0.5); + + sd::ops::cosine_distance_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_TRUE(result->e(0) == -71.f); + + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, cosine_distance_loss_test7) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,1,4}); + + labels.linspace(1); + weights.assign(0.5); + predictions.assign(0.5); + + sd::ops::cosine_distance_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1,0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_TRUE(result->e(0) == -69.f); + + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, cosine_distance_loss_test8) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,1}); + + labels.linspace(1); + weights.assign(0.5f); + predictions.assign(0.5f); + + sd::ops::cosine_distance_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2,2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_TRUE(result->e(0) == -24.f); + + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, cosine_distance_loss_test9) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,1}); + + labels.linspace(1); + weights.assign(0.5f); + predictions.assign(0.5f); + + sd::ops::cosine_distance_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2,2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_TRUE(result->e(0) == -24.); + + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, cosine_distance_loss_test10) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,1}); + + labels.linspace(1); + weights.assign(0.5f); + predictions.assign(0.5f); + weights.p(0, 0.f); + weights.p(1, 0.f); + + sd::ops::cosine_distance_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2,2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_TRUE(result->e(0) == -32.); + + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, hinge_loss_test1) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); + auto expected = NDArrayFactory::create('c', {2,3,4}, {1., 0. , 0., 2.5,0., 3.5, 0., 4.5,0., 5.5, 0., 6.5, 0., 7.5, 0., 8.5,0., 9.5,10., 0. ,0.,11.5, 0.,12.5}); + + logits.linspace(1); + weights.assign(0.5); + + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + // result->printBuffer(); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, hinge_loss_test2) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,1}); + auto expected = NDArrayFactory::create('c', {2,3,4}, {1., 0. , 0., 2.5,0., 3.5, 0., 4.5,0., 5.5, 0., 6.5, 0., 7.5, 0., 8.5,0., 9.5,10., 0. ,0.,11.5, 0.,12.5}); + + logits.linspace(1); + weights.assign(0.5); + + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + // result->printBuffer(); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, hinge_loss_test3) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,3,1}); + auto expected = NDArrayFactory::create('c', {2,3,4}, {1., 0. , 0., 2.5,0., 3.5, 0., 4.5,0., 5.5, 0., 6.5, 0., 7.5, 0., 8.5,0., 9.5,10., 0. ,0.,11.5, 0.,12.5}); + + logits.linspace(1); + weights.assign(0.5); + + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + // result->printBuffer(); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, hinge_loss_test4) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); + + logits.linspace(1); + weights.assign(0.5); + + + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_TRUE(result->e(0) == 83.); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, hinge_loss_test5) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,1}); + + logits.linspace(1); + weights.assign(0.5); + + + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_TRUE(result->e(0) == 83.); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, hinge_loss_test6) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,1,1}); + + logits.linspace(1); + weights.assign(0.5); + + + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_TRUE(result->e(0) == 83.); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, hinge_loss_test7) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); + + logits.linspace(1); + weights.assign(0.5); + + + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 6.91667, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, hinge_loss_test8) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,1}); + + logits.linspace(1); + weights.assign(0.5); + + + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 6.91667, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, hinge_loss_test9) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,1,4}); + + logits.linspace(1); + weights.assign(0.5); + + + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 6.91667, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, hinge_loss_test10) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); + + logits.linspace(1); + weights.assign(0.5); + + + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 3.45833, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, hinge_loss_test11) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,1,4}); + + logits.linspace(1); + weights.assign(0.5); + + + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 3.45833, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, hinge_loss_test12) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); + + logits.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); + + + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 3.975, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, hinge_loss_test13) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,1}); + + logits.linspace(1); + weights.assign(0.); + + + sd::ops::hinge_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_TRUE(result->e(0) == 0.); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, huber_loss_test1) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); + auto expected = NDArrayFactory::create('c', {2,3,4}, {0.0425 ,0.0875 ,0.13250001,0.17749999,0.22250001,0.26750001,0.31250003,0.35749999,0.4025 ,0.44749999,0.49249998,0.53750002, 0.58249998,0.6275 ,0.67250001,0.71749997,0.76249999,0.8075 ,0.85250002,0.89749998,0.9425 ,0.98749995,1.03250015,1.0775001}); + + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); + + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, huber_loss_test2) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,1}); + auto expected = NDArrayFactory::create('c', {2,3,4}, {0.0425 ,0.0875 ,0.13250001,0.17749999,0.22250001,0.26750001,0.31250003,0.35749999,0.4025 ,0.44749999,0.49249998,0.53750002, 0.58249998,0.6275 ,0.67250001,0.71749997,0.76249999,0.8075 ,0.85250002,0.89749998,0.9425 ,0.98749995,1.03250015,1.0775001}); + + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); + + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, huber_loss_test3) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,1}); + auto expected = NDArrayFactory::create('c', {2,3,4}, {0.0425 ,0.0875 ,0.13250001,0.17749999,0.22250001,0.26750001,0.31250003,0.35749999,0.4025 ,0.44749999,0.49249998,0.53750002, 0.58249998,0.6275 ,0.67250001,0.71749997,0.76249999,0.8075 ,0.85250002,0.89749998,0.9425 ,0.98749995,1.03250015,1.0775001}); + + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); + + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, huber_loss_test4) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); + + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); + + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 13.44, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, huber_loss_test5) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,1}); + + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); + + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 13.44, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, huber_loss_test6) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); + + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); + + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 1.12, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, huber_loss_test7) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,1,1}); + + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); + + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 1.12, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, huber_loss_test8) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); + + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); + + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 1.3, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, huber_loss_test9) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); + + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); + + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 0.56, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, huber_loss_test10) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,1}); + + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); + + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 0.56, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, huber_loss_test11) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); + + labels.linspace(0.1, 0.1); + predictions.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); + + sd::ops::huber_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {0.1}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 0.65, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, log_loss_test1) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); + auto expected = NDArrayFactory::create('c', {2,3,4}, {1.60943663, 2.48403668, 3.05256081, 3.40363169, 3.57730675, 3.59525585, 3.46986699, 3.20791793, 2.81228209, 2.28273821, 1.61630058, 0.80721998, -0.15329313, -1.27764463, -2.5828433 , -4.09208679, -5.83734226, -7.8636713 ,-10.23689461,-13.05822182,-16.49509811,-20.85659218,-26.82411766,-36.52717209}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, log_loss_test2) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,1,4}); + auto expected = NDArrayFactory::create('c', {2,3,4}, {1.60943663, 2.48403668, 3.05256081, 3.40363169, 3.57730675, 3.59525585, 3.46986699, 3.20791793, 2.81228209, 2.28273821, 1.61630058, 0.80721998, -0.15329313, -1.27764463, -2.5828433 , -4.09208679, -5.83734226, -7.8636713 ,-10.23689461,-13.05822182,-16.49509811,-20.85659218,-26.82411766,-36.52717209}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, log_loss_test3) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + NDArray weights(sd::DataType::DOUBLE); + auto expected = NDArrayFactory::create('c', {2,3,4}, {1.60943663, 2.48403668, 3.05256081, 3.40363169, 3.57730675, 3.59525585, 3.46986699, 3.20791793, 2.81228209, 2.28273821, 1.61630058, 0.80721998, -0.15329313, -1.27764463, -2.5828433 , -4.09208679, -5.83734226, -7.8636713 ,-10.23689461,-13.05822182,-16.49509811,-20.85659218,-26.82411766,-36.52717209}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, log_loss_test4) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), -113.886429, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, log_loss_test5) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,3,1}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), -113.886429, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, log_loss_test6) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + NDArray weights(sd::DataType::DOUBLE); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), -113.886429, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, log_loss_test7) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), -9.490536, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, log_loss_test8) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,3,1}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), -9.490536, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, log_loss_test9) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + NDArray weights(sd::DataType::DOUBLE); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), -9.490536, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, log_loss_test10) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); + + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), -12.443609, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, log_loss_test11) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), -4.745268, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, log_loss_test12) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,1}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), -4.745268, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, log_loss_test13) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); + + predictions.linspace(0.04, 0.04); + labels.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); + + sd::ops::log_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {1e-7}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), -6.221805, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test1) { + auto labels = NDArrayFactory::create('c', {1,3}, {0., 0.5, 1.}); + auto predictions = NDArrayFactory::create('c', {1,3}, {1., 1., 1.}); + auto weights = NDArrayFactory::create('c', {1,1}, {1}); + auto expected = NDArrayFactory::create('c', {1,1}, {1.}); + + sd::ops::mean_pairwssqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test2) { + auto labels = NDArrayFactory::create('c', {10,4}, {-0.5533444483384939, -0.4045807428083095, -0.38990808632111873, -1.3367815555936828, 2.2110825342567204, -0.3322538938773163, 0.5683588435736076, 1.401524673423209, -0.2216208609234102, -0.23645194877057543, -1.9319189398422172, 0.6106128799796062, 1.6973842275926025, -2.8306371397325553E-4, -1.1550401544465256, -0.08357706614294765, -0.27784822018757077, 0.8290894318337857, 1.6484476009013025, -0.7752524785358668, -0.9700596207063842, 3.0809371469543207, -0.23684959888998405, 0.22403535560739518, 0.6146150452128438, -1.1250088686147994, -0.5915314787415693, -0.0944090155356556, 0.7995514825959854, -1.2290496239142903, -1.8329592004926936, -0.1694821152623061, -1.7614978090471403, 0.07929168376086736, 0.4086255139492943, 2.045562727396195, -0.48701853719962834, 0.10304152395720723, -0.8993147347502636, -0.49078404206110715}); + auto predictions = NDArrayFactory::create('c', {10,4}, {-0.5982871220907984, 1.2010665656903237, 0.30243355682445544, -0.2070857400459659, 0.6962389393180044, -0.5878034128580758, 0.8325626284025988, -0.3555823702782838, -0.7099759151434476, 1.7971905051128672, -1.1018498592680859, 0.008705918349147959, -1.713038986676157, 0.5029671900704719, 0.7491261275031563, -0.34800067781360444, -1.3529065441284513, -0.6075230577852321, -0.6153583973120907, 1.6014780660677996, 0.6444219215516616, 0.7925830851904783, -0.5006063079380708, 1.7812300901376552, 0.4736193941708224, 1.411502849640833, 0.9555142545037492, -0.03936687661890644, 1.31661624967917, 0.7344531724786305, 0.8388550872918745, 0.7010030219905558, -0.5442944240155373, 0.4437344837841118, -1.7502823958671712, -1.9271369730241665, 0.9256612923554498, 1.9065401403827893, 0.42450175148842717, -0.11783183865542822}); + auto weights = NDArrayFactory::create('c', {1,1}, {1}); + auto expected = NDArrayFactory::create('c', {10,1}, {1.9665822560405073, 3.806679563402927, 6.185624212589066, 20.237895345263905, 16.739700814450472, 13.655430201400929, 6.473256392322658, 3.9337379694106325, 22.509455553531062, 1.4741234749089487}); + + sd::ops::mean_pairwssqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test3) { + auto labels = NDArrayFactory::create('c', {10,4}, {0.9165069946629816, 0.166426191704143, 0.13873357227527264, -0.5986162145785378, 0.4763504550662989, 1.2259816058633732, -0.4653205175596491, -1.7447031523970766, 1.349525448316014, 2.433089865629357, -2.54858150221601, -0.6060282162911894, 0.2625377104613349, -0.5007107584102752, 0.9576065700956302, -0.35787770401703584, -0.2608532564720665, 0.65688909921908, -0.1705876431948587, 1.2052884124800949, -0.976783296084278, 1.1163504624016534, -0.10545986164581109, -1.0632271027867568, 0.26460250034147065, -0.2299030354616135, -0.418989869909565, 0.7954060747536896, 0.37934127200736545, 0.8550487997440007, 0.2984909806904042, 0.1329065864221682, 1.478600294413247, 0.05421279873635542, -1.0552978360622536, -0.743808639782604, -1.3371851696151362, 2.7752972493355963, -1.6107187893743549, 1.5030902829432997}); + auto predictions = NDArrayFactory::create('c', {10,4}, {-3.398114657004427, 0.40587455906092945, 1.587706448479039, 0.27394335709083156, 1.0463122023764637, -0.6552570653663903, -0.26929204111727345, -2.710461824817806, 0.9141296064806023, -0.7632270851454939, -0.4077235519855459, 0.5555107559107472, -0.6776140976423888, 1.2422270521180823, 0.2372445100636733, 0.08522757123963924, -2.708523129389936, 0.09738215252575103, -0.8797837670498875, 0.8714091607391934, -0.628958978867591, 0.49380147969660415, -0.6663578349373824, 0.14570184758600965, -0.4710388511314244, 0.7708214742640788, 0.06836525442683238, -1.2786368797129386, -0.5077556003990912, 0.45383439418987664, 1.1686877788409553, -0.3078567969393852, -2.2375730522738198, 1.0108200459611192, 0.21955367964983963, 1.2268011099696847, 0.48061693077695455, -0.5306373077054981, 1.5005367299570744, -2.1005486985463966}); + auto weights = NDArrayFactory::create('c', {10,1}, {0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0}); + auto expected = NDArrayFactory::create('c', {10,1}, {0.0, 0.0, 21.748459867092496, 6.090581568657439, 7.51315897553838, 5.999534225166869, 22.58050883748054, 6.8600435676788605, 107.5976928688877, 191.56864939172544}); + + sd::ops::mean_pairwssqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test4) { + auto labels = NDArrayFactory::create('c', {10,4}, {-1.9540657282602247, -0.37099621218123746, 0.24959541842365968, 0.4125896396216978, -0.8661959659606203, 0.3651479206362867, -1.7475031047706964, -1.0962133982440159, 0.8451229874730279, 0.6876932162478913, 1.2598782790596628, 0.9372328828104118, 1.383555504464105, -0.816048166961237, 0.009041816630426176, -0.004376554457540983, -0.2386352931506252, -0.6494407817111416, 1.7888273635934742, -1.2157303560822368, -0.2446697859467434, -0.3040881765177774, -0.25843499040765916, -0.16479617511053568, 1.8063435075905592, 0.36002291874022285, -0.43317974028771883, 1.070086390817373, -1.0788479808458253, -0.3364318348487324, -0.859106579072977, 0.43984270049845064, -0.23662331183489546, -1.263417124724063, -0.3123732566483939, -0.125249623799724, -1.951308433393268, -0.4925779190927575, -1.081735149025745, -1.9910331435034687}); + auto predictions = NDArrayFactory::create('c', {10,4}, {-1.7053977111021588, 1.7704125629388408, -0.0876171627499475, 0.9428762101237441, 0.9080108618240852, -0.478732892339118, -0.8189639230649537, 1.3359668242925342, -0.07499867017894829, 0.6169780756804321, -1.1891117691972148, -0.319354110980483, -1.4287263424900434, -0.3556443786879834, 0.6389682186473912, 0.3161742985911756, 0.9047447733840537, -1.9974117226910393, 2.1067775658502326, 0.17035521714679938, -1.1393894489992826, 1.4570837278971687, 0.6312249731754015, -0.42793125692777634, -1.0685964336386844, -0.3590636581851568, -0.19147354841437528, -0.10128937266756889, -0.5714869078294972, 0.2682604831358205, 0.6608524575561853, 0.35658907103040305, -0.7053263272861181, -0.6318441042427088, 2.131292677079184, -0.3624048087249232, 1.6008209804575328, 0.1245980660014825, 1.0685424462364297, -0.5672594432046791}); + auto weights = NDArrayFactory::create('c', {1,1}, {1}); + + sd::ops::mean_pairwssqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 60.74394998193965, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test5) { + auto labels = NDArrayFactory::create('c', {10,4}, {0.9165069946629816, 0.166426191704143, 0.13873357227527264, -0.5986162145785378, 0.4763504550662989, 1.2259816058633732, -0.4653205175596491, -1.7447031523970766, 1.349525448316014, 2.433089865629357, -2.54858150221601, -0.6060282162911894, 0.2625377104613349, -0.5007107584102752, 0.9576065700956302, -0.35787770401703584, -0.2608532564720665, 0.65688909921908, -0.1705876431948587, 1.2052884124800949, -0.976783296084278, 1.1163504624016534, -0.10545986164581109, -1.0632271027867568, 0.26460250034147065, -0.2299030354616135, -0.418989869909565, 0.7954060747536896, 0.37934127200736545, 0.8550487997440007, 0.2984909806904042, 0.1329065864221682, 1.478600294413247, 0.05421279873635542, -1.0552978360622536, -0.743808639782604, -1.3371851696151362, 2.7752972493355963, -1.6107187893743549, 1.5030902829432997}); + auto predictions = NDArrayFactory::create('c', {10,4}, {-3.398114657004427, 0.40587455906092945, 1.587706448479039, 0.27394335709083156, 1.0463122023764637, -0.6552570653663903, -0.26929204111727345, -2.710461824817806, 0.9141296064806023, -0.7632270851454939, -0.4077235519855459, 0.5555107559107472, -0.6776140976423888, 1.2422270521180823, 0.2372445100636733, 0.08522757123963924, -2.708523129389936, 0.09738215252575103, -0.8797837670498875, 0.8714091607391934, -0.628958978867591, 0.49380147969660415, -0.6663578349373824, 0.14570184758600965, -0.4710388511314244, 0.7708214742640788, 0.06836525442683238, -1.2786368797129386, -0.5077556003990912, 0.45383439418987664, 1.1686877788409553, -0.3078567969393852, -2.2375730522738198, 1.0108200459611192, 0.21955367964983963, 1.2268011099696847, 0.48061693077695455, -0.5306373077054981, 1.5005367299570744, -2.1005486985463966}); + auto weights = NDArrayFactory::create('c', {1,1}, {1}); + + sd::ops::mean_pairwssqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 15.189082270182983, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test6) { + auto labels = NDArrayFactory::create('c', {10,4}, {0.7712557146220891, 0.37344724586647443, -1.465944048516541, 0.3226845250222374, 0.3153238532645865, -0.6453963287132424, -1.7695663855309438, -0.31350813714835285, 0.6209850696184357, -1.0632582557661083, 0.8971205782356552, -0.7361143357044725, 0.4349813432397299, 1.1012674501462072, -1.846028584047857, -0.04711049067212126, 0.3511384383511822, -1.5908669452488973, 0.6271232025632083, -0.5370025878354387, 0.09775855957778733, 0.8465118033582384, -0.5118005514773271, -0.8215749768059044, -0.5154271246850248, -0.6614138367887438, -2.721743038982485, -0.20634785234624944, 1.074134378795222, -0.515671736473577, 0.33574452224656587, -0.4258992514621533, -1.6946210614398756, 2.0853105493575246, -0.23223717047374226, -1.3145231337861756, -0.307739072607248, -0.13713627422120406, -0.05615471338688221, -0.7031780205843188}); + auto predictions = NDArrayFactory::create('c', {10,4}, {-0.8253096544930751, 0.81324545672996, 1.2530858908292535, 0.6881658781201572, 0.11626814971230247, 0.810096847233213, -0.41726775033902014, -0.07246036077805246, -0.3491325803119671, -0.7381717490678714, -1.258884944199858, 2.6195012275145992, 0.3241066697239042, -1.3306435333372646, -0.3413119919683999, 0.13167356361127197, -0.3992424507051653, 0.14454163796541403, -2.4931643208872316, 1.8740911656038526, -2.3404306490682956, -0.8036392545918644, -1.9726177395274997, -0.20128619801149433, -1.0680828820641624, -0.6228179015361869, 1.0785520122486962, -0.26148573195062036, -0.9154287856620913, 0.6612224269248097, -0.21735407368781667, 0.5584864652543093, 1.0208212201167435, -0.7560947201084579, -0.9092906572495081, 0.47525819203475833, 1.2215678456801444, -0.39319465979983964, 1.9435677135606038, 1.4540100039010526}); + auto weights = NDArrayFactory::create('c', {1,1}, {1}); + + sd::ops::mean_pairwssqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 13.568564090650312, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test7) { + auto labels = NDArrayFactory::create('c', {10,4}, {-0.06125002348040258, 0.5143643450377119, 2.6790723358660036, -0.8032552006036418, -2.4374371040644163, -0.1562964773317163, -1.3957988654288038, 1.2791626503391635, -1.433421873294552, -1.1819478586737284, 0.05162930965054662, -0.538650473505593, -0.548171720093084, -0.3103900587344872, -2.3955103171953342, 0.7127238680062526, 0.7182079438418053, 1.1842662402382182, 0.09585189676958715, 0.9276146067349225, 0.7856673461867428, 0.41368195133354113, -0.2939280190178078, -2.400566355562181, -1.1841519118039245, -1.066170501847581, -0.9274507409610022, 1.7671863041813334, -1.2849985781031494, -1.275990164491566, -0.8866824403466698, -0.6074077385015517, 0.7647344603897107, -1.048099070426831, 0.9433828938345293, -0.5591415819237762, 1.7962773615541947, -0.42365710367758247, -0.0385518907389571, -1.109959713481321}); + auto predictions = NDArrayFactory::create('c', {10,4}, {-0.7445687252538243, 0.2293875300325241, -1.0231630280206505, -0.18532545069458992, -0.07797403344353356, -0.9132035669873787, 0.9352296415512886, -1.7406458535354787, 0.8578334648119594, -0.6186274065269556, 0.4874824473654153, -0.9285817343788997, 0.1654680500853023, -0.6371334533926012, 1.3115245864160707, -2.072558735678832, 0.660795731844733, -0.34942292767044864, 0.05787182311194333, -0.12939210444705632, -0.6457028552461069, -0.6048992126598505, -0.17179604529778109, 1.292989642826032, -0.28867767615688045, 0.7635565516046265, -1.5464151753137487, -1.273368390129285, -1.074046012825826, -0.3534580692302915, 0.5757285568118223, 1.823271242883469, 0.31618576929075215, 0.5422847605415213, -0.7836698021860683, -0.6292022623165172, 2.1114596721927508, 0.4634986528550097, 0.08922001427846013, 1.5767749644913223}); + auto weights = NDArrayFactory::create('c', {10,1}, {0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0}); + + sd::ops::mean_pairwssqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 198.318201904499, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test8) { + auto labels = NDArrayFactory::create('c', {10,4}, {1.2003157672694111, -1.0738078620687983, 1.4513396266923826, 0.5753935722952708, -0.5424028602429585, 0.9816221437385002, -1.0566397385428794, 1.503481308203513, -0.6543147953583112, 1.7453669976827346, -0.1557689124924227, 0.3387794658137257, -1.2306868494328145, -0.3299042398395769, 0.026464968146954395, -1.5077479623528403, -0.27514168845621795, 0.18739335150879793, 1.7319910646645431, 1.5228099405663476, 0.8522684742808536, 0.2362049362675063, 0.2610756525241469, 0.457998065505686, -2.7342179885912623, -0.10968795695808314, 0.581598742956297, -1.9309885922934567, -1.5775788440607954, -0.04254899350225641, -0.3125858556254039, -1.1328154327730207, 0.00566243314780096, 0.8492052576274621, 0.05945202212214481, 1.4976918834497108, 0.8869512918387292, 0.4014181932175132, -0.015512552855187248, -1.3609667909108454}); + auto predictions = NDArrayFactory::create('c', {10,4}, {-1.1088399463364795, 0.09302972835006071, 0.033839927431215555, -0.39567507675572494, 0.8269497207597863, 1.111162272517752, 0.4930937252630912, -1.4561668998323452, 0.9417715392862969, -1.0553855492735509, 0.05848285303876081, 0.8852337518047972, -0.7472824481835305, 0.404906922583895, -0.2198309547562547, 1.9536515925189717, 0.8165036568007779, -0.19524282774410398, -0.09111693087754393, 1.1604245932512238, -0.6243762858131077, 1.4297003275591034, -0.17220079411538428, -2.3139504326793032, 0.3839796486999712, 2.0287791964679234, 0.1534441713632995, -0.6062103319229825, -0.4965880982906036, -0.373907747810053, -1.6566345746154432, 0.17534987728494222, -1.6713458890334796, 1.254628987947714, 1.914596591838086, -1.0816010467183583, 0.25033738231939673, -1.605752685708275, 1.1029112741353981, 0.3237822320282494}); + auto weights = NDArrayFactory::create('c', {10,1}, {0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0}); + + sd::ops::mean_pairwssqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 10.709003499121707, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, mean_pairwssqerr_loss_test9) { + auto labels = NDArrayFactory::create('c', {10,4}, {0.054445708809271035, 2.107634671009908, -0.7906421810578572, -1.075840781788665, 0.11881403008710377, 0.8444812915085994, -0.305754504070933, 1.6429935026781464, 0.8155105031719394, 0.04900134907242568, 0.6847004530975871, 0.23315535615893132, 0.17011663306483038, -1.1865513655938285, 1.5931597087896407, -1.7937514075547496, -0.036695307704292295, -1.6416280650778925, 1.130578912176608, -1.1267224667674058, -0.8690453889645526, 0.6717944721406133, 0.0850200492927782, 1.1294419289013125, 0.2154793028698133, 0.4557382556428947, -0.7343674069166273, -0.20013117860162175, -0.6096905108192562, 0.42022878041905926, -0.7446306649741321, 0.01724811509597817, 1.843091605690758, 1.008879504632424, 1.198292190689489, -0.4474144618813475, 0.25202981742888664, 0.07036737843407408, 1.2400630276444486, -1.1072825235557615}); + auto predictions = NDArrayFactory::create('c', {10,4}, {-1.6788168943811437, 1.1823653279081687, -0.3580541857004183, -0.4449970504370699, -1.3031645333940127, 0.5755013195969282, -0.7997343141774744, -0.8806735270004084, 0.9705277499376251, -1.6360067944580943, 0.12579369136710156, 1.0525902242414313, -1.625751312422252, -0.03900152587147075, 0.4112500942756277, 0.6589999986358094, 0.6144107111689617, 2.8561269030217264, 1.5299963640392247, -0.314093051147705, 1.6523278218751989, -0.5504653447714114, 0.53395260877978, 0.409795577698306, 0.4466825218051794, 1.2382059301630401, 0.4834869732526594, -0.635409128905636, -1.9343816841697272, -0.4192523056060229, -1.0662979055059818, 0.4270901960618144, -0.7391311480757151, -0.8268168961897452, -1.0855715553457785, -9.410401291588706E-4, -0.7721838774717349, 0.4784019579457375, -0.6979798841469268, -0.319729737118584}); + auto weights = NDArrayFactory::create('c', {10,1}, {0.0, 0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0}); + + sd::ops::mean_pairwssqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 17.686067864414472, 1e-5); + + +} +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test1) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); + auto expected = NDArrayFactory::create('c', {2,3,4}, {0.125, 0.5, 1.125, 2., 3.125, 4.5, 6.125, 8.,10.125,12.5,15.125,18.,21.125,24.5,28.125,32.,36.125,40.5,45.125,50.,55.125,60.5,66.125,72.}); + + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test2) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,1,4}); + auto expected = NDArrayFactory::create('c', {2,3,4}, {0.125, 0.5, 1.125, 2., 3.125, 4.5, 6.125, 8.,10.125,12.5,15.125,18.,21.125,24.5,28.125,32.,36.125,40.5,45.125,50.,55.125,60.5,66.125,72.}); + + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test3) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,1,1}); + auto expected = NDArrayFactory::create('c', {2,3,4}, {0.125, 0.5, 1.125, 2., 3.125, 4.5, 6.125, 8.,10.125,12.5,15.125,18.,21.125,24.5,28.125,32.,36.125,40.5,45.125,50.,55.125,60.5,66.125,72.}); + + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test4) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); + auto expected = NDArrayFactory::create('c', {2,3,4}, {0., 0., 0., 0., 3.125, 4.5, 6.125, 8.,10.125,12.5,15.125,18.,21.125,24.5,28.125,32.,36.125,40.5,45.125,50.,55.125,60.5,66.125,72.}); + + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); + + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test5) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); + + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 612.5, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test6) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,1,4}); + + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 612.5, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test7) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,1}); + + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 612.5, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test8) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); + + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + weights.p(3, 0.); + + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 608.75, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test9) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); + + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 51.041668, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test10) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,3,1}); + + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 51.041668, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test11) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,1}); + + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 51.041668, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test12) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,1}); + + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 88.541664, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test13) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); + + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 25.520834, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test14) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,1,4}); + + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 25.520834, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test15) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,1}); + + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); + + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 25.520834, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, mean_sqerr_loss_test16) { + + auto labels = NDArrayFactory::create('c', {2,3,4}); + auto predictions = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,1}); + + predictions.linspace(0.5, 0.5); + labels.linspace(1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + + sd::ops::mean_sqerr_loss op; + auto results = op.evaluate({&predictions, &weights, &labels}, {}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 44.270832, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test1) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); + auto expected = NDArrayFactory::create('c', {2,3,4}, {0.37219834,0.29906943,0.27717763,0.45650762,0.23703849,0.51874399,0.20159303,0.58555031,0.17057693,0.65663081,0.14366767,0.73164123,0.12050423,0.81020868,0.10070664,0.89195037,0.08389302,0.97648883,1.01969337,0.06346401,0.05775976,1.15254164,0.04777273,1.2434181 }); + + logits.linspace(0.1, 0.1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test2) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,1,1}); + auto expected = NDArrayFactory::create('c', {2,3,4}, {0.37219834,0.29906943,0.27717763,0.45650762,0.23703849,0.51874399,0.20159303,0.58555031,0.17057693,0.65663081,0.14366767,0.73164123,0.12050423,0.81020868,0.10070664,0.89195037,0.08389302,0.97648883,1.01969337,0.06346401,0.05775976,1.15254164,0.04777273,1.2434181 }); + + logits.linspace(0.1, 0.1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test3) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,1}); + auto expected = NDArrayFactory::create('c', {2,3,4}, {0.37219834,0.29906943,0.27717763,0.45650762,0.23703849,0.51874399,0.20159303,0.58555031,0.17057693,0.65663081,0.14366767,0.73164123,0.12050423,0.81020868,0.10070664,0.89195037,0.08389302,0.97648883,1.01969337,0.06346401,0.05775976,1.15254164,0.04777273,1.2434181 }); + + logits.linspace(0.1, 0.1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test4) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); + auto expected = NDArrayFactory::create('c', {2,3,4}, {0.24719833, 0.54906946, 0.65217763,-0.04349237,0.86203849,-0.23125602, 1.07659304,-0.41444966,1.29557693,-0.59336919, 1.5186677 ,-0.76835877,1.74550426,-0.93979132, 1.9757067 ,-1.10804963,2.20889306,-1.27351117,-1.35530663, 2.56346393,2.68275976,-1.59745836, 2.92277265,-1.7565819 }); + + logits.linspace(0.1, 0.1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test5) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); + + logits.linspace(0.1, 0.1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 11.2187976837, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test6) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,1}); + + logits.linspace(0.1, 0.1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 11.2187976837, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test7) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,1}); + + logits.linspace(0.1, 0.1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 11.2187976837, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test8) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); + + logits.linspace(0.1, 0.1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 10.2187976837, 1e-5); + + +} + + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test9) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,1}); + + logits.linspace(0.1, 0.1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 6.06840181351, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test10) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); + + logits.linspace(0.1, 0.1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 0.934899806976, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test11) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,1,4}); + + logits.linspace(0.1, 0.1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 0.934899806976, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test12) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,1}); + + logits.linspace(0.1, 0.1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 0.851566493511, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test13) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,1}); + + logits.linspace(0.1, 0.1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 1.01140034199, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test14) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,4}); + + logits.linspace(0.1, 0.1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 0.467449903488, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test15) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,3,1}); + + logits.linspace(0.1, 0.1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 0.467449903488, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test16) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,1}); + + logits.linspace(0.1, 0.1); + weights.assign(0.5); + + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 0.425783246756, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, sigm_cross_entropy_loss_test17) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3,1}); + + logits.linspace(0.1, 0.1); + weights.assign(0.5); + weights.p(0, 0.); + weights.p(1, 0.); + weights.p(2, 0.); + + sd::ops::sigm_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {3}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 0.505700170994, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test1) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3}); + auto expected = NDArrayFactory::create('c', {2,3}, {1.39253557,1.44253552,1.44253552,1.44253552,1.39253557,1.44253552}); + + logits.linspace(0.1, 0.1); + weights.assign(0.5); + + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test2) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3}); + auto expected = NDArrayFactory::create('c', {2,3}, {-0.92835701,-1.12835705,-1.12835705,-1.12835705,-0.92835701,-1.12835705}); + + logits.linspace(0.1, 0.1); + weights.assign(0.5); + + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test3) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,1}); + auto expected = NDArrayFactory::create('c', {2,3}, {-0.92835701,-1.12835705,-1.12835705,-1.12835705,-0.92835701,-1.12835705}); + + logits.linspace(0.1, 0.1); + weights.assign(0.5); + + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test4) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,3}); + auto expected = NDArrayFactory::create('c', {2,3}, {-0.92835701,-1.12835705,-1.12835705,-1.12835705,-0.92835701,-1.12835705}); + + logits.linspace(0.1, 0.1); + weights.assign(0.5); + + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test5) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,1}); + auto expected = NDArrayFactory::create('c', {2,3}, {-0.92835701,-1.12835705,-1.12835705,-1.12835705,-0.92835701,-1.12835705}); + + logits.linspace(0.1, 0.1); + weights.assign(0.5); + + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test6) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3}); + + logits.linspace(0.1, 0.1); + weights.assign(0.5); + + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), 8.55521392822, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test7) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {2,3}); + + logits.linspace(0.1, 0.1); + weights.assign(0.5); + + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), -6.37014198303, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test8) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,1}); + + logits.linspace(0.1, 0.1); + weights.assign(0.5); + + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), -6.37014198303, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test9) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,3}); + + logits.linspace(0.1, 0.1); + weights.assign(0.5); + + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), -6.37014198303, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test10) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,3}); + + logits.linspace(0.1, 0.1); + weights.assign(0.5); + + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {2}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), -2.12338066101, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test11) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto weights = NDArrayFactory::create('c', {1,3}); + + logits.linspace(0.1, 0.1); + weights.assign(0.5); + + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {3}, {}, {}, false); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), -1.06169033051, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test12) { + + auto labels = NDArrayFactory::create('c', {2,4},{0,1,1,0,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,4}); + auto weights = NDArrayFactory::create('c', {2,1}); + + logits.linspace(0.1, 0.1); + weights.assign(0.5); + + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {3}, {}, {}, false); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(result->isScalar()); + ASSERT_NEAR(result->e(0), -2.18880319595, 1e-5); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test13) { + + auto labels = NDArrayFactory::create('c', {2,4},{0,1,1,0,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,4}); + auto weights = NDArrayFactory::create('c', {2,1}); + auto expected = NDArrayFactory::create('c', {2,1}, {1.39253557,1.44253552}); + + logits.linspace(0.1, 0.1); + weights.assign(0.5); + + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {0.}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + + + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test14) { + + auto labels = NDArrayFactory::create('c', {2,4},{0,1,1,0,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,4}); + auto weights = NDArrayFactory::create('c', {2,1}); + auto expected = NDArrayFactory::create('c', {2,1}, {-2.08880329, -2.28880334}); + + logits.linspace(0.1, 0.1); + weights.assign(0.5); + + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, softmax_cross_entropy_loss_test15) { + + auto labels = NDArrayFactory::create('c', {2,4},{0,1,1,0,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,4}); + auto weights = NDArrayFactory::create('c', {1,1}); + auto expected = NDArrayFactory::create('c', {2,1}, {-2.08880329, -2.28880334}); + + logits.linspace(0.1, 0.1); + weights.assign(0.5); + + sd::ops::softmax_cross_entropy_loss op; + auto results = op.evaluate({&logits, &weights, &labels}, {5.}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result = results.at(0); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, lstmCell_test1) { + + const int batchSize = 2; + const int inSize = 10; + const int numProj = 4; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); + auto Wc = NDArrayFactory::create('c', {3*numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4*numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {0.99926789,0.99926789,0.99926789,0.99926789,0.99926789,0.99926789,0.99926789,0.99926789}); + auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{3.99987108,3.99987108,3.99987108,3.99987108,3.99987108,3.99987108,3.99987108,3.99987108}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., 1.}, {0, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *ht = results.at(0); + auto *ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, lstmCell_test2) { + + const int batchSize = 2; + const int inSize = 10; + const int numProj = 4; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); + auto Wc = NDArrayFactory::create('c', {3*numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4*numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {0.95867589,0.95867589,0.95867589,0.95867589,0.95867589,0.95867589,0.95867589,0.95867589}); + auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{1.93001527,1.93001527,1.93001527,1.93001527, 1.93001527,1.93001527,1.93001527,1.93001527}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., -10.5}, {0, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *ht = results.at(0); + auto *ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, lstmCell_test3) { + + const int batchSize = 2; + const int inSize = 10; + const int numProj = 4; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); + auto Wc = NDArrayFactory::create('c', {3*numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4*numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {0.37992568,0.37992568,0.37992568,0.37992568,0.37992568,0.37992568,0.37992568,0.37992568}); + auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0.4, 0., 1.5}, {0, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *ht = results.at(0); + auto *ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, lstmCell_test4) { + + const int batchSize = 2; + const int inSize = 10; + const int numProj = 4; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); + auto Wc = NDArrayFactory::create('c', {3*numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4*numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {0.37992568,0.37992568,0.37992568,0.37992568,0.37992568,0.37992568,0.37992568,0.37992568}); + auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0.4, 0.3, 1.5}, {0, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *ht = results.at(0); + auto *ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, lstmCell_test5) { + + const int batchSize = 2; + const int inSize = 10; + const int numProj = 3; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); + auto Wc = NDArrayFactory::create('c', {3*numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4*numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {0.3,0.3,0.3,0.3,0.3,0.3}); + auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0.4, 0.3, 1.5}, {0, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *ht = results.at(0); + auto *ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, lstmCell_test6) { + + const int batchSize = 2; + const int inSize = 10; + const int numProj = 3; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); + auto Wc = NDArrayFactory::create('c', {3*numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4*numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {1.99832496,1.99832496,1.99832496,1.99832496,1.99832496,1.99832496}); + auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{3.99972188,3.99972188,3.99972188,3.99972188,3.99972188,3.99972188,3.99972188,3.99972188}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., 1.5}, {0, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *ht = results.at(0); + auto *ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, lstmCell_test7) { + + const int batchSize = 2; + const int inSize = 10; + const int numProj = 3; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); + auto Wc = NDArrayFactory::create('c', {3*numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4*numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {0.75977136,0.75977136,0.75977136,0.75977136,0.75977136,0.75977136}); + auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4, 0.4}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0.4, 0., 1.5}, {0, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *ht = results.at(0); + auto *ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); + + +} + + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, lstmCell_test8) { + + const int batchSize = 2; + const int inSize = 10; + const int numProj = 4; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); + auto Wc = NDArrayFactory::create('c', {3*numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4*numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {0.99930672,0.99930672,0.99930672,0.99930672, 0.99930672,0.99930672,0.99930672,0.99930672}); + auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{3.99996277,3.99996277,3.99996277,3.99996277,3.99996277,3.99996277,3.99996277,3.99996277}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., 10.5}, {1, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *ht = results.at(0); + auto *ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht,1e-4)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct,1e-4)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, lstmCell_test9) { + + const int batchSize = 2; + const int inSize = 10; + const int numProj = 4; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); + auto Wc = NDArrayFactory::create('c', {3*numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4*numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {0.99501777,0.99501777,0.99501777,0.99501777,0.99501777,0.99501777,0.99501777,0.99501777}); + auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{3.,3.,3.,3.,3.,3.,3.,3.}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {3., 0., 10.5}, {1, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *ht = results.at(0); + auto *ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht,1e-4)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, lstmCell_test10) { + + const int batchSize = 2; + const int inSize = 10; + const int numProj = 3; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); + auto Wc = NDArrayFactory::create('c', {3*numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4*numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {1.99861344,1.99861344,1.99861344,1.99861344,1.99861344,1.99861344}); + auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{3.99996277, 3.99996277, 3.99996277, 3.99996277,3.99996277, 3.99996277, 3.99996277, 3.99996277}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., 10.5}, {1, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *ht = results.at(0); + auto *ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, lstmCell_test11) { + + const int batchSize = 2; + const int inSize = 10; + const int numProj = 3; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); + auto Wc = NDArrayFactory::create('c', {3*numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4*numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {1.99003554,1.99003554,1.99003554,1.99003554,1.99003554,1.99003554}); + auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{3.,3.,3.,3.,3.,3.,3.,3.}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {3., 0., 10.5}, {1, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *ht = results.at(0); + auto *ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests2, lstmCell_test12) { + + const int batchSize = 2; + const int inSize = 10; + const int numProj = 3; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numProj}); + auto ct_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); + auto Wc = NDArrayFactory::create('c', {3*numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4*numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + ct_1.assign(3.); + Wx.assign(0.5); + Wh.assign(0.5); + Wc.assign(0.5); + Wp.assign(0.5); + b.assign(0.7); + + auto expHt = NDArrayFactory::create('c', {batchSize, numProj}, {1.,1.,1.,1.,1.,1.}); + auto expCt = NDArrayFactory::create('c', {batchSize, numUnits},{3.,3.,3.,3.,3.,3.,3.,3.}); + + sd::ops::lstmCell op; + auto results = op.evaluate({&xt, &ht_1, &ct_1, &Wx, &Wh, &Wc, &Wp, &b}, {3., 1.,-5.}, {1, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *ht = results.at(0); + auto *ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); + + +} + +#if !defined(__CUDABLAS__) || defined(HAVE_CUDNN) +TEST_F(DeclarableOpsTests2, ctc_loss_test1) { + constexpr int FRAME_LEN = 6 ; + constexpr int CLASS_LEN = 5 ; + constexpr int BATCH_LEN = 4 ; + constexpr int MIN_TARGET_LEN = 2; + constexpr int MAX_TARGET_LEN = 4; + +#if defined(HAVE_CUDNN) +//cudnn blankindex should be 0 + constexpr int BLANK_INDEX=0; +#else + constexpr int BLANK_INDEX=CLASS_LEN-1; +#endif + //logits were generated using numpy random and applying log softmax + //[ctc_loss.py](https://gist.github.com/quickwritereader/ca9858be201fd857348826a56e2bebc4) + auto logits = NDArrayFactory::create('c', {BATCH_LEN, FRAME_LEN, CLASS_LEN }, + {-1.52900087f, -1.7423916f, -1.79369985f, -1.68980741f, -1.35771429f, + -2.08261997f, -1.65483307f, -1.31878488f, -1.38940393f, -1.78624192f, + -1.83125744f, -1.28989651f, -1.86882736f, -1.51760877f, -1.65575026f, + -1.59030191f, -2.09045484f, -2.01113821f, -1.31159853f, -1.3120046f, + -1.45263472f, -1.52268525f, -1.6567962f, -2.06986454f, -1.46546941f, + -1.25549694f, -1.86336982f, -1.64691575f, -1.69584239f, -1.69374889f, + -1.62384788f, -1.53256338f, -1.47943003f, -1.9953089f, -1.49995189f, + -1.58914748f, -2.14294273f, -1.89989005f, -1.26397295f, -1.40048678f, + -1.52242117f, -1.79940303f, -1.86987214f, -1.41871056f, -1.51299132f, + -1.41772259f, -1.27648263f, -1.87029582f, -1.71325761f, -1.93542947f, + -1.4372372f, -1.72814911f, -1.18767571f, -1.85569031f, -2.09127332f, + -1.99591619f, -1.17070749f, -1.91569048f, -1.66127429f, -1.52865783f, + -1.39319926f, -2.19674832f, -1.69619098f, -1.37916537f, -1.58285964f, + -1.85456282f, -1.91027747f, -1.35265643f, -1.76707679f, -1.32405154f, + -1.70063352f, -1.82894304f, -1.81275811f, -1.76677183f, -1.13084056f, + -2.01507311f, -1.50622804f, -1.55902412f, -1.4076143f, -1.66137954f, + -1.72469437f, -1.74285619f, -1.72109242f, -1.54947478f, -1.36444454f, + -1.78795939f, -1.62871901f, -1.43244094f, -1.83058005f, -1.43770547f, + -1.3577647f, -1.81454222f, -1.58227661f, -1.89836191f, -1.49373763f, + -1.52027507f, -1.41807732f, -1.54481537f, -1.86538837f, -1.76619851f, + -1.64547283f, -1.58328753f, -1.58442673f, -1.65941447f, -1.57762943f, + -1.54091641f, -1.76747862f, -1.56063854f, -1.76235545f, -1.45495771f, + -1.37294933f, -1.75871646f, -1.38392315f, -1.62238305f, -2.06866473f, + -1.98087487f, -1.49880371f, -2.14268396f, -1.22969736f, -1.47432277f + }); + + auto logits_length = NDArrayFactory::create('c', {BATCH_LEN}, {FRAME_LEN,FRAME_LEN,FRAME_LEN,FRAME_LEN}); + std::vector target ={2, 2, 2, 0, 1, 1, 0, 0, 1, 2, 2, 3, 0, 2, 1, 2}; +#if defined(HAVE_CUDNN) + //for cudnn blank index is -. therefore our targets cant be 0 + for(int i=0;i('c',{BATCH_LEN, MAX_TARGET_LEN}, target ); + + auto labels_len = NDArrayFactory::create('c', {BATCH_LEN}, {MIN_TARGET_LEN,MIN_TARGET_LEN +1, MAX_TARGET_LEN, MIN_TARGET_LEN +1}); + +#if defined(HAVE_CUDNN) + auto expected = NDArrayFactory::create('c', {BATCH_LEN}, {6.088762f, 5.9546056f, 7.5806675f, 5.5532417f}); +#else + auto expected = NDArrayFactory::create('c', {BATCH_LEN}, {6.0661564f, 6.4285727f, 7.7180986f, 4.936057f}); +#endif + sd::ops::ctc_loss op; + + //logits.printIndexedBuffer("logits"); + //labels.printIndexedBuffer("labels"); + + auto results = op.evaluate({&labels, &logits, &labels_len, &logits_length}, {}, {BLANK_INDEX}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *loss = results.at(0); + + //loss->printIndexedBuffer("loss"); + + ASSERT_TRUE(expected.isSameShape(loss)); + ASSERT_TRUE(expected.equalsTo(loss)); + +} + + +TEST_F(DeclarableOpsTests2, ctc_loss_grad_test1) { + constexpr int FRAME_LEN = 6 ; + constexpr int CLASS_LEN = 5 ; + constexpr int BATCH_LEN = 4 ; + constexpr int MAX_TARGET_LEN = 4; + constexpr int MIN_TARGET_LEN = 2; +#if defined(HAVE_CUDNN) +//cudnn blankindex should be 0 + constexpr int BLANK_INDEX=0; +#else + constexpr int BLANK_INDEX=CLASS_LEN-1; +#endif + //logits were generated using numpy random and applying log softmax + //[ctc_loss.py](https://gist.github.com/quickwritereader/ca9858be201fd857348826a56e2bebc4) + auto logits = NDArrayFactory::create('c', {BATCH_LEN, FRAME_LEN, CLASS_LEN }, + {-1.52900087f, -1.7423916f, -1.79369985f, -1.68980741f, -1.35771429f, + -2.08261997f, -1.65483307f, -1.31878488f, -1.38940393f, -1.78624192f, + -1.83125744f, -1.28989651f, -1.86882736f, -1.51760877f, -1.65575026f, + -1.59030191f, -2.09045484f, -2.01113821f, -1.31159853f, -1.3120046f, + -1.45263472f, -1.52268525f, -1.6567962f, -2.06986454f, -1.46546941f, + -1.25549694f, -1.86336982f, -1.64691575f, -1.69584239f, -1.69374889f, + -1.62384788f, -1.53256338f, -1.47943003f, -1.9953089f, -1.49995189f, + -1.58914748f, -2.14294273f, -1.89989005f, -1.26397295f, -1.40048678f, + -1.52242117f, -1.79940303f, -1.86987214f, -1.41871056f, -1.51299132f, + -1.41772259f, -1.27648263f, -1.87029582f, -1.71325761f, -1.93542947f, + -1.4372372f, -1.72814911f, -1.18767571f, -1.85569031f, -2.09127332f, + -1.99591619f, -1.17070749f, -1.91569048f, -1.66127429f, -1.52865783f, + -1.39319926f, -2.19674832f, -1.69619098f, -1.37916537f, -1.58285964f, + -1.85456282f, -1.91027747f, -1.35265643f, -1.76707679f, -1.32405154f, + -1.70063352f, -1.82894304f, -1.81275811f, -1.76677183f, -1.13084056f, + -2.01507311f, -1.50622804f, -1.55902412f, -1.4076143f, -1.66137954f, + -1.72469437f, -1.74285619f, -1.72109242f, -1.54947478f, -1.36444454f, + -1.78795939f, -1.62871901f, -1.43244094f, -1.83058005f, -1.43770547f, + -1.3577647f, -1.81454222f, -1.58227661f, -1.89836191f, -1.49373763f, + -1.52027507f, -1.41807732f, -1.54481537f, -1.86538837f, -1.76619851f, + -1.64547283f, -1.58328753f, -1.58442673f, -1.65941447f, -1.57762943f, + -1.54091641f, -1.76747862f, -1.56063854f, -1.76235545f, -1.45495771f, + -1.37294933f, -1.75871646f, -1.38392315f, -1.62238305f, -2.06866473f, + -1.98087487f, -1.49880371f, -2.14268396f, -1.22969736f, -1.47432277f + }); + + auto logits_length = NDArrayFactory::create('c', {BATCH_LEN}, {FRAME_LEN,FRAME_LEN,FRAME_LEN,FRAME_LEN}); + std::vector target ={2, 2, 2, 0, 1, 1, 0, 0, 1, 2, 2, 3, 0, 2, 1, 2}; +#if defined(HAVE_CUDNN) + //for cudnn blank index is 0. therefore our targets cant be 0 + for(int i=0;i('c',{BATCH_LEN, MAX_TARGET_LEN}, target ); + auto labels_len = NDArrayFactory::create('c', {BATCH_LEN}, {MIN_TARGET_LEN, MIN_TARGET_LEN +1, MAX_TARGET_LEN, MIN_TARGET_LEN +1}); +#if defined(HAVE_CUDNN) +//results for blank Index=0 + auto expected = NDArrayFactory::create('c', {BATCH_LEN, FRAME_LEN, CLASS_LEN}, + { + -0.2673936f, 0.17510113f, 0.16634358f, -0.33129925f, 0.2572481f, + -0.17626494f, 0.19112396f, 0.2674601f, -0.44990796f, 0.1675888f, + -0.33695614f, 0.27529928f, 0.1543045f, -0.28359637f, 0.19094874f, + -0.26243734f, 0.1236309f, 0.13383625f, -0.26430953f, 0.26927972f, + -0.33964074f, 0.21812534f, 0.1907491f, -0.3002034f, 0.23096953f, + -0.200618f, 0.15514892f, 0.19264314f, -0.3310032f, 0.18382908f, + -0.04921098f, 0.21598133f, -0.52588296f, 0.13597165f, 0.22314091f, + -0.38300496f, 0.11730913f, -0.2633105f, 0.2825293f, 0.24647695f, + -0.34686768f, 0.16539758f, -0.280806f, 0.24202588f, 0.22025016f, + -0.21347934f, 0.19306758f, -0.304228f, 0.18027757f, 0.14436226f, + 0.02692442f, -0.08318196f, -0.2236172f, 0.15634498f, 0.12352975f, + 0.03155032f, -0.5855137f, 0.14724013f, 0.18989684f, 0.2168265f, + 0.10374172f, 0.11116405f, -0.67208123f, 0.25178862f, 0.20538692f, + 0.09189357f, 0.14803931f, 0.00725803f, -0.5132462f, 0.2660552f, + -0.4309733f, 0.16058321f, 0.16320339f, -0.21557501f, 0.32276183f, + -0.32850766f, 0.2217448f, 0.21034124f, -0.2934553f, 0.18987685f, + 0.06212101f, 0.1750198f, 0.17887063f, -0.38780046f, -0.02821094f, + 0.05002825f, 0.19618073f, 0.23872548f, 0.16032055f, -0.64525515f, + -0.19972575f, -0.38012666f, 0.20550671f, 0.14981383f, 0.22453187f, + -0.02966774f, -0.34505254f, 0.21335125f, -0.00961271f, 0.17098173f, + -0.04058227f, -0.03726651f, 0.16733989f, -0.295955f, 0.20646395f, + -0.05670565f, 0.12657055f, -0.00966609f, -0.2936089f, 0.23341022f, + -0.01142454f, 0.17226583f, -0.2727364f, -0.01445916f, 0.12635438f, + -0.23244353f, 0.22339724f, -0.5122685f, 0.29238105f, 0.2289337f + }); +#else + auto expected = NDArrayFactory::create('c', {BATCH_LEN, FRAME_LEN, CLASS_LEN}, + { + 0.21675213f, 0.17510113f, -0.27113008f, 0.18455505f, -0.30527824f, + 0.12460334f, 0.19112396f, -0.44803357f, 0.24922381f, -0.11691755f, + 0.16021198f, 0.27529928f, -0.28298444f, 0.21923551f, -0.37176234f, + 0.20386407f, 0.1236309f, -0.15528734f, 0.2693891f, -0.44159663f, + 0.23395306f, 0.21812534f, -0.36457074f, 0.12620285f, -0.21371071f, + 0.28493422f, 0.15514892f, -0.4384392f, 0.18344463f, -0.18508859f, + 0.19713868f, -0.61835873f, 0.22776747f, 0.13597165f, 0.05748086f, + 0.20409954f, -0.17006806f, 0.14958507f, 0.2825293f, -0.46614605f, + 0.218183f, -0.28762838f, 0.15414338f, 0.24202588f, -0.32672384f, + 0.09618269f, -0.40792802f, 0.15407808f, 0.18027757f, -0.02261038f, + -0.40063405f, -0.04311697f, 0.3049292f, 0.15634498f, -0.01752307f, + -0.43639395f, 0.31014743f, 0.14724013f, 0.18989684f, -0.21089047f, + 0.24827974f, -0.8280775f, 0.1833807f, 0.25178862f, 0.1446285f, + 0.15652135f, 0.05439584f, -0.5887033f, 0.17083165f, 0.20695446f, + 0.1825678f, 0.1605832f, -0.04697506f, 0.17088373f, -0.4670597f, + 0.13331066f, 0.2217448f, -0.46589473f, 0.24472642f, -0.13388708f, + 0.17822751f, 0.1750198f, -0.27072078f, -0.15830047f, 0.07577389f, + 0.16730122f, 0.19618073f, 0.23872548f, -0.618405f, 0.01619747f, + -0.41614607f, 0.16291247f, 0.20550671f, 0.14981383f, -0.10208681f, + -0.32300252f, 0.2421792f, -0.01448151f, 0.15483606f, -0.05953133f, + -0.03524604f, 0.1660878f, -0.24423766f, 0.19025035f, -0.07685445f, + 0.1546654f, 0.00699046f, -0.26606354f, 0.17164008f, -0.06723261f, + 0.2533586f, -0.31069174f, -0.07983261f, 0.19742766f, -0.06026195f, + 0.1379485f, -0.47723943f, 0.11733948f, 0.29238105f, -0.07042958 + }); +#endif + sd::ops::ctc_loss_grad op; + + auto results = op.evaluate({&labels, &logits, &labels_len, &logits_length}, {}, {BLANK_INDEX}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *gradient = results.at(0); + + //gradient->printIndexedBuffer("gradient"); + + ASSERT_TRUE(expected.isSameShape(gradient)); + ASSERT_TRUE(expected.equalsTo(gradient, 1.e-06)); + +} + +#endif + +TEST_F(DeclarableOpsTests2, ctc_beam_test1) { + constexpr int CLASS_LEN = 5 ; + constexpr int BATCH_LEN = 1 ; + constexpr int MAX_FRAME_LEN = 3; + constexpr int NBEST_LEN = 2; + constexpr int BEAM_WIDTH = 3; + constexpr int BLANK_INDEX=CLASS_LEN-1; + auto logits = NDArrayFactory::create('c', { BATCH_LEN, MAX_FRAME_LEN, CLASS_LEN }, + { + -2.578319f, -1.091237f, -1.519336f, -2.115322f, -1.390921f, + -1.901657f, -2.46196f, -1.718925f, -0.837558f, -1.874794f, + -1.761921f, -1.125581f, -2.378538f, -1.907196f, -1.336974f + }); + auto logits_length = NDArrayFactory::create('c', { BATCH_LEN }, { 3 }); + + auto output_sequence = NDArrayFactory::create('c', { BATCH_LEN, NBEST_LEN, MAX_FRAME_LEN}); + auto output_seq_prob = NDArrayFactory::create('c', { BATCH_LEN, NBEST_LEN}); + auto output_seq_length = NDArrayFactory::create('c', { BATCH_LEN, NBEST_LEN}); + + auto expected_seq = NDArrayFactory::create('c', {BATCH_LEN, NBEST_LEN, MAX_FRAME_LEN}, + {1, 3, 0, + 1, 3, 1}); + + auto expected_length = NDArrayFactory::create('c', {BATCH_LEN, NBEST_LEN }, {2, 3}); + + auto expected_probs = NDArrayFactory::create('c', {BATCH_LEN, NBEST_LEN }, {-2.817627f, -3.054376f}); + + sd::ops::ctc_beam op; + + auto result = op.execute({ &logits, &logits_length}, {&output_sequence, &output_seq_prob, &output_seq_length }, {BLANK_INDEX, BEAM_WIDTH, NBEST_LEN}); + + ASSERT_EQ(Status::OK(), result); + ASSERT_TRUE(expected_seq.equalsTo(output_sequence)); + ASSERT_TRUE(expected_probs.equalsTo(output_seq_prob)); + ASSERT_TRUE(expected_length.equalsTo(output_seq_length)); + +} + +TEST_F(DeclarableOpsTests2, ctc_beam_test2) { + constexpr int CLASS_LEN = 5 ; + constexpr int BATCH_LEN = 4 ; + constexpr int MIN_FRAME_LEN = 4; + constexpr int MAX_FRAME_LEN = 6; + constexpr int NBEST_LEN = 1; + constexpr int BEAM_WIDTH = 3; + constexpr int BLANK_INDEX=CLASS_LEN-1; + + auto logits = NDArrayFactory::create('c', {BATCH_LEN, MAX_FRAME_LEN, CLASS_LEN }, + {-1.52900087f, -1.7423916f, -1.79369985f, -1.68980741f, -1.35771429f, + -2.08261997f, -1.65483307f, -1.31878488f, -1.38940393f, -1.78624192f, + -1.83125744f, -1.28989651f, -1.86882736f, -1.51760877f, -1.65575026f, + -1.59030191f, -2.09045484f, -2.01113821f, -1.31159853f, -1.3120046f, + -1.45263472f, -1.52268525f, -1.6567962f, -2.06986454f, -1.46546941f, + -1.25549694f, -1.86336982f, -1.64691575f, -1.69584239f, -1.69374889f, + -1.62384788f, -1.53256338f, -1.47943003f, -1.9953089f, -1.49995189f, + -1.58914748f, -2.14294273f, -1.89989005f, -1.26397295f, -1.40048678f, + -1.52242117f, -1.79940303f, -1.86987214f, -1.41871056f, -1.51299132f, + -1.41772259f, -1.27648263f, -1.87029582f, -1.71325761f, -1.93542947f, + -1.4372372f, -1.72814911f, -1.18767571f, -1.85569031f, -2.09127332f, + -1.99591619f, -1.17070749f, -1.91569048f, -1.66127429f, -1.52865783f, + -1.39319926f, -2.19674832f, -1.69619098f, -1.37916537f, -1.58285964f, + -1.85456282f, -1.91027747f, -1.35265643f, -1.76707679f, -1.32405154f, + -1.70063352f, -1.82894304f, -1.81275811f, -1.76677183f, -1.13084056f, + -2.01507311f, -1.50622804f, -1.55902412f, -1.4076143f, -1.66137954f, + -1.72469437f, -1.74285619f, -1.72109242f, -1.54947478f, -1.36444454f, + -1.78795939f, -1.62871901f, -1.43244094f, -1.83058005f, -1.43770547f, + -1.3577647f, -1.81454222f, -1.58227661f, -1.89836191f, -1.49373763f, + -1.52027507f, -1.41807732f, -1.54481537f, -1.86538837f, -1.76619851f, + -1.64547283f, -1.58328753f, -1.58442673f, -1.65941447f, -1.57762943f, + -1.54091641f, -1.76747862f, -1.56063854f, -1.76235545f, -1.45495771f, + -1.37294933f, -1.75871646f, -1.38392315f, -1.62238305f, -2.06866473f, + -1.98087487f, -1.49880371f, -2.14268396f, -1.22969736f, -1.47432277f + }); + + auto logits_length = NDArrayFactory::create('c', {BATCH_LEN}, {MAX_FRAME_LEN, MAX_FRAME_LEN, MAX_FRAME_LEN, MAX_FRAME_LEN}); + + auto expected_seq = NDArrayFactory::create('c', {BATCH_LEN, NBEST_LEN, MAX_FRAME_LEN}, + {3, 1, 3, 0, 0, 0, + 2, 3, 1, 0, 0, 0, + 3, 2, 3, 0, 0, 0, + 0, 1, 0, 0, 0, 0}); + + auto expected_length = NDArrayFactory::create('c', {BATCH_LEN, NBEST_LEN }, {3, 3, 3, 3}); + + auto expected_probs = NDArrayFactory::create('c', {BATCH_LEN, NBEST_LEN }, {-5.497302f, -5.469760f, -5.338807f,-5.520249f}); + + sd::ops::ctc_beam op; + + auto results = op.evaluate({ &logits, &logits_length}, {}, {BATCH_LEN, BEAM_WIDTH, NBEST_LEN}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *result_sequence = results.at(0); + auto *result_probs = results.at(1); + auto *result_sequence_length = results.at(2); + + ASSERT_TRUE(expected_seq.equalsTo(result_sequence)); + ASSERT_TRUE(expected_probs.equalsTo(result_probs)); + ASSERT_TRUE(expected_length.equalsTo(result_sequence_length)); + +} diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests3.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests3.cpp new file mode 100644 index 000000000..c4c90a586 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests3.cpp @@ -0,0 +1,2764 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include + + +using namespace sd; +using namespace sd::graph; + +class DeclarableOpsTests3 : public testing::Test { +public: + + DeclarableOpsTests3() { +// + } +}; + + +TEST_F(DeclarableOpsTests3, Test_Tile_1) { + auto x= NDArrayFactory::create('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto rep_vector= NDArrayFactory::create('c', {1, 2}, {2, 2}); + std::vector reps({2, 2}); + + auto exp = x.tile(reps); + + sd::ops::tile op; + auto result = op.evaluate({&x, &rep_vector}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(DeclarableOpsTests3, Test_Tile_2) { + auto x= NDArrayFactory::create('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + std::vector reps({2, 2}); + + auto exp = x.tile(reps); + + sd::ops::tile op; + auto result = op.evaluate({&x}, {}, {2, 2}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests3, Test_Permute_1) { + auto x= NDArrayFactory::create('c', {2, 3, 4}); + auto permute= NDArrayFactory::create('c', {1, 3}, {0, 2, 1}); + auto exp= NDArrayFactory::create('c', {2, 4, 3}); + + sd::ops::permute op; + auto result = op.evaluate({&x, &permute}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); +} + +TEST_F(DeclarableOpsTests3, Test_Permute_2) { + auto x= NDArrayFactory::create('c', {2, 3, 4}); + auto exp= NDArrayFactory::create('c', {4, 3, 2}); + + sd::ops::permute op; + auto result = op.evaluate({&x}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + +} + + +TEST_F(DeclarableOpsTests3, Test_Unique_1) { + auto x= NDArrayFactory::create('c', {1, 5}, {1.f, 2.f, 1.f, 2.f, 3.f}); + auto expV= NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + auto expI= NDArrayFactory::create('c', {5}, {0, 1, 0, 1, 2}); +// auto expI= NDArrayFactory::create('c', {3}, {0, 1, 4}); + + sd::ops::unique op; + auto result = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); + + auto v = result.at(0); + auto i = result.at(1); + // v->printIndexedBuffer("Values"); + // i->printIndexedBuffer("Indices"); + // i->printShapeInfo("Indices shape"); + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); + + ASSERT_TRUE(expI.isSameShape(i)); + ASSERT_TRUE(expI.equalsTo(i)); + + +} + +TEST_F(DeclarableOpsTests3, Test_Unique_2) { + auto x= NDArrayFactory::create('c', {1, 5}, {1.f, 2.f, 1.f, 2.f, 3.f}); + auto expV= NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + auto expI= NDArrayFactory::create('c', {5}, {0, 1, 0, 1, 2}); + auto expC= NDArrayFactory::create('c', {3}, {2, 2, 1}); + + sd::ops::unique_with_counts op; + auto result = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(3, result.size()); + + auto v = result.at(0); + auto i = result.at(1); + auto c = result.at(2); + + // v->printShapeInfo(); + // v->printIndexedBuffer("Values"); + // i->printShapeInfo(); + // i->printIndexedBuffer("Indices"); + // c->printShapeInfo(); + // c->printIndexedBuffer("Counts"); + + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); + + ASSERT_TRUE(expI.isSameShape(i)); + ASSERT_TRUE(expI.equalsTo(i)); + + ASSERT_TRUE(expC.isSameShape(c)); + ASSERT_TRUE(expC.equalsTo(c)); +} + +TEST_F(DeclarableOpsTests3, Test_Rint_1) { + auto x= NDArrayFactory::create('c', {1, 7}, {-1.7f, -1.5f, -0.2f, 0.2f, 1.5f, 1.7f, 2.0f}); + auto exp= NDArrayFactory::create('c', {1, 7}, {-2.f, -2.f, -0.f, 0.f, 2.f, 2.f, 2.f}); + + sd::ops::rint op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(DeclarableOpsTests3, Test_Norm_1) { + auto x = NDArrayFactory::create('c', {100, 100}); + x.linspace(1); + + std::vector empty; + std::vector dims({1}); + sd::ops::norm op; + + auto result0 = op.evaluate({&x}, {0.}, {}); + + auto z0 = result0.at(0); + auto exp0 = x.reduceAlongDimension(reduce::NormFrobenius, empty); + ASSERT_TRUE(exp0.isSameShape(z0)); + ASSERT_TRUE(exp0.equalsTo(z0)); + + auto result1 = op.evaluate({&x}, {1.}, {1}); + ASSERT_EQ(result1.status(), ND4J_STATUS_OK); + auto z1 = result1.at(0); + // z1->printIndexedBuffer("Z1"); + auto exp1 = x.reduceAlongDimension(reduce::Norm2, dims); + // exp1.printIndexedBuffer("EXP1"); + // z1->printShapeInfo("Z1 shape"); + // exp1.printShapeInfo("EXP1 shape"); + ASSERT_TRUE(exp1.isSameShape(z1)); + ASSERT_TRUE(exp1.equalsTo(z1)); + + auto result4 = op.evaluate({&x}, {4.}, {1}); + + auto z4 = result4.at(0); + auto exp4= x.reduceAlongDimension(reduce::NormMax, dims); + ASSERT_TRUE(exp4.isSameShape(z4)); + ASSERT_TRUE(exp4.equalsTo(z4)); +} + + +TEST_F(DeclarableOpsTests3, Test_Norm_2) { + auto x = NDArrayFactory::create('c', {100, 100}); + x.linspace(1); + auto axis= NDArrayFactory::create('c', {1, 1}, {1}); + + std::vector empty; + std::vector dims({1}); + sd::ops::norm op; + + auto result0 = op.evaluate({&x}, {0}, {}); + + auto z0 = result0.at(0); + auto exp0 = x.reduceAlongDimension(reduce::NormFrobenius, empty); + ASSERT_TRUE(exp0.isSameShape(z0)); + ASSERT_TRUE(exp0.equalsTo(z0)); + + + + auto result1 = op.evaluate({&x, &axis}, {1}, {}); + + auto z1 = result1.at(0); + auto exp1 = x.reduceAlongDimension(reduce::Norm2, dims); + ASSERT_TRUE(exp1.isSameShape(z1)); + ASSERT_TRUE(exp1.equalsTo(z1)); + + auto result4 = op.evaluate({&x, &axis}, {4}, {}); + + auto z4 = result4.at(0); + auto exp4= x.reduceAlongDimension(reduce::NormMax, dims); + ASSERT_TRUE(exp4.isSameShape(z4)); + ASSERT_TRUE(exp4.equalsTo(z4)); + +} + +TEST_F(DeclarableOpsTests3, Test_ListDiff_1) { + auto x= NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto y= NDArrayFactory::create('c', {3}, {1.f, 3.f, 5.f}); + + auto exp0= NDArrayFactory::create('c', {3}, {2.f, 4.f, 6.f}); + auto exp1= NDArrayFactory::create('c', {3}, {1, 3, 5}); + + sd::ops::listdiff op; + auto result = op.evaluate({&x, &y}); + + ASSERT_EQ(Status::OK(), result.status()); + + auto z0 = result.at(0); + auto z1 = result.at(1); + + z0->getDataBuffer()->syncToSpecial(true); // force sync + z1->getDataBuffer()->syncToSpecial(true); // force sync + + ASSERT_TRUE(exp0.isSameShape(z0)); + ASSERT_TRUE(exp0.equalsTo(z0)); + + ASSERT_TRUE(exp1.isSameShape(z1)); + ASSERT_TRUE(exp1.equalsTo(z1)); + +} + +TEST_F(DeclarableOpsTests3, Test_Range_1) { + auto start = NDArrayFactory::create(0.3f); + auto stop = NDArrayFactory::create(-5.f); + auto step = NDArrayFactory::create(-0.33f); + auto exp= NDArrayFactory::create('c', {17}, { 0.3f, -0.03f, -0.36f, -0.69f, -1.02f, -1.35f, -1.68f, -2.01f, -2.34f, -2.67f, -3.f, -3.33f, -3.66f, -3.99f, -4.32f, -4.65f, -4.98f}); + + sd::ops::range op; + auto result = op.evaluate({&start, &stop, &step}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + + +TEST_F(DeclarableOpsTests3, Test_Range_2) { + auto start= NDArrayFactory::create('c', {1, 1}, {2.f}); + auto stop= NDArrayFactory::create('c', {1, 1}, {0.f}); + auto step= NDArrayFactory::create('c', {1, 1}, {-1.f}); + auto exp= NDArrayFactory::create('c', {2}, {2.f, 1.f}); + + sd::ops::range op; + auto result = op.evaluate({&start, &stop, &step}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +TEST_F(DeclarableOpsTests3, Test_Range_3) { + auto start= NDArrayFactory::create('c', {1, 1}, {0.f}); + auto stop= NDArrayFactory::create('c', {1, 1}, {2.f}); + auto step= NDArrayFactory::create('c', {1, 1}, {1.f}); + auto exp= NDArrayFactory::create('c', {2}, {0.f, 1.f}); + + sd::ops::range op; + auto result = op.evaluate({&start, &stop, &step}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + + +TEST_F(DeclarableOpsTests3, Test_Range_10) { + auto start= NDArrayFactory::create('c', {1, 1}, {0.f}); + auto stop= NDArrayFactory::create('c', {1, 1}, {2.f}); + auto step= NDArrayFactory::create('c', {1, 1}, {1.f}); + auto exp= NDArrayFactory::create('c', {2}, {0.f, 1.f}); + + sd::ops::range op; + auto result = op.evaluate({&start, &stop, &step}, {sd::DataType::DOUBLE}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + + +TEST_F(DeclarableOpsTests3, Test_Range_4) { + auto exp= NDArrayFactory::create('c', {13}, {-10.f, -8.334f, -6.668f, -5.002f, -3.336f, -1.67f, -0.004f, 1.662f, 3.328f, 4.994f, 6.66f, 8.326f, 9.992f}); + + sd::ops::range op; + auto result = op.evaluate({}, {-10., 10., 1.666}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + + +TEST_F(DeclarableOpsTests3, Test_Range_5) { + auto exp= NDArrayFactory::create('c', {2}, {2.f, 1.f}); + + sd::ops::range op; + auto result = op.evaluate({}, {2, 0, -1}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +TEST_F(DeclarableOpsTests3, Test_Range_6) { + auto exp= NDArrayFactory::create('c', {2}, {0.f, 1.f}); + + sd::ops::range op; + auto result = op.evaluate({}, {0, 2, 1}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +TEST_F(DeclarableOpsTests3, Test_Range_7) { + auto exp= NDArrayFactory::create('c', {10}, {10.f, 8.334f, 6.668f, 5.002f, 3.336f, 1.67f, 0.004f, -1.662f, -3.328f, -4.994f}); + + sd::ops::range op; + auto result = op.evaluate({}, {10,-5,-1.666}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + + + +TEST_F(DeclarableOpsTests3, Test_Range_8) { + auto exp= NDArrayFactory::create('c', {2}, {2, 1}); + + sd::ops::range op; + auto result = op.evaluate({}, {}, {2, 0, -1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +TEST_F(DeclarableOpsTests3, Test_Range_9) { + auto exp= NDArrayFactory::create('c', {2}, {0, 1}); + + sd::ops::range op; + auto result = op.evaluate({}, {}, {0, 2, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_1) { + auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); + auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); + auto x= NDArrayFactory::create('f', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y= NDArrayFactory::create('f', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + + auto exp = MmulHelper::mmul(&x, &y); + + sd::ops::batched_gemm op; + auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {111, 111, 3, 3, 3, 3, 3, 3, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_EQ(3, result.size()); + + for (int e = 0; e < 3; e++) { + auto z = result.at(e); + +// exp->printIndexedBuffer("e"); +// z->printIndexedBuffer("z"); + + ASSERT_TRUE(exp->isSameShape(z)); + ASSERT_TRUE(exp->equalsTo(z)); + } + + delete exp; + +} + +TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_2) { + auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); + auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); + auto x= NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y= NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + + auto exp = MmulHelper::mmul(&x, &y); + + sd::ops::batched_gemm op; + auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 3, 3, 3, 3, 3, 3, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_EQ(3, result.size()); + + for (int e = 0; e < 3; e++) { + auto z = result.at(e); + + //exp->printIndexedBuffer("e"); + //z->printIndexedBuffer("z"); + + ASSERT_TRUE(exp->isSameShape(z)); + ASSERT_TRUE(exp->equalsTo(z)); + } + + delete exp; + +} + +TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_3) { + auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); + auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); + auto x= NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y= NDArrayFactory::create('f', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + + auto exp = MmulHelper::mmul(&x, &y); + + sd::ops::batched_gemm op; + auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 111, 3, 3, 3, 3, 3, 3, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_EQ(3, result.size()); + + for (int e = 0; e < 3; e++) { + auto z = result.at(e); + +// exp->printIndexedBuffer("e"); +// z->printIndexedBuffer("z"); + + ASSERT_TRUE(exp->isSameShape(z)); + ASSERT_TRUE(exp->equalsTo(z)); + } + + delete exp; + +} + +TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_4) { + auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); + auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); + auto x= NDArrayFactory::create('f', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + auto y= NDArrayFactory::create('f', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + + auto exp = MmulHelper::mmul(&x, &y); + + sd::ops::batched_gemm op; + auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {111, 111, 5, 4, 3, 5, 3, 5, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_EQ(3, result.size()); + + for (int e = 0; e < 3; e++) { + auto z = result.at(e); + + //exp->printIndexedBuffer("e"); + //z->printIndexedBuffer("z"); + + ASSERT_TRUE(exp->isSameShape(z)); + ASSERT_TRUE(exp->equalsTo(z)); + } + + delete exp; + +} + +TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_5) { + auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); + auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); + auto x= NDArrayFactory::create('c', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + auto y= NDArrayFactory::create('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + + auto exp = MmulHelper::mmul(&x, &y); + + sd::ops::batched_gemm op; + auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 5, 4, 3, 3, 4, 5, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_EQ(3, result.size()); + + for (int e = 0; e < 3; e++) { + auto z = result.at(e); + + //exp->printIndexedBuffer("e"); + //z->printIndexedBuffer("z"); + + ASSERT_TRUE(exp->isSameShape(z)); + ASSERT_TRUE(exp->equalsTo(z)); + } + + delete exp; + +} + + +TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_6) { + auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); + auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); + auto x= NDArrayFactory::create('f', {2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + auto y= NDArrayFactory::create('f', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + + auto exp = MmulHelper::mmul(&x, &y); + + sd::ops::batched_gemm op; + auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {111, 111, 2, 3, 5, 2, 5, 2, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_EQ(3, result.size()); + + for (int e = 0; e < 3; e++) { + auto z = result.at(e); + + //exp->printIndexedBuffer("e"); + //z->printIndexedBuffer("z"); + + ASSERT_TRUE(exp->isSameShape(z)); + ASSERT_TRUE(exp->equalsTo(z)); + } + + delete exp; + +} + +TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_7) { + auto a= NDArrayFactory::create('c', {1, 3}, {1, 1, 1}); + auto b= NDArrayFactory::create('c', {1, 3}, {0, 0, 0}); + auto x= NDArrayFactory::create('c', {2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + auto y= NDArrayFactory::create('c', {5, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); + + auto exp = MmulHelper::mmul(&x, &y); + + // exp->printShapeInfo("exp shape"); + + sd::ops::batched_gemm op; + auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_EQ(3, result.size()); + + for (int e = 0; e < 3; e++) { + auto z = result.at(e); + + //exp->printIndexedBuffer("e"); + //z->printIndexedBuffer("z"); + + ASSERT_TRUE(exp->isSameShape(z)); + ASSERT_TRUE(exp->equalsTo(z)); + } + + delete exp; + +} + +TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_Validation_1) { + auto a = NDArrayFactory::create('c', {1, 3}, {1.f, 1.f, 1.f}); + auto b = NDArrayFactory::create('c', {1, 3}, {0.f, 0.f, 0.f}); + auto x = NDArrayFactory::create('c', {2, 5}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f}); + auto y = NDArrayFactory::create('c', {5, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}); + + sd::ops::batched_gemm op; + try { + auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3}); + + ASSERT_TRUE(false); + } catch (std::invalid_argument &e) { + // + } +} + +TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_Validation_2) { + auto a = NDArrayFactory::create('c', {1, 3}, {1.f, 1.f, 1.f}); + auto b = NDArrayFactory::create('c', {1, 3}, {0.f, 0.f, 0.f}); + auto x = NDArrayFactory::create('c', {2, 5}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f}); + auto y = NDArrayFactory::create('c', {5, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}); + + auto z = NDArrayFactory::create('c', {2, 3}); + + sd::ops::batched_gemm op; + try { + auto result = op.execute({&a, &b, &x, &x, &x, &y, &y, &y}, {&z}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3}, {}); + ASSERT_TRUE(false); + } catch (std::invalid_argument &e) { + // + } +} + +TEST_F(DeclarableOpsTests3, Test_ReverseDivide_1) { + auto x= NDArrayFactory::create('c', {1, 3}, {2, 2, 2}); + auto y= NDArrayFactory::create('c', {1, 3}, {4, 6, 8}); + auto exp= NDArrayFactory::create('c', {1, 3}, {2, 3, 4}); + + sd::ops::reversedivide op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, sruCell_test1) { + + const int batchSize = 2; + const int inSize = 5; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ct_1= NDArrayFactory::create('c', {batchSize, inSize}); + auto w = NDArrayFactory::create('c', {inSize, 3*inSize}); + auto b = NDArrayFactory::create('c', {2*inSize}); + + xt.assign(1.); + ct_1.assign(2.); + w.assign(0.5); + b.assign(0.7); + + auto expHt= NDArrayFactory::create('c', {batchSize, inSize}, {0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f, 0.96674103f}); + auto expCt= NDArrayFactory::create('c', {batchSize, inSize}, {2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f, 2.01958286f}); + + sd::ops::sruCell op; + auto results = op.evaluate({&xt, &ct_1, &w, &b}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *ht = results.at(0); + auto *ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); + +} + + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, sruCell_test2) { + + const int batchSize = 2; + const int inSize = 5; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ct_1= NDArrayFactory::create('c', {batchSize, inSize}); + auto w = NDArrayFactory::create('c', {inSize, 3*inSize}); + auto b = NDArrayFactory::create('c', {2*inSize}); + + xt.assign(1.); + ct_1.assign(2.); + w.assign(0.5); + b.assign(-1.); + + auto expHt= NDArrayFactory::create('c', {batchSize, inSize}, {0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f, 0.97542038f}); + auto expCt= NDArrayFactory::create('c', {batchSize, inSize}, {2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f, 2.09121276f}); + + sd::ops::sruCell op; + auto results = op.evaluate({&xt, &ct_1, &w, &b}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *ht = results.at(0); + auto *ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, sruCell_test3) { + + const int batchSize = 2; + const int inSize = 5; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ct_1= NDArrayFactory::create('c', {batchSize, inSize}); + auto w = NDArrayFactory::create('c', {inSize, 3*inSize}); + auto b = NDArrayFactory::create('c', {2*inSize}); + + xt.assign(10.); + ct_1.assign(1.); + w.assign(0.5); + b.assign(-1.); + + auto expHt= NDArrayFactory::create('c', {batchSize, inSize}, {0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f, 0.76159416f}); + auto expCt= NDArrayFactory::create('c', {batchSize, inSize}, {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}); + + sd::ops::sruCell op; + auto results = op.evaluate({&xt, &ct_1, &w, &b}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *ht = results.at(0); + auto *ct = results.at(1); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + ASSERT_TRUE(expCt.isSameShape(ct)); + ASSERT_TRUE(expCt.equalsTo(ct)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, gruCell_test1) { + + const int batchSize = 2; + const int inSize = 10; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wru = NDArrayFactory::create('c', {(inSize+numUnits), 2*numUnits}); + auto Wc = NDArrayFactory::create('c', {(inSize+numUnits), numUnits}); + auto bru = NDArrayFactory::create('c', {2*numUnits}); + auto bc = NDArrayFactory::create('c', {numUnits}); + + xt.assign(1.); + ht_1.assign(2.); + Wru.assign(0.5); + Wc.assign(0.5); + bru.assign(0.7); + bc.assign(0.7); + + auto expHt = NDArrayFactory::create('c', {batchSize, numUnits}, {1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f, 1.99993872f}); + + sd::ops::gruCell op; + auto results = op.evaluate({&xt, &ht_1, &Wru, &Wc, &bru, &bc}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *ht = results.at(3); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, gruCell_test2) { + + const int batchSize = 2; + const int inSize = 10; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wru = NDArrayFactory::create('c', {(inSize+numUnits), 2*numUnits}); + auto Wc = NDArrayFactory::create('c', {(inSize+numUnits), numUnits}); + auto bru = NDArrayFactory::create('c', {2*numUnits}); + auto bc = NDArrayFactory::create('c', {numUnits}); + + xt.assign(1.); + ht_1.assign(0.); + Wru.assign(1.5); + Wc.assign(1.5); + bru.assign(-10); + bc.assign(-10); + + auto expHt= NDArrayFactory::create('c', {batchSize, numUnits}, {0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f, 0.00669224f}); + + sd::ops::gruCell op; + auto results = op.evaluate({&xt, &ht_1, &Wru, &Wc, &bru, &bc}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *ht = results.at(3); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, gruCell_test3) { + + const int batchSize = 2; + const int inSize = 10; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {batchSize, inSize}); + auto ht_1= NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wru = NDArrayFactory::create('c', {(inSize+numUnits), 2*numUnits}); + auto Wc = NDArrayFactory::create('c', {(inSize+numUnits), numUnits}); + auto bru = NDArrayFactory::create('c', {2*numUnits}); + auto bc = NDArrayFactory::create('c', {numUnits}); + + xt.assign(1.); + ht_1.assign(0.); + Wru.assign(0.1); + Wc.assign(0.1); + bru.assign(1); + bc.assign(1); + + auto expHt= NDArrayFactory::create('c', {batchSize, numUnits}, {0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f, 0.1149149f}); + + + sd::ops::gruCell op; + auto result = op.evaluate({&xt, &ht_1, &Wru, &Wc, &bru, &bc}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *ht = result.at(3); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, invertPermutation_test1) { + + auto input= NDArrayFactory::create('c', {1, 8}, {5,2,7,4,6,3,1,0}); + auto expected= NDArrayFactory::create('c', {1, 8}, {7, 6, 1, 5, 3, 0, 4, 2}); + + sd::ops::invert_permutation op; + auto result = op.evaluate({&input}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, invertPermutation_test2) { + + auto input= NDArrayFactory::create('c', {1, 8}, {5,2,7,4,6,3,1,0}); + auto expected= NDArrayFactory::create('c', {1, 8}, {7, 6, 1, 5, 3, 0, 4, 2}); + + + sd::ops::invert_permutation op; + auto result = op.evaluate({&input}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, invertPermutation_test3) { + + auto input= NDArrayFactory::create('c', {1, 8}, {1,2,0,4,6,3,5,7}); + auto expected= NDArrayFactory::create('c', {1, 8}, {2, 0, 1, 5, 3, 6, 4, 7}); + + sd::ops::invert_permutation op; + auto result = op.evaluate({&input}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, diag_test1) { + + auto input= NDArrayFactory::create('c', {3, 2}); + input.linspace(1); + + auto expected= NDArrayFactory::create('c', {3,2,3,2}, {1,0,0,0,0,0, 0,2,0,0,0,0, 0,0,3,0,0,0, 0,0,0,4,0,0, 0,0,0,0,5,0, 0,0,0,0,0,6}); + + sd::ops::diag op; + auto result = op.evaluate({&input}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, diag_test2) { + + auto input= NDArrayFactory::create('c', {2, 3}); + input.linspace(1); + + auto expected= NDArrayFactory::create('c', {2,3,2,3}, {1,0,0,0,0,0, 0,2,0,0,0,0, 0,0,3,0,0,0, 0,0,0,4,0,0, 0,0,0,0,5,0, 0,0,0,0,0,6}); + + sd::ops::diag op; + auto result = op.evaluate({&input}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, diag_test_vector) { + + + auto input = NDArrayFactory::linspace(1,4,4); + auto expected= NDArrayFactory::create('c', {4,4}, {1,0,0,0, 0,2,0,0, 0,0,3,0,0,0,0,4}); + + sd::ops::diag op; + auto result = op.evaluate({input}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete input; +} + +TEST_F(DeclarableOpsTests3, diag_test_col_vector) { + + + auto input = NDArrayFactory::linspace(1,4,4); + input->reshapei({4,1}); + auto expected= NDArrayFactory::create('c', {4,4}, {1,0,0,0, 0,2,0,0, 0,0,3,0,0,0,0,4}); + + sd::ops::diag op; + auto result = op.evaluate({input}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + + delete input; +} +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, diag_test3) { + + auto input= NDArrayFactory::create('c', {1, 3}); + input.linspace(1); + + auto expected= NDArrayFactory::create('c', {3,3}, {1,0,0, 0,2,0, 0,0,3}); + + sd::ops::diag op; + auto result = op.evaluate({&input}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, diag_test4) { + + auto input= NDArrayFactory::create('c', {3, 1}); + input.linspace(1); + + auto expected= NDArrayFactory::create('c', {3,3}, {1,0,0, 0,2,0, 0,0,3}); + + sd::ops::diag op; + auto result = op.evaluate({&input}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, diag_test5) { + + auto input= NDArrayFactory::create('c', {1, 1}); + input.linspace(2); + + auto expected= NDArrayFactory::create('c', {1,1}, {2}); + + sd::ops::diag op; + auto result = op.evaluate({&input}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, diag_test6) { + + auto input= NDArrayFactory::create('c', {2,2,2}); + input.linspace(1); + + auto expected= NDArrayFactory::create('c', {2,2,2,2,2,2}, {1,0,0,0, 0,0,0,0, 0,2,0,0, 0,0,0,0, 0,0,3,0, 0,0,0,0, 0,0,0,4, 0,0,0,0, 0,0,0,0, 5,0,0,0, 0,0,0,0, 0,6,0,0, 0,0,0,0, 0,0,7,0, 0,0,0,0, 0,0,0,8}); + + sd::ops::diag op; + auto result = op.evaluate({&input}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, matrixSetDiag_test1) { + + auto input= NDArrayFactory::create('c', {4,3,2}); + auto diagonal= NDArrayFactory::create('c', {4,2}); + input.assign(0.); + diagonal.assign(1.); + + auto expected= NDArrayFactory::create('c', {4,3,2}, {1,0,0,1,0,0, 1,0,0,1,0,0, 1,0,0,1,0,0, 1,0,0,1,0,0}); + + sd::ops::matrix_set_diag op; + auto result = op.evaluate({&input, &diagonal}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, matrixSetDiag_test2) { + + auto input= NDArrayFactory::create('c', {1,1,2}); + auto diagonal= NDArrayFactory::create('c', {1,1}); + input.assign(0.); + diagonal.assign(1.); + + auto expected= NDArrayFactory::create('c', {1,1,2}, {1.f, 0.f}); + + sd::ops::matrix_set_diag op; + auto result = op.evaluate({&input, &diagonal}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, matrixSetDiag_test3) { + + auto input= NDArrayFactory::create('c', {2,1,4}); + auto diagonal= NDArrayFactory::create('c', {2,1}); + input.assign(0.); + diagonal.assign(1.); + + auto expected= NDArrayFactory::create('c', {2,1,4}, {1,0,0,0,1,0,0,0}); + + sd::ops::matrix_set_diag op; + auto result = op.evaluate({&input, &diagonal}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, matrixSetDiag_test4) { + + auto input= NDArrayFactory::create('c', {2,1,4,1}); + auto diagonal= NDArrayFactory::create('c', {2,1,1}); + input.assign(0.); + diagonal.assign(1.); + + auto expected= NDArrayFactory::create('c', {2,1,4,1}, {1,0,0,0,1,0,0,0}); + + sd::ops::matrix_set_diag op; + auto result = op.evaluate({&input, &diagonal}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, diagPart_test1) { + + auto input= NDArrayFactory::create('c', {2,2}); + input.linspace(1); + + auto expected= NDArrayFactory::create('c', {2}, {1,4}); + + sd::ops::diag_part op; + auto result = op.evaluate({&input}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + // output->printBuffer(); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, diagPart_test2) { + + auto input= NDArrayFactory::create('c', {2,2,2,2}); + input.linspace(1); + + auto expected= NDArrayFactory::create('c', {2,2}, {1,6,11,16}); + + sd::ops::diag_part op; + auto result = op.evaluate({&input}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, diagPart_test3) { + + auto input= NDArrayFactory::create('c', {2,2,2,2,2,2}); + input.linspace(1); + + auto expected= NDArrayFactory::create('c', {2,2,2}, {1,10,19,28,37,46,55,64}); + + sd::ops::diag_part op; + auto result = op.evaluate({&input}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, betainc_test1) { + + auto a = NDArrayFactory::create('c', {3,3}); + auto b = NDArrayFactory::create('c', {3,3}); + auto x = NDArrayFactory::create('c', {3,3}); + + a.linspace((float16)0.1, (float16)0.1); + b.linspace((float16)0.1, (float16)0.1); + x.assign(0.1); + + auto expected = NDArrayFactory::create('c', {3,3}, {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f}); + + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output, 1e-2)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, betainc_test2) { + + auto a= NDArrayFactory::create('c', {3,3}); + auto b= NDArrayFactory::create('c', {3,3}); + auto x= NDArrayFactory::create('c', {3,3}); + + a.linspace(0.1, 0.1); + b.linspace(0.1, 0.1); + x.assign(0.1); + + auto expected= NDArrayFactory::create('c', {3,3}, {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f}); + + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, betainc_test3) { + + auto a= NDArrayFactory::create('c', {3,3}); + auto b= NDArrayFactory::create('c', {3,3}); + auto x= NDArrayFactory::create('c', {3,3}); + + a.linspace(0.1, 0.1); + b.linspace(0.1, 0.1); + x.assign(0.1); + + auto expected= NDArrayFactory::create('c', {3,3}, {0.40638509f, 0.33668978f, 0.28271242f, 0.23973916f, 0.20483276f, 0.17604725f, 0.15203027f, 0.13180567f, 0.114647f}); + + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, betainc_test4) { + + auto a= NDArrayFactory::create('c', {3,3}); + auto b= NDArrayFactory::create('c', {3,3}); + auto x= NDArrayFactory::create('c', {3,3}); + + a.linspace(1); + b.linspace(1); + x.assign(0.1); + + auto expected= NDArrayFactory::create('c', {3,3}, {1.00000000e-01f, 2.80000000e-02f, 8.56000000e-03f, 2.72800000e-03f, 8.90920000e-04f, 2.95706080e-04f, 9.92854864e-05f, 3.36248880e-05f, 1.14644360e-05f}); + + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output, 1e-6)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, betainc_test5) { + + auto a= NDArrayFactory::create('c', {3,3}); + auto b= NDArrayFactory::create('c', {3,3}); + auto x= NDArrayFactory::create('c', {3,3}); + + a.linspace(3200.); + b.linspace(3200.); + x.assign(0.1); + + auto expected= NDArrayFactory::create('c', {3,3}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); + + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output, 1e-6)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, betainc_test6) { + + auto a= NDArrayFactory::create('c', {3,3}); + auto b= NDArrayFactory::create('c', {3,3}); + auto x= NDArrayFactory::create('c', {3,3}); + + a.linspace(10.); + b.linspace(10.); + x.assign(0.1); + + auto expected= NDArrayFactory::create('c', {3,3}, {3.92988233e-06f, 1.35306497e-06f, 4.67576826e-07f, 1.62083416e-07f, 5.63356971e-08f, 1.96261318e-08f, 6.85120307e-09f, 2.39594668e-09f, 8.39227685e-10f}); + + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output, 1e-6)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, betainc_test7) { + + auto a= NDArrayFactory::create('c', {3,3}); + auto b= NDArrayFactory::create('c', {3,3}); + auto x= NDArrayFactory::create('c', {3,3}); + + a.linspace(10.); + b.linspace(10.); + x.assign(0.9); + + auto expected= NDArrayFactory::create('c', {3,3}, {0.99999607f, 0.99999865f, 0.99999953f, 0.99999984f, 0.99999994f, 0.99999998f, 0.99999999f, 1.f, 1.f}); + + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output, 1e-6)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, betainc_test8) { + + auto a= NDArrayFactory::create('c', {3,3}); + auto b= NDArrayFactory::create('c', {3,3}); + auto x= NDArrayFactory::create('c', {3,3}); + + a.linspace(10.); + b.linspace(10.); + x.assign(1.); + + auto expected= NDArrayFactory::create('c', {3,3}, {1.f, 1.f, 1.f,1.f,1.f,1.f,1.f,1.f,1.f}); + + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output, 1e-6)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, betainc_test9) { + + auto a= NDArrayFactory::create('c', {3,3}); + auto b= NDArrayFactory::create('c', {3,3}); + auto x= NDArrayFactory::create('c', {3,3}); + + a.linspace(10.); + b.linspace(10.); + x.assign(0.); + + auto expected= NDArrayFactory::create('c', {3,3}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); + + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, betainc_test10) { + + auto a= NDArrayFactory::create('c', {3,3}); + auto b= NDArrayFactory::create('c', {3,3}); + auto x= NDArrayFactory::create('c', {3,3}); + + a.linspace(10.); + b.linspace(10.); + x.assign(0.5); + + auto expected= NDArrayFactory::create('c', {3,3}, {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f}); + + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, betainc_test11) { + + NDArray a('c', {4}, {0.7788f, 0.8012f, 0.7244f, 0.2309f}, sd::DataType::FLOAT32); + NDArray b('c', {4}, {0.7717f, 0.9281f, 0.9846f, 0.4838f}, sd::DataType::FLOAT32); + NDArray x('c', {4}, {0.9441f, 0.5957f, 0.8669f, 0.3502f}, sd::DataType::FLOAT32); + + NDArray expected('c', {4}, {0.912156, 0.634460, 0.898314, 0.624538}, sd::DataType::FLOAT32); + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, betainc_test12) { + + NDArray a('c', {4}, {8.0091f, 8.2108f, 7.5194f, 3.0780f}, sd::DataType::FLOAT32); + NDArray b('c', {4}, {7.9456f, 9.3527f, 9.8610f, 5.3541f}, sd::DataType::FLOAT32); + NDArray x('c', {4}, {0.9441f, 0.5957f, 0.8669f, 0.3502f}, sd::DataType::FLOAT32); + + NDArray expected('c', {4}, {0.9999995 , 0.8594694 , 0.999988 , 0.49124345}, sd::DataType::FLOAT32); + + sd::ops::betainc op; + auto result = op.evaluate({&a, &b, &x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, zeta_test1) { + + auto x= NDArrayFactory::create('c', {3,3}); + auto q= NDArrayFactory::create('c', {3,3}); + + q.linspace(1.); + x.assign(2.); + + auto expected= NDArrayFactory::create('c', {3,3}, {1.64493407f, 0.64493407f, 0.39493407f, 0.28382296f, 0.22132296f, 0.18132296f, 0.15354518f, 0.13313701f, 0.11751201f}); + + sd::ops::zeta op; + auto result = op.evaluate({&x, &q}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, zeta_test2) { + + auto x= NDArrayFactory::create('c', {3,3}); + auto q= NDArrayFactory::create('c', {3,3}); + + q.linspace(10.); + x.assign(2.); + + auto expected= NDArrayFactory::create('c', {3,3}, {0.10516634f, 0.09516634f, 0.08690187f, 0.07995743f, 0.07404027f, 0.06893823f, 0.06449378f, 0.06058753f, 0.05712733f}); + + sd::ops::zeta op; + auto result = op.evaluate({&x, &q}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, zeta_test3) { + + auto x= NDArrayFactory::create('c', {3,3}); + auto q= NDArrayFactory::create('c', {3,3}); + + q.linspace(100.); + x.assign(2.); + + auto expected= NDArrayFactory::create('c', {3,3}, {0.01005017f, 0.00995017f, 0.00985214f, 0.00975602f, 0.00966176f, 0.0095693f, 0.0094786f, 0.0093896f, 0.00930226f}); + + sd::ops::zeta op; + auto result = op.evaluate({&x, &q}, {}, {}); + + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, zeta_test4) { + + auto x= NDArrayFactory::create('c', {3,3}); + auto q= NDArrayFactory::create('c', {3,3}); + + q.linspace(100.); + x.assign(2.); + + auto expected= NDArrayFactory::create('c', {3,3}, {0.01005017f, 0.00995017f, 0.00985214f, 0.00975602f, 0.00966176f, 0.0095693f, 0.0094786f, 0.0093896f, 0.00930226f}); + + sd::ops::zeta op; + auto result = op.evaluate({&x, &q}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, zeta_test5) { + + auto x= NDArrayFactory::create('c', {3,3}); + auto q= NDArrayFactory::create('c', {3,3}); + + q.linspace(1.); + x.assign(1.1); + + auto expected= NDArrayFactory::create('c', {3,3}, {10.58444846f, 9.58444846f, 9.11793197f, 8.81927915f, 8.60164151f, 8.43137352f, 8.29204706f, 8.17445116f, 8.07291961f}); + + + sd::ops::zeta op; + auto result = op.evaluate({&x, &q}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, zeta_test6) { + + auto x= NDArrayFactory::create('c', {3,3}); + auto q= NDArrayFactory::create('c', {3,3}); + + q.linspace(1.); + x.assign(1.01); + + auto expected= NDArrayFactory::create('c', {3,3}, {100.57794334f, 99.57794334f, 99.08139709f, 98.75170576f, 98.50514758f, 98.30834069f, 98.1446337f, 98.00452955f, 97.88210202f}); + + sd::ops::zeta op; + auto result = op.evaluate({&x, &q}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, zeta_test7) { + + auto x= NDArrayFactory::create('c', {3,3}); + auto q= NDArrayFactory::create('c', {3,3}); + + q.linspace(1.); + x.assign(10.); + + auto expected= NDArrayFactory::create('c', {3,3}, {1.00099458e+00f, 9.94575128e-04f, 1.80126278e-05f, 1.07754001e-06f, 1.23865693e-07f, 2.14656932e-08f, 4.92752156e-09f, 1.38738839e-09f, 4.56065812e-10f}); + + + sd::ops::zeta op; + auto result = op.evaluate({&x, &q}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, zeta_test8) { + + auto x= NDArrayFactory::create('c', {3,4}, {1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,1.9,1.01,1.11,1.12}); + auto q= NDArrayFactory::create('c', {3,4}, {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.01, 0.11, 0.12}); + + //q.linspace(1.); + //x.assign(10.); + + auto expected= NDArrayFactory::create('c', {3,4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398}); + + sd::ops::zeta op; + auto result = op.evaluate({&x, &q}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, zeta_test9) { + + auto x= NDArrayFactory::create('c', {3,4}, {1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,1.9,1.01,1.11,1.12}); + auto q= NDArrayFactory::create('c', {3,4}, {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.01, 0.11, 0.12}); + auto z= NDArrayFactory::create('c', {3,4}, {1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,1.}); + + //q.linspace(1.); + //x.assign(10.); + + auto expected= NDArrayFactory::create('c', {3,4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398}); + + sd::ops::zeta op; + auto results = op.execute({&x, &q}, {&z}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results); + + //auto *output = result.at(0); + // z.printIndexedBuffer("Zeta output"); + ASSERT_TRUE(expected.isSameShape(z)); + ASSERT_TRUE(expected.equalsTo(z)); + +// +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, zeta_test10) { + + auto x= NDArrayFactory::create('c', {3,4}, {1.1,1.2,1.3,1.4,1.5,1.6,1.7,1.8,1.9,1.01,1.11,1.12}); + auto q= NDArrayFactory::create('c', {3,4}, {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.01, 0.11, 0.12}); + auto z= NDArrayFactory::create('c', {3,4}); + + //q.linspace(1.); + //x.assign(10.); + + auto expected= NDArrayFactory::create('c', {3,4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398}); + + sd::ops::zeta op; + auto results = op.execute({&x, &q}, {&z}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results); + + //auto *output = result.at(0); + // z.printIndexedBuffer("Zeta output"); + ASSERT_TRUE(expected.isSameShape(z)); + ASSERT_TRUE(expected.equalsTo(z)); + +// +} + + +TEST_F(DeclarableOpsTests3, Test_SplitV_Validation_1) { + auto x = NDArrayFactory::create('c', {8, 7}); + auto indices = NDArrayFactory::create('c',{2}, {5, 3}); + auto axis = NDArrayFactory::create(-2); + + auto z0 = NDArrayFactory::create('c', {5, 7}); + auto z1 = NDArrayFactory::create('c', {3, 7}); + + sd::ops::split_v op; + auto status = op.execute({&x, &indices, &axis}, std::vector{&z0, &z1}, {}, {}, {}); + ASSERT_EQ(Status::OK(), status); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, polygamma_test1) { + + auto n= NDArrayFactory::create('c', {3,3}); + auto x= NDArrayFactory::create('c', {3,3}); +// ASSERT_FALSE(true); + n.linspace(1.); + x.assign(0.5); + + auto expected= NDArrayFactory::create('c', {3,3}, {4.934802, -16.828796, 97.409088, -771.474243, 7691.113770, -92203.460938, 1290440.250000, -20644900.000000, 3.71595e+08}); + + sd::ops::polygamma op; + auto result = op.evaluate({&n, &x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + // output->printBuffer(); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, polygamma_test2) { + + auto n= NDArrayFactory::create('c', {3,3}); + auto x= NDArrayFactory::create('c', {3,3}); + + n.linspace(10.); + x.linspace(0.5); + + auto expected= NDArrayFactory::create('c', {3,3}, {-7.43182451e+09, 3.08334759e+05,-3.25669798e+03, 1.55186197e+02,-1.46220433e+01, 2.00905201e+00,-3.48791235e-01, 7.08016273e-02,-1.60476052e-02}); + + //ASSERT_FALSE(true); + + sd::ops::polygamma op; + auto result = op.evaluate({&n, &x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, polygamma_test3) { + + auto n= NDArrayFactory::create('c', {3,3}); + auto x= NDArrayFactory::create('c', {3,3}); + + n.linspace(1.); + x.linspace(10.); + + auto expected= NDArrayFactory::create('c', {3,3}, {1.05166336e-01,-9.04983497e-03, 1.31009323e-03,-2.44459433e-04, 5.31593880e-05,-1.28049888e-05, 3.31755364e-06,-9.07408791e-07, 2.58758130e-07}); + + sd::ops::polygamma op; + auto result = op.evaluate({&n, &x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +TEST_F(DeclarableOpsTests3, polygamma_test4) { + + NDArray n('c', {3,4}, {/*0.7788*/0, 0,1,2,3,4,5,6,7,8,9,10}, sd::DataType::DOUBLE); + NDArray x('c', {3,4}, {0.7717,0.9281,0.9846,0.4838,0.6433,0.6041,0.6501,0.7612,0.7605,0.3948,0.9493,0.8600}, sd::DataType::DOUBLE); + + NDArray expected('c', {3,4}, {/*std::numeric_limits::quiet_NaN()*/-1.031918, -7.021327e-01, 1.682743e+00, -1.851378e+01,3.604167e+01, -3.008293e+02, + 1.596005e+03, -4.876665e+03,4.510025e+04, -1.730340e+08, 6.110257e+05, -1.907087e+07}, sd::DataType::DOUBLE); + + sd::ops::polygamma op; + auto result = op.evaluate({&n, &x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +TEST_F(DeclarableOpsTests3, digamma_1) { + + NDArray x('c', {18}, {-25, -24.99999, -21.5, -21.2, -5.5, -4.1, -2.1, -0.5, -0.3, 0., 0.2, 1, 1.5, 2.2, 5.2, 19., 21, 22.2}, sd::DataType::DOUBLE); + + NDArray expected('c', {18}, {std::numeric_limits::infinity(), -99996.761229, 3.091129, 7.401432, 1.792911,11.196838,10.630354, 0.03649, 2.11331, + std::numeric_limits::infinity(),-5.28904,-0.577216, 0.03649, 0.544293, 1.549434,2.917892, 3.020524, 3.077401}, sd::DataType::DOUBLE); + + sd::ops::digamma op; + auto result = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, svd_test1) { + + auto x= NDArrayFactory::create('c', {6,6}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16}); + auto expS= NDArrayFactory::create('c', {6}, {54.12775, 38.79293, 25.89287, 9.82168, 6.07227, 2.91827}); + auto expU= NDArrayFactory::create('c', {6,6}, {0.14692,-0.11132,-0.69568, 0.59282,-0.14881, 0.32935,-0.38751, 0.60378,-0.04927,-0.01397,-0.69456,-0.01581, 0.19293,-0.12795,-0.18682,-0.69065,-0.20597, 0.62617, 0.66806, 0.4314 ,-0.33849,-0.22166, 0.04099,-0.44967, 0.11121,-0.64065,-0.02138,-0.07378,-0.60568,-0.45216,-0.5765 ,-0.1007 ,-0.60305,-0.34175, 0.29068,-0.3042}); + auto expV= NDArrayFactory::create('c', {6,6}, {-0.24577,-0.24512, 0.00401,-0.04585,-0.62058, 0.70162, 0.27937, 0.75961, 0.43885,-0.06857,-0.3839 , 0.01669,-0.35944,-0.09629, 0.44593, 0.78602,-0.09103,-0.19125, 0.53973, 0.07613,-0.10721, 0.49559, 0.35687, 0.56431,-0.6226 , 0.39742, 0.12785,-0.15716, 0.52372, 0.37297, 0.23113,-0.43578, 0.76204,-0.32414, 0.23996, 0.11543}); + + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {1, 1, 16}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *s = result.at(0); + auto *u = result.at(1); + auto *v = result.at(2); + + ASSERT_TRUE(expS.isSameShape(s)); + ASSERT_TRUE(expU.isSameShape(u)); + ASSERT_TRUE(expV.isSameShape(v)); + + ASSERT_TRUE(expS.equalsTo(s)); + + if(sd::Environment::getInstance().isCPU()) { + ASSERT_TRUE(expU.equalsTo(u)); + ASSERT_TRUE(expV.equalsTo(v)); + } + else { + for(uint i = 0; i < expU.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u->e(i)), 1e-5); + for(uint i = 0; i < expV.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v->e(i)), 1e-5); + } + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, svd_test2) { + + auto x = NDArrayFactory::create('c', {7,6}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16, 4, -9, 1, -15, 7, -2}); + auto expS= NDArrayFactory::create('c', {6}, {56.76573, 39.11776, 26.00713, 11.83606, 6.16578, 3.99672}); + auto expU= NDArrayFactory::create('c', {7,7}, {-0.13417,-0.12443, -0.68854, 0.5196 , 0.21706, 0.03974, 0.41683, 0.347 , 0.62666, -0.04964, -0.01912, 0.66932, 0.1457 , -0.12183,-0.17329,-0.14666, -0.19639, -0.55355, 0.0614 , 0.75729, 0.1619 ,-0.64703, 0.37056, -0.37398, -0.32922, -0.0186 , -0.35656, -0.26134,-0.08027,-0.64405, -0.0127 , -0.06934, 0.59287, -0.14956, -0.44712, 0.55906,-0.06235, -0.58017, -0.12911, -0.359 , -0.00393, -0.44877, 0.30645,-0.11953, -0.09083, -0.54163, 0.14283, -0.50417, 0.56178}); + auto expV= NDArrayFactory::create('c', {6,6}, {0.2508 ,-0.2265 , 0.01689, 0.04486, 0.53132, 0.77537,-0.32281, 0.74559, 0.41845, -0.13821, 0.37642, 0.06315, 0.33139,-0.05528, 0.47186, 0.73171, 0.18905, -0.3055 ,-0.57263, 0.06276,-0.09542, 0.59396, -0.36152, 0.419 , 0.59193, 0.4361 , 0.13557, -0.03632, -0.5755 , 0.32944,-0.21165,-0.44227, 0.75794, -0.29895, -0.27993, 0.13187}); + + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {1, 1, 16}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *s = result.at(0); + auto *u = result.at(1); + auto *v = result.at(2); + + ASSERT_TRUE(expS.isSameShape(s)); + ASSERT_TRUE(expU.isSameShape(u)); + ASSERT_TRUE(expV.isSameShape(v)); + + ASSERT_TRUE(expS.equalsTo(s)); + + if(sd::Environment::getInstance().isCPU()) { + ASSERT_TRUE(expU.equalsTo(u)); + ASSERT_TRUE(expV.equalsTo(v)); + } + else { + for(uint i = 0; i < expU.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u->e(i)), 1e-5); + for(uint i = 0; i < expV.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v->e(i)), 1e-5); + } + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, svd_test3) { + + auto x= NDArrayFactory::create('c', {7,6}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16, 4, -9, 1, -15, 7, -2}); + auto expS= NDArrayFactory::create('c', {6}, {56.76573, 39.11776, 26.00713, 11.83606, 6.16578, 3.99672}); + auto expU= NDArrayFactory::create('c', {7,6}, {-0.13417, -0.12443, -0.68854, 0.5196 , 0.21706, 0.03974, 0.347 , 0.62666, -0.04964, -0.01912, 0.66932, 0.1457 ,-0.17329, -0.14666, -0.19639, -0.55355, 0.0614 , 0.75729,-0.64703, 0.37056, -0.37398, -0.32922, -0.0186 , -0.35656,-0.08027, -0.64405, -0.0127 , -0.06934, 0.59287, -0.14956, 0.55906, -0.06235, -0.58017, -0.12911, -0.359 , -0.00393, 0.30645, -0.11953, -0.09083, -0.54163, 0.14283, -0.50417}); + auto expV= NDArrayFactory::create('c', {6,6}, {0.2508 ,-0.2265 , 0.01689, 0.04486, 0.53132, 0.77537,-0.32281, 0.74559, 0.41845, -0.13821, 0.37642, 0.06315, 0.33139,-0.05528, 0.47186, 0.73171, 0.18905, -0.3055 ,-0.57263, 0.06276,-0.09542, 0.59396, -0.36152, 0.419 , 0.59193, 0.4361 , 0.13557, -0.03632, -0.5755 , 0.32944,-0.21165,-0.44227, 0.75794, -0.29895, -0.27993, 0.13187}); + + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {0, 1, 16}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *s = result.at(0); + auto *u = result.at(1); + auto *v = result.at(2); + + ASSERT_TRUE(expS.isSameShape(s)); + ASSERT_TRUE(expU.isSameShape(u)); + ASSERT_TRUE(expV.isSameShape(v)); + + ASSERT_TRUE(expS.equalsTo(s)); + + if(sd::Environment::getInstance().isCPU()) { + ASSERT_TRUE(expU.equalsTo(u)); + ASSERT_TRUE(expV.equalsTo(v)); + } + else { + for(uint i = 0; i < expU.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u->e(i)), 1e-5f); + for(uint i = 0; i < expV.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v->e(i)), 1e-5f); + } + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, svd_test4) { + + auto x= NDArrayFactory::create('c', {6,7}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16, 4, -9, 1, -15, 7, -2}); + auto expS= NDArrayFactory::create('c', {6}, {53.11053, 39.09542, 28.1987, 17.7468, 11.61684, 5.36217}); + auto expU= NDArrayFactory::create('c', {6,6}, {-0.16541, 0.21276, 0.51284, 0.20472, 0.74797, 0.25102,-0.49879, 0.12076, 0.37629, -0.7211 , -0.24585, 0.12086,-0.36569,-0.70218, -0.08012, 0.21274, -0.07314, 0.56231,-0.44508, 0.4329 , 0.1356 , 0.60909, -0.47398, -0.02164, 0.61238,-0.05674, 0.59489, 0.06588, -0.3874 , 0.33685,-0.13044,-0.50644, 0.46552, 0.13236, -0.00474, -0.70161}); + auto expV= NDArrayFactory::create('c', {7,7}, {-0.35914, 0.68966, -0.30077, -0.15238, -0.48179, 0.14716, -0.16709, 0.21989, -0.34343, 0.11086, -0.78381, -0.37902, 0.24224, -0.06862, 0.32179, 0.12812, -0.25812, 0.0691 , -0.12891, 0.26979, 0.84807,-0.50833, 0.13793, 0.06658, -0.53001, 0.52572, -0.16194, 0.36692, 0.48118, 0.15876, -0.65132, -0.24602, 0.3963 , -0.16651, -0.27155,-0.31605, -0.46947, -0.50195, 0.0378 , -0.34937, -0.53062, 0.15069, 0.35957, 0.35408, 0.38732, -0.12154, -0.22827, -0.7151 , 0.13065}); + + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {1, 1, 16}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *s = result.at(0); + auto *u = result.at(1); + auto *v = result.at(2); + + ASSERT_TRUE(expS.isSameShape(s)); + ASSERT_TRUE(expU.isSameShape(u)); + ASSERT_TRUE(expV.isSameShape(v)); + + ASSERT_TRUE(expS.equalsTo(s)); + + if(sd::Environment::getInstance().isCPU()) { + ASSERT_TRUE(expU.equalsTo(u)); + ASSERT_TRUE(expV.equalsTo(v)); + } + else { + for(uint i = 0; i < expU.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u->e(i)), 1e-5f); + for(uint i = 0; i < expV.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v->e(i)), 1e-5f); + } + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, svd_test5) { + + auto x= NDArrayFactory::create('c', {6,7}, {0. ,-9. ,-6 ,9 ,-10 ,-12 ,2 ,13 ,5 ,-11 ,20 ,-17 ,1 ,-2 ,-11 ,3 ,-8 ,3 ,-14 ,19 ,-20 ,20 ,-17 ,-5 ,6 ,-16 ,0 ,-1 ,-16 ,11 ,7 ,-19 ,2 ,-17 ,17 ,-16, 4, -9, 1, -15, 7, -2}); + auto expS= NDArrayFactory::create('c', {6}, {53.11053, 39.09542, 28.1987, 17.7468, 11.61684, 5.36217}); + auto expU= NDArrayFactory::create('c', {6,6}, {-0.16541, 0.21276, 0.51284, 0.20472, 0.74797, 0.25102,-0.49879, 0.12076, 0.37629, -0.7211 , -0.24585, 0.12086,-0.36569,-0.70218, -0.08012, 0.21274, -0.07314, 0.56231,-0.44508, 0.4329 , 0.1356 , 0.60909, -0.47398, -0.02164, 0.61238,-0.05674, 0.59489, 0.06588, -0.3874 , 0.33685,-0.13044,-0.50644, 0.46552, 0.13236, -0.00474, -0.70161}); + auto expV= NDArrayFactory::create('c', {7,6}, {-0.35914, 0.68966, -0.30077, -0.15238, -0.48179, 0.14716, 0.21989, -0.34343, 0.11086, -0.78381, -0.37902, 0.24224, 0.32179, 0.12812, -0.25812, 0.0691 , -0.12891, 0.26979,-0.50833, 0.13793, 0.06658, -0.53001, 0.52572, -0.16194, 0.48118, 0.15876, -0.65132, -0.24602, 0.3963 , -0.16651,-0.31605, -0.46947, -0.50195, 0.0378 , -0.34937, -0.53062, 0.35957, 0.35408, 0.38732, -0.12154, -0.22827, -0.7151}); + + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {0, 1, 16}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *s = result.at(0); + auto *u = result.at(1); + auto *v = result.at(2); + + ASSERT_TRUE(expS.isSameShape(s)); + ASSERT_TRUE(expU.isSameShape(u)); + ASSERT_TRUE(expV.isSameShape(v)); + + ASSERT_TRUE(expS.equalsTo(s)); + + if(sd::Environment::getInstance().isCPU()) { + ASSERT_TRUE(expU.equalsTo(u)); + ASSERT_TRUE(expV.equalsTo(v)); + } + else { + for(uint i = 0; i < expU.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u->e(i)), 1e-5f); + for(uint i = 0; i < expV.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v->e(i)), 1e-5f); + } + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, svd_test6) { + + auto x= NDArrayFactory::create('c', {2,2,5,5}, {-7. ,17 ,4 ,-10 ,5 ,1 ,-5 ,-19 ,13 ,-8 ,9 ,13 ,19 ,13 ,-2 + ,-8 ,10 ,-9 ,0 ,-20 ,-2 ,14 ,19 ,5 ,-18 ,4 ,-13 ,12 ,-10 ,5 ,-10 ,-10 ,17 ,-5 ,-2 ,10 ,5 ,-4 ,-11 ,15 ,-3 ,15 ,-17 + ,-20 ,-10 ,-4 ,12 ,-9 ,16 ,13 ,10 ,-19 ,2 ,-9 ,-10 ,8 ,-2 ,-4 ,3 ,7 ,10 ,-19 ,-11 ,-4 ,-6 ,2 ,-12 ,6 ,-4 ,-14 ,14 + ,16 ,7 ,19 ,-17 ,2 ,-14 ,5 ,-1 ,16 ,19 ,-11 ,-14 ,-16 ,-19 ,15 ,-18 ,-12 ,-16 ,16 ,1 ,5 ,7 ,8 ,2 ,13 ,-3 ,6 ,2 ,-5}); + auto expS= NDArrayFactory::create('c', {2,2,5}, {40.95395, 31.46869, 24.79993, 12.33768, 1.80031, + 38.18412, 31.52287, 23.52755, 11.79484, 1.90195, + 39.34498, 32.54861, 17.52492, 7.03003, 2.2399, + 44.72126, 32.3164 , 16.60139, 6.88783, 0.78122}); + auto expU= NDArrayFactory::create('c', {2,2,5,5}, {0.25441, 0.16908, -0.68564, 0.58844, -0.30054, + -0.32285, -0.58332, 0.3451 , 0.4746 , -0.45953,0.58332, 0.10605, 0.51533, 0.50234, 0.36136,0.12588, -0.73123, -0.37812, -0.00215, 0.55361, + 0.68915, -0.2919 , 0.04767, -0.4197 , -0.51132,0.44464, -0.25326, -0.42493, -0.01712, -0.74653,0.516 , -0.16688, 0.1854 , -0.77155, 0.27611, + -0.19321, -0.14317, -0.85886, -0.15224, 0.42585,-0.60155, -0.68323, 0.18819, -0.29053, -0.22696,-0.36993, 0.64862, -0.10956, -0.54483, -0.36552, + -0.57697, -0.32277, 0.11229, 0.55495, 0.4923 ,-0.02937, 0.01689, -0.63257, 0.57075, -0.52245,-0.56002, -0.2036 , -0.53119, -0.6022 , 0.01017, + -0.33605, -0.35257, 0.53215, -0.04936, -0.69075,0.48958, -0.85427, -0.14796, -0.03449, 0.08633,0.15008, 0.60996, 0.31071, -0.67721, 0.22421, + 0.67717, -0.59857, 0.04372, -0.2565 , 0.33979,0.68116, 0.49852, -0.13441, 0.51374, -0.07421,-0.20066, 0.04504, 0.42865, 0.44418, 0.75939,0.12113, -0.13826, 0.83651, 0.11988, -0.50209}); + auto expV= NDArrayFactory::create('c', {2,2,5,5}, {0.01858, 0.17863, 0.51259, 0.14048, 0.82781, + 0.59651, -0.13439, -0.395 , 0.66979, 0.14654,0.73731, 0.47061, 0.19357, -0.41127, -0.16817,0.1047 , -0.29727, 0.73711, 0.38235, -0.45951, + -0.29873, 0.80012, -0.02078, 0.4651 , -0.23201,-0.05314, -0.0419 , -0.52146, 0.77792, 0.344 ,-0.66438, 0.05648, 0.03756, -0.31531, 0.67422, + 0.74471, 0.01504, -0.03081, -0.24335, 0.62049,0.03172, 0.91947, 0.30828, 0.23713, 0.04796,-0.01311, 0.38652, -0.79415, -0.42423, -0.19945, + -0.13783, -0.54667, -0.58527, 0.49955, 0.3001 ,0.85214, 0.01628, 0.02688, -0.02891, 0.52157,0.16608, -0.20181, 0.61371, 0.69894, -0.25794, + 0.45726, -0.33952, -0.32659, -0.18938, -0.73015,0.13486, 0.73816, -0.41646, 0.47458, -0.1956 ,0.5536 , -0.137 , 0.64688, 0.50536, 0.03017, + -0.51827, -0.31837, -0.16732, 0.71378, -0.30425,-0.39314, 0.15266, 0.63693, -0.30945, -0.5663 ,-0.51981, 0.03325, 0.37603, 0.05147, 0.76462,-0.01282, 0.92491, -0.08042, 0.36977, -0.03428}); + + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {1, 1, 16}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *s = result.at(0); + auto *u = result.at(1); + auto *v = result.at(2); + + ASSERT_TRUE(expS.isSameShape(s)); + ASSERT_TRUE(expU.isSameShape(u)); + ASSERT_TRUE(expV.isSameShape(v)); + + ASSERT_TRUE(expS.equalsTo(s)); + + if(sd::Environment::getInstance().isCPU()) { + ASSERT_TRUE(expU.equalsTo(u)); + ASSERT_TRUE(expV.equalsTo(v)); + } + else { + for(uint i = 0; i < expU.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u->e(i)), 1e-5f); + for(uint i = 0; i < expV.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v->e(i)), 1e-5f); + } + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, svd_test7) { + + auto x= NDArrayFactory::create('c', {2,2,5,5}, {-7. ,17 ,4 ,-10 ,5 ,1 ,-5 ,-19 ,13 ,-8 ,9 ,13 ,19 ,13 ,-2,-8 ,10 ,-9 ,0 ,-20 ,-2 ,14 ,19 ,5 ,-18 ,4 ,-13 ,12 ,-10 + ,5 ,-10 ,-10 ,17 ,-5 ,-2 ,10 ,5 ,-4 ,-11 ,15 ,-3 ,15 ,-17,-20 ,-10 ,-4 ,12 ,-9 ,16 ,13 ,10 ,-19 ,2 ,-9 ,-10 ,8 ,-2 + ,-4 ,3 ,7 ,10 ,-19 ,-11 ,-4 ,-6 ,2 ,-12 ,6 ,-4 ,-14 ,14,16 ,7 ,19 ,-17 ,2 ,-14 ,5 ,-1 ,16 ,19 ,-11 ,-14 ,-16,-19 ,15 ,-18 ,-12 ,-16 ,16 ,1 ,5 ,7 ,8 ,2 ,13 ,-3 ,6 ,2 ,-5}); + auto expS= NDArrayFactory::create('c', {2,2,5}, {40.95395, 31.46869, 24.79993, 12.33768, 1.80031,38.18412, 31.52287, 23.52755, 11.79484, 1.90195, + 39.34498, 32.54861, 17.52492, 7.03003, 2.2399,44.72126, 32.3164 , 16.60139, 6.88783, 0.78122}); + + + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {0, 0, 16}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *s = result.at(0); + + ASSERT_TRUE(expS.equalsTo(s)); + ASSERT_TRUE(expS.isSameShape(s)); + +} + +/////////////////////////////////////////////////////////////////// +// TEST_F(DeclarableOpsTests3, svd_test8) { + +// auto x= NDArrayFactory::create('c', {2,2,11,10}, {3 ,-8 ,0 ,3 ,-5 ,16 ,-3 ,7 ,-4 ,19 ,19 ,13 ,15 ,15 ,9 ,6 ,-7 ,-5 ,-9 ,-12 ,7 ,-1 ,-1 ,6 ,19 +// ,-6 ,16 ,0 ,16 ,16 ,7 ,14 ,18. ,0 ,18 ,-4 ,10 ,-16 ,-17 ,15 ,13 ,-17 ,-14 ,-17 ,-5 ,-9 ,-1 ,-19 +// ,-18 ,5 ,-5 ,-13 ,17 ,-19 ,-5 ,18 ,4 ,10 ,17 ,-7 ,-10 ,16 ,10 ,8 ,-10 ,-3 ,10 ,1 ,-4 ,-16 ,-1 +// ,-1 ,5 ,5 ,17 ,14 ,20 ,15 ,-6 ,19 ,14 ,17 ,0 ,-17 ,-16 ,-8 ,-6 ,3 ,-6 ,-11 ,-4 ,-2 ,-7 ,4 ,-6 +// ,-6 ,-17 ,16 ,-8 ,-20 ,2 ,7 ,-12 ,15 ,-15 ,-19 ,14 ,17 ,9 ,10 ,5 ,18 ,2 ,-6 ,0 ,2 ,-10 ,7 ,8 +// ,-13 ,2 ,8 ,20 ,11 ,-15 ,13 ,-10 ,-14 ,-2 ,20 ,5 ,2 ,16 ,18 ,-3 ,3 ,-18 ,15 ,-11 ,17 ,-8 ,-18 +// ,20 ,-12 ,20 ,20 ,-16 ,20 ,-8. ,19 ,-8 ,3 ,-3 ,17 ,7 ,13 ,9 ,-2 ,11 ,16 ,4 ,-18 ,5 ,0 ,-12 ,9 +// ,-6 ,6 ,0 ,-9 ,-13 ,13 ,17 ,-12 ,3 ,-13 ,17 ,-19 ,17 ,0 ,-8 ,4 ,-19 ,-9 ,-7 ,12 ,-1 ,-12 ,-1 +// ,7 ,2 ,19 ,10 ,19 ,-15 ,-18 ,17 ,-1 ,1 ,14 ,-7 ,-10 ,12 ,-20 ,6 ,-5 ,14 ,5 ,5 ,3 ,-18 ,5 ,17 +// ,-13 ,20 ,-1 ,-2 ,-11 ,-5 ,14 ,8 ,7 ,-13 ,-9 ,-12 ,11 ,3 ,14 ,-6 ,-2 ,13 ,8 ,-15 ,-5 ,-6 ,-7 ,19 +// ,-1 ,6 ,1 ,14 ,8 ,18 ,-20 ,-14 ,-3 ,-5 ,19 ,15 ,13 ,2 ,-20 ,2 ,14 ,13 ,4 ,-15 ,1 ,-14 +// ,0 ,9 ,-1 ,10 ,4 ,6 ,4 ,-7 ,-2 ,-1 ,-15 ,-1 ,-16 ,-5 ,-12 ,-10 ,16 ,-16 ,-15 ,-17 ,-5 ,-6 +// ,18 ,14 ,-3 ,-10 ,8 ,20 ,19 ,20 ,-3 ,-6 ,9 ,10 ,-1 ,-20 ,-5 ,5 ,12 ,8 ,17 ,13 ,-18 ,-14 ,0 +// ,4 ,-11 ,3 ,-12 ,-2 ,-5 ,19 ,-15 ,19 ,16 ,-16 ,13 ,-6 ,11 ,11 ,0 ,-18 ,4 ,5 ,6 ,-12 ,-10 +// ,-3 ,2 ,-18 ,16 ,-5 ,17 ,16 ,-16 ,-20 ,14 ,6 ,10 ,-5 ,-3 ,4 ,20 ,18 ,5 ,1 ,-10 ,15 ,10 ,16 +// ,-18 ,2 ,12 ,20 ,6 ,14 ,8 ,3 ,-2 ,9 ,15 ,-4 ,13 ,-19 ,-5 ,3 ,3 ,-20 ,-4 ,18 ,-11 ,11 ,-10 +// ,3 ,8 ,9 ,20 ,-19 ,6 ,18 ,9 ,20 ,-12 ,4 ,15 ,19 ,3 ,5 ,1 ,2 ,20 ,-3 ,-1 ,-8 ,-3 ,8 ,17 , +// -14 ,18 ,-10 ,4 ,13 ,-5 ,13 ,-6 ,12 ,-10 ,19 ,4 ,-7 ,-17 ,20 ,8 ,6 ,-3 ,3 ,-7 ,-18 ,17 , +// -13 ,18 ,-20 ,-16 ,-5 ,12 ,5 ,17 ,-4 ,4 ,7 ,8 ,17 ,-9 ,-12 ,-10 ,8 ,-14 ,-11 ,7 ,19 ,-17}); + +// auto expS= NDArrayFactory::create('c', {2,2,10}, { 64.12636, 54.37044, 50.63744, 48.10308, 33.7364 , 29.96456, +// 25.53945, 19.31856, 15.30939, 9.31349, +// 67.41342, 59.64963, 58.72687, 39.22496, 32.39772, 29.30833, +// 23.1491 , 16.92442, 6.38613, 3.49563, +// 74.37477, 52.07016, 46.10758, 39.10742, 32.02261, 27.05888, +// 20.54921, 13.17989, 8.4158 , 4.39974, +// 65.47447, 56.31305, 54.13371, 46.26955, 43.47755, 30.25799, +// 20.71463, 16.89671, 10.39572, 7.81631}); + +// auto expU= NDArrayFactory::create('c', {2,2,11,11}, {-0.177870, -0.149461, -0.196911, 0.036990, -0.338237, 0.548901, +// -0.074396, 0.497067, -0.083636, -0.111810, -0.466989, -0.010465, 0.434732, 0.337198, 0.305239, -0.292813, +// 0.041280, -0.517144, 0.121499, 0.464908, 0.003658, 0.135017, -0.446916, -0.098318, 0.073571, -0.200521, +// 0.186776, -0.353022, -0.435582, -0.225959, 0.052972, 0.032390, -0.583801, -0.402790, 0.562809, 0.102744, +// 0.066555, 0.206079, 0.115322, 0.217220, -0.062591, -0.273173, -0.569645, 0.005612, 0.092601, 0.350055, +// -0.608007, -0.367743, 0.064860, 0.112656, 0.091576, -0.144262, 0.554655, -0.042100, -0.092023, 0.026986, +// -0.395811, -0.245209, 0.572522, 0.429430, 0.099621, -0.159236, -0.086263, 0.268160, -0.391298, 0.050417, +// 0.150175, 0.045253, 0.464173, 0.138376, 0.265551, 0.049691, 0.528778, 0.116951, 0.384609, 0.144416, +// -0.453591, -0.519390, -0.150671, 0.072897, 0.102406, -0.154184, 0.450735, 0.174171, -0.519405, 0.147109, +// 0.333670, 0.178053, 0.360763, 0.226976, 0.069976, -0.046765, 0.448897, 0.511309, -0.361050, -0.191690, +// -0.304442, 0.270383, -0.124133, 0.417183, -0.083359, 0.137022, 0.004276, -0.462336, 0.051267, 0.020622, +// -0.566932, -0.051351, -0.417106, -0.292202, -0.021595, -0.315956, 0.396626, -0.604952, 0.155990, 0.258395, +// -0.125080, 0.115404, 0.234517, -0.357460, 0.271271, 0.063771, -0.087400, -0.024710, -0.179892, 0.584339, +// -0.413085, 0.510580, 0.334646, 0.044424, 0.224735, 0.134434, -0.147861, 0.291853, 0.487948, 0.238917, +// 0.433893, 0.435884, 0.056370, -0.051216, -0.450902, 0.062411, 0.080733, -0.365211, 0.031931, 0.493926, +// -0.239428, 0.038247, -0.180721, -0.118035, 0.042175, 0.377296, -0.516399, 0.324744, -0.756196, 0.160856, +// -0.152527, -0.046867, -0.092933, -0.044945, 0.137659, 0.246552, -0.071709, 0.032821, -0.529356, -0.029669, +// 0.200178, 0.188916, 0.428036, -0.496734, -0.164185, 0.629070, -0.131588, 0.073992, 0.066877, 0.208450, +// -0.156170, -0.253670, -0.000365, -0.121172, 0.067774, 0.618226, 0.230460, -0.118865, 0.579424, 0.324523, +// 0.038653, 0.310308, 0.570186, -0.217271, -0.110967, 0.196375, 0.167058, 0.264071, -0.130023, 0.254189, +// -0.459057, -0.301033, 0.069932, -0.033338, -0.070600, 0.685064, 0.130274, 0.074929, -0.206899, 0.574057, +// 0.327277, -0.131588, -0.018497, 0.312445, 0.314594, 0.480422, -0.293858, -0.273277, -0.006598, -0.134574, +// 0.403501, 0.140025, 0.380693, -0.257039, -0.067012, 0.248776, -0.361838, -0.270296, -0.225844, 0.320245, +// 0.055730, 0.454809, -0.212163, -0.063281, 0.563112, -0.200737, 0.537389, -0.210845, 0.109997, 0.166215, +// -0.243725, -0.347349, -0.274348, 0.263950, 0.437134, 0.265820, -0.127520, -0.033325, -0.137156, 0.518557, +// 0.246720, 0.389394, -0.600568, 0.062027, -0.047838, -0.338416, 0.032778, -0.141998, -0.338022, -0.381467, +// 0.210512, -0.314413, 0.256321, 0.001460, 0.238901, 0.139840, 0.633423, -0.182575, -0.461504, 0.290250, +// -0.025930, 0.336998, -0.211280, -0.662387, -0.207946, -0.003860, -0.147842, 0.157217, 0.123704, 0.345686, +// 0.337946, 0.138261, -0.178814, -0.109597, 0.087135, -0.509500, -0.300296, -0.262279, 0.377476, -0.366815, +// 0.091787, 0.247495, -0.193812, -0.179714, 0.238552, -0.162305, -0.029549, 0.785426, -0.157586, -0.084533, +// -0.357024, 0.317878, 0.217656, 0.125319, 0.648832, 0.344045, -0.001109, 0.457190, -0.072439, -0.106278, +// 0.228962, -0.136139, -0.528342, -0.020840, -0.108908, -0.231661, 0.396864, 0.234925, 0.180894, -0.179430, +// -0.587730, 0.178276, -0.008672, -0.386172, 0.033155, 0.319568, 0.101457, -0.272011, 0.126007, 0.175374, +// -0.081668, 0.112987, -0.296422, -0.713743, 0.269413, -0.082098, -0.338649, 0.131035, -0.518616, 0.022478, +// 0.177802, -0.042432, -0.606219, -0.343848, 0.014416, -0.141375, 0.748332, -0.165911, -0.049067, -0.241062, +// 0.436318, 0.173318, 0.058066, 0.193764, -0.000647, 0.265777, -0.027847, -0.096305, 0.711632, 0.066506, +// -0.223124, 0.219165, -0.038165, 0.427444, -0.296887, 0.139982, 0.298976, 0.294876, -0.001315, 0.419802, +// 0.475401, -0.156256, -0.289477, -0.438761, -0.116348, 0.108350, -0.369368, -0.219943, 0.433088, 0.187565, +// -0.217259, 0.147014, -0.538991, -0.065052, 0.310337, 0.491887, 0.254439, 0.075052, 0.071155, -0.084856, +// 0.402098, 0.096270, 0.093662, -0.475769, 0.256832, 0.161394, -0.390050, -0.513551, -0.184665, 0.211506, +// -0.112525, -0.493409, -0.258765, 0.262124, -0.272998, 0.269370, 0.266226, -0.367919, 0.192386, -0.006422, +// -0.466728, -0.481792, 0.090611, -0.156359, 0.178693, -0.371658, -0.214190, -0.469058, -0.006134, 0.081902, +// 0.536950, 0.064836, -0.334010, 0.523530, -0.182061, -0.206686, 0.002985, 0.054858, -0.038727, -0.075390, +// 0.543839, -0.442964, -0.190550, -0.298127, -0.065323, 0.131415, 0.329899, 0.122096, -0.507075, 0.523751, +// -0.167317, 0.198593, -0.069066, 0.402739, 0.328583, 0.314184, -0.268003, -0.148549, 0.118925, -0.508174, +// 0.128716, -0.405597, -0.157224, 0.271021, -0.384444, -0.174935, 0.343919, -0.076726, 0.607931, 0.383931, +// 0.198254, 0.133707, 0.321460, -0.232543, 0.099988, -0.321954, -0.366304, -0.137440, 0.232835, -0.290306, +// -0.260804, -0.347721, 0.182895, 0.382311, -0.332847, -0.192469, -0.438258, -0.017533, -0.192976, -0.702531, +// 0.124463, 0.039719, -0.221319, -0.224785, 0.096356, -0.302131, -0.462598, 0.194320}); + + +// auto expV= NDArrayFactory::create('c', {2,2,10,10}, {-0.050761, 0.370975, -0.061567, -0.125530, 0.024081, 0.275524, -0.800334, +// -0.025855, 0.348132, 0.036882, 0.034921, 0.307295, 0.629837, 0.014276, 0.265687, 0.188407, -0.035481, 0.082827, +// -0.490175, 0.391118, -0.180180, 0.169108, 0.206663, 0.623321, 0.260009, 0.081943, 0.004485, 0.136199, 0.060353, +// -0.641224, -0.181559, -0.041761, 0.578416, -0.161798, -0.573128, -0.187563, 0.012533, 0.368041, 0.314619, +// -0.079349, -0.527508, 0.216020, 0.004721, 0.188769, -0.242534, -0.442685, -0.121683, -0.565306, -0.202894, +// 0.095280, -0.181900, -0.170627, -0.201655, 0.620259, -0.257996, 0.277656, -0.009623, 0.266775, 0.081952, +// 0.539241, -0.452254, -0.136142, 0.177049, -0.144734, 0.494673, 0.101613, 0.280091, -0.186281, 0.548779, +// 0.235160, 0.054763, -0.571503, 0.298086, 0.035312, -0.195188, 0.474030, -0.175457, -0.497267, -0.101439, +// -0.170678, -0.060605, -0.557305, 0.073433, 0.057195, 0.352091, -0.486102, -0.483569, 0.252091, -0.121245, +// 0.068719, -0.638919, -0.078029, -0.236556, -0.351440, -0.024437, 0.319855, -0.007406, 0.319691, -0.402334, +// -0.197966, 0.058936, -0.360900, 0.233414, -0.251532, 0.105457, 0.048097, 0.029321, 0.002714, -0.845953, +// -0.136344, 0.378037, 0.277491, 0.278420, 0.037491, 0.432117, -0.586745, 0.104573, 0.316569, -0.039848, +// 0.239645, -0.320923, 0.555156, 0.145059, -0.546959, 0.267760, 0.298029, 0.177831, -0.191286, -0.032427, +// 0.197034, 0.081887, -0.113063, 0.711713, 0.020279, -0.362346, -0.145776, 0.173289, -0.500880, 0.181624, +// 0.084391, -0.278967, 0.212143, -0.413382, 0.012879, -0.216886, -0.625774, 0.066795, -0.421937, -0.291320, +// 0.011402, -0.416660, -0.134200, 0.043039, 0.554715, 0.126867, 0.147315, 0.474334, 0.094354, -0.156458, +// 0.450168, 0.447448, 0.261750, -0.161426, -0.064309, -0.592417, 0.210891, 0.104312, 0.176178, -0.237020, +// 0.455579, -0.358056, -0.307454, 0.033700, -0.486831, -0.303963, -0.284916, 0.241549, 0.510701, 0.206104, +// 0.062587, 0.248212, 0.132088, -0.122704, 0.026342, -0.011108, 0.066306, 0.763127, 0.009491, 0.038822, +// -0.562773, -0.320104, 0.477773, 0.354169, 0.293329, -0.304227, -0.001662, -0.213324, 0.365277, -0.198056, +// -0.383499, -0.017789, 0.324542, -0.642856, 0.238689, -0.360461, -0.060599, -0.257192, 0.342400, 0.180845, +// 0.272810, -0.452278, -0.409323, 0.077013, -0.082561, 0.334893, -0.103309, -0.198049, 0.480416, 0.470593, +// 0.029072, -0.300574, 0.532293, 0.250892, -0.355298, 0.079716, -0.319781, 0.259925, 0.277872, -0.251917, +// 0.346821, 0.161642, 0.205861, 0.107125, -0.594779, -0.226272, 0.610183, -0.065926, 0.170332, 0.312553, +// -0.108093, 0.368268, -0.183109, -0.192222, -0.544559, 0.136824, -0.412352, -0.398250, -0.257291, 0.019911, +// 0.288797, 0.013350, 0.349817, -0.108331, 0.180576, 0.652863, 0.319319, 0.020218, -0.324499, 0.290877, +// 0.338518, -0.301776, -0.440871, -0.281683, -0.158759, -0.080281, 0.418260, 0.189926, -0.064112, -0.390914, +// 0.485420, -0.464327, 0.211070, 0.044295, -0.032292, 0.043985, 0.147160, -0.702247, -0.198395, -0.352940, +// -0.237014, -0.438235, 0.073448, -0.418712, -0.280275, -0.091373, -0.194273, 0.347558, -0.421767, 0.283011, +// -0.351869, -0.210088, -0.034628, 0.448410, 0.149194, -0.488551, -0.068805, -0.117007, -0.390999, 0.377100, +// 0.423252, -0.041944, 0.455115, -0.537818, 0.266732, 0.218202, 0.047475, -0.383506, -0.158858, 0.450881, +// 0.072415, 0.355772, 0.002360, 0.138976, 0.541349, -0.295405, 0.463832, 0.400676, -0.168962, 0.259334, +// -0.047960, 0.272197, 0.582658, 0.198052, 0.127300, -0.320468, -0.104858, -0.229698, 0.046672, -0.474224, +// 0.370765, -0.246450, 0.212667, 0.024935, -0.344530, -0.238547, 0.185931, 0.269068, 0.487414, 0.421376, +// 0.442391, -0.284247, 0.304973, -0.365006, -0.159016, -0.129088, -0.126454, 0.600462, -0.461163, -0.243552, +// -0.049814, -0.381340, -0.054504, 0.436237, 0.126120, -0.359677, -0.409734, -0.179422, -0.414820, 0.371149, +// 0.078299, 0.503544, 0.322165, 0.148341, -0.495447, -0.084355, -0.174667, 0.016802, -0.066954, 0.318825, +// -0.480771, -0.060163, 0.144302, -0.041555, 0.459106, 0.029882, -0.565026, 0.282336, 0.528472, 0.044916, +// -0.286167, -0.101052, -0.181529, -0.419406, -0.032204, -0.732282, 0.106833, -0.288881, 0.171516, -0.096242, +// -0.331834, -0.493188, 0.393195, 0.358365, 0.049125, 0.123457, 0.438169, -0.105015, 0.092386, -0.130413, -0.476991}); + +// sd::ops::svd op; +// auto results = op.execute({&x}, {}, {1, 1, 7}); + +// ASSERT_EQ(ND4J_STATUS_OK, result.status()); + +// auto *s = result.at(0); +// auto *u = result.at(1); +// auto *v = result.at(2); + + // ASSERT_TRUE(expS.isSameShape(s)); + // ASSERT_TRUE(expU.isSameShape(u)); + // ASSERT_TRUE(expV.isSameShape(v)); + + // ASSERT_TRUE(expS.equalsTo(s)); + + // if(sd::Environment::getInstance().isCPU()) { + // ASSERT_TRUE(expU.equalsTo(u)); + // ASSERT_TRUE(expV.equalsTo(v)); + // } + // else { + // for(uint i = 0; i < expU.lengthOf(); ++i) + // ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u->e(i)), 1e-5); + // for(uint i = 0; i < expV.lengthOf(); ++i) + // ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v->e(i)), 1e-5); + // } + +// +// } + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, svd_test9) { + + auto x= NDArrayFactory::create('c', {2,2,5,6}, {17 ,-11 ,20 ,-10 ,19 ,13 ,-18 ,6 ,-2 ,-6 ,-10 ,4 ,-6 ,-4 ,3 ,16 ,12 , + -15 ,8 ,-8 ,12 ,-1 ,20 ,19 ,-13 ,0 ,20 ,17 ,-8 ,16 ,-19 ,7 ,-16 ,-14 ,-5 ,7 ,7 ,-5 ,12 ,-15 ,7 ,8 , + 1 ,-8 ,-17 ,10 ,-11 ,8 ,-10 ,1 ,-6 ,10 ,15 ,19 ,-15 ,8 ,2 ,8 ,12 ,7 ,-5 ,1 ,8 ,4 ,-13 ,2 ,19 ,-2 ,-10 , + -8 ,11 ,1 ,20 ,-11 ,4 ,1 ,-17 ,-15 ,0 ,-9 ,-4 ,-1 ,-6 ,-9 ,-13 ,10 ,7 ,-2 ,15 ,-10 ,-1 ,11 ,-20 ,-2 , + -1 ,-18 ,12 ,16 ,8 ,-9 ,-20 ,-7 ,-20 ,3 ,-9 ,12 ,8 ,-19 ,-2 ,2 ,1 ,7 ,10 ,-18 ,13 ,6 ,14 ,0 ,19 ,8}); + + auto expS= NDArrayFactory::create('c', {2,2,5}, {50.46507, 35.75599, 28.12787, 12.45245, 9.08545, + 38.56035, 30.62846, 26.31646, 19.42605, 3.01162, + 38.56369, 29.18881, 19.54565, 10.89746, 2.017 , + 44.99108, 34.95059, 26.00453, 15.43898, 7.18752}); + + auto expU= NDArrayFactory::create('c', {2,2,5,5}, {-0.73644, -0.10751, 0.10081, -0.00325, 0.66025,0.26329, 0.3079 , 0.38582, 0.77696, 0.28872,0.03076, 0.03015, -0.9128 , 0.36387, 0.18039, + -0.61335, 0.10076, 0.01381, 0.40922, -0.66783,-0.10577, 0.93946, -0.0871 , -0.31058, 0.04677,0.52823, 0.31163, -0.78777, 0.02322, -0.05234, + -0.23942, -0.45801, -0.34248, 0.71286, 0.32778,0.26147, 0.60409, 0.39933, 0.46862, 0.43318,0.62118, -0.37993, 0.30992, 0.34537, -0.50444, + 0.45763, -0.42877, 0.08128, -0.3904 , 0.66912,-0.05428, 0.53632, 0.19774, -0.32198, 0.75276,-0.21986, -0.8214 , -0.00392, -0.1659 , 0.49944, + -0.79443, 0.1633 , -0.45374, -0.31666, -0.18989,-0.24459, 0.10463, -0.27652, 0.85595, 0.34657,0.50772, 0.00757, -0.82374, -0.18941, 0.16658, 0.49473, -0.39923, -0.20758, 0.74339, -0.01213, + -0.2024 , -0.80239, -0.35502, -0.3982 , -0.17492,0.68875, 0.1822 , -0.08046, -0.39238, -0.57619,0.34555, 0.12488, -0.50703, -0.29269, 0.72267,-0.34713, 0.3847 , -0.7532 , 0.22176, -0.33913}); + + auto expV= NDArrayFactory::create('c', {2,2,6,6}, {-4.15640000e-01, -5.30190000e-01, 5.29200000e-02, -7.15710000e-01,-1.10690000e-01, 1.37280000e-01,2.86620000e-01, 5.88200000e-02, 1.68760000e-01, -2.55000000e-03,-1.00090000e-01, 9.35890000e-01, + -4.88230000e-01, 4.84470000e-01, -1.09150000e-01, -1.46810000e-01,6.70320000e-01, 2.10040000e-01,1.00910000e-01, 4.35740000e-01, -6.90500000e-01, -3.61090000e-01,-4.38680000e-01, 1.83200000e-02, + -5.48440000e-01, -2.86950000e-01, -4.23900000e-01, 5.78540000e-01,-2.10060000e-01, 2.41550000e-01,-4.42450000e-01, 4.56640000e-01, 5.48020000e-01, 3.32100000e-02,-5.40210000e-01, -4.97000000e-02, + -6.36070000e-01, 5.57600000e-02, 3.28740000e-01, 3.81950000e-01,-4.21850000e-01, 4.00490000e-01,1.83740000e-01, -1.36190000e-01, -2.29380000e-01, -5.11090000e-01,-2.06580000e-01, 7.68890000e-01, + -4.81880000e-01, -6.31100000e-01, 3.40000000e-04, -1.35730000e-01,5.88210000e-01, 7.12900000e-02,2.25200000e-01, 4.30600000e-02, 9.08510000e-01, -3.08940000e-01,1.51570000e-01, 6.02100000e-02, + 1.97510000e-01, -7.26560000e-01, 1.05370000e-01, 1.10600000e-02,-5.79750000e-01, -2.92870000e-01,4.89620000e-01, -2.24300000e-01, 5.31200000e-02, 6.92040000e-01,2.72560000e-01, 3.92350000e-01, + -6.84450000e-01, -5.18030000e-01, 2.92000000e-02, -4.96740000e-01,-1.17970000e-01, -4.08100000e-02,4.25340000e-01, -1.65500000e-02, -2.82400000e-02, -5.60180000e-01,1.93050000e-01, -6.83340000e-01, + 8.08800000e-02, 4.38260000e-01, -2.48340000e-01, -6.36220000e-01,2.37500000e-02, 5.78250000e-01,-6.10000000e-04, 3.00110000e-01, 1.17290000e-01, -6.92400000e-02,-9.19220000e-01, -2.15420000e-01, + 5.41330000e-01, -6.61130000e-01, -2.86360000e-01, -2.13500000e-02,-3.19580000e-01, 2.92020000e-01,2.25920000e-01, -1.10170000e-01, 9.17020000e-01, -1.71540000e-01,3.39100000e-02, 2.55590000e-01, + -4.86810000e-01, -2.32390000e-01, -4.31500000e-01, 3.75290000e-01,4.98470000e-01, -3.65370000e-01,6.39700000e-02, -4.04150000e-01, -5.28310000e-01, 8.90000000e-02,-7.30460000e-01, -1.09390000e-01, + -4.94030000e-01, 1.55540000e-01, -3.46720000e-01, -7.58460000e-01,5.20000000e-04, 1.90420000e-01,2.55960000e-01, 3.17040000e-01, -3.47800000e-02, -3.01860000e-01,-3.57600000e-02, -8.60450000e-01, + 1.31650000e-01, 7.57150000e-01, -4.89030000e-01, 3.47710000e-01,-4.39400000e-02, 2.17750000e-01,-6.57270000e-01, 2.91000000e-01, 4.17280000e-01, 2.52880000e-01,-4.63400000e-01, -1.74620000e-01}); + + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {1, 1, 16}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *s = result.at(0); + auto *u = result.at(1); + auto *v = result.at(2); + + ASSERT_TRUE(expS.isSameShape(s)); + ASSERT_TRUE(expU.isSameShape(u)); + ASSERT_TRUE(expV.isSameShape(v)); + + ASSERT_TRUE(expS.equalsTo(s)); + + if(sd::Environment::getInstance().isCPU()) { + ASSERT_TRUE(expU.equalsTo(u)); + ASSERT_TRUE(expV.equalsTo(v)); + } + else { + for(uint i = 0; i < expU.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u->e(i)), 1e-5); + for(uint i = 0; i < expV.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v->e(i)), 1e-5); + } + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, svd_test10) { + + auto x= NDArrayFactory::create('c', {2,2,5,6}, {17 ,-11 ,20 ,-10 ,19 ,13 ,-18 ,6 ,-2 ,-6 ,-10 ,4 ,-6 ,-4 ,3 ,16 ,12 , + -15 ,8 ,-8 ,12 ,-1 ,20 ,19 ,-13 ,0 ,20 ,17 ,-8 ,16 ,-19 ,7 ,-16 ,-14 ,-5 ,7 ,7 ,-5 ,12 ,-15 ,7 ,8 , + 1 ,-8 ,-17 ,10 ,-11 ,8 ,-10 ,1 ,-6 ,10 ,15 ,19 ,-15 ,8 ,2 ,8 ,12 ,7 ,-5 ,1 ,8 ,4 ,-13 ,2 ,19 ,-2 ,-10 , + -8 ,11 ,1 ,20 ,-11 ,4 ,1 ,-17 ,-15 ,0 ,-9 ,-4 ,-1 ,-6 ,-9 ,-13 ,10 ,7 ,-2 ,15 ,-10 ,-1 ,11 ,-20 ,-2 , + -1 ,-18 ,12 ,16 ,8 ,-9 ,-20 ,-7 ,-20 ,3 ,-9 ,12 ,8 ,-19 ,-2 ,2 ,1 ,7 ,10 ,-18 ,13 ,6 ,14 ,0 ,19 ,8}); + + auto expS= NDArrayFactory::create('c', {2,2,5}, {50.46507, 35.75599, 28.12787, 12.45245, 9.08545, + 38.56035, 30.62846, 26.31646, 19.42605, 3.01162, + 38.56369, 29.18881, 19.54565, 10.89746, 2.017 , + 44.99108, 34.95059, 26.00453, 15.43898, 7.18752}); + + auto expU= NDArrayFactory::create('c', {2,2,5,5}, {-0.73644, -0.10751, 0.10081, -0.00325, 0.66025,0.26329, 0.3079 , 0.38582, 0.77696, 0.28872,0.03076, 0.03015, -0.9128 , 0.36387, 0.18039,-0.61335, 0.10076, 0.01381, 0.40922, -0.66783, + -0.10577, 0.93946, -0.0871 , -0.31058, 0.04677,0.52823, 0.31163, -0.78777, 0.02322, -0.05234,-0.23942, -0.45801, -0.34248, 0.71286, 0.32778,0.26147, 0.60409, 0.39933, 0.46862, 0.43318, + 0.62118, -0.37993, 0.30992, 0.34537, -0.50444,0.45763, -0.42877, 0.08128, -0.3904 , 0.66912,-0.05428, 0.53632, 0.19774, -0.32198, 0.75276,-0.21986, -0.8214 , -0.00392, -0.1659 , 0.49944, + -0.79443, 0.1633 , -0.45374, -0.31666, -0.18989,-0.24459, 0.10463, -0.27652, 0.85595, 0.34657,0.50772, 0.00757, -0.82374, -0.18941, 0.16658,0.49473, -0.39923, -0.20758, 0.74339, -0.01213, + -0.2024 , -0.80239, -0.35502, -0.3982 , -0.17492,0.68875, 0.1822 , -0.08046, -0.39238, -0.57619,0.34555, 0.12488, -0.50703, -0.29269, 0.72267,-0.34713, 0.3847 , -0.7532 , 0.22176, -0.33913}); + + auto expV= NDArrayFactory::create('c', {2,2,6,5}, { -4.15640000e-01, -5.30190000e-01, 5.29200000e-02, -7.15710000e-01,-1.10690000e-01,2.86620000e-01, 5.88200000e-02, 1.68760000e-01, -2.55000000e-03,-1.00090000e-01, + -4.88230000e-01, 4.84470000e-01, -1.09150000e-01, -1.46810000e-01,6.70320000e-01,1.00910000e-01, 4.35740000e-01, -6.90500000e-01, -3.61090000e-01,-4.38680000e-01,-5.48440000e-01, -2.86950000e-01, -4.23900000e-01, 5.78540000e-01, + -2.10060000e-01,-4.42450000e-01, 4.56640000e-01, 5.48020000e-01, 3.32100000e-02,-5.40210000e-01,-6.36070000e-01, 5.57600000e-02, 3.28740000e-01, 3.81950000e-01,-4.21850000e-01, + 1.83740000e-01, -1.36190000e-01, -2.29380000e-01, -5.11090000e-01,-2.06580000e-01,-4.81880000e-01, -6.31100000e-01, 3.40000000e-04, -1.35730000e-01,5.88210000e-01,2.25200000e-01, 4.30600000e-02, 9.08510000e-01, -3.08940000e-01, + 1.51570000e-01,1.97510000e-01, -7.26560000e-01, 1.05370000e-01, 1.10600000e-02,-5.79750000e-01,4.89620000e-01, -2.24300000e-01, 5.31200000e-02, 6.92040000e-01,2.72560000e-01, + -6.84450000e-01, -5.18030000e-01, 2.92000000e-02, -4.96740000e-01,-1.17970000e-01,4.25340000e-01, -1.65500000e-02, -2.82400000e-02, -5.60180000e-01,1.93050000e-01,8.08800000e-02, 4.38260000e-01, -2.48340000e-01, -6.36220000e-01,2.37500000e-02,-6.10000000e-04, 3.00110000e-01, 1.17290000e-01, -6.92400000e-02,-9.19220000e-01, + 5.41330000e-01, -6.61130000e-01, -2.86360000e-01, -2.13500000e-02,-3.19580000e-01,2.25920000e-01, -1.10170000e-01, 9.17020000e-01, -1.71540000e-01,3.39100000e-02,-4.86810000e-01, -2.32390000e-01, -4.31500000e-01, 3.75290000e-01,4.98470000e-01,6.39700000e-02, -4.04150000e-01, -5.28310000e-01, 8.90000000e-02,-7.30460000e-01, + -4.94030000e-01, 1.55540000e-01, -3.46720000e-01, -7.58460000e-01,5.20000000e-04,2.55960000e-01, 3.17040000e-01, -3.47800000e-02, -3.01860000e-01,-3.57600000e-02,1.31650000e-01, 7.57150000e-01, -4.89030000e-01, 3.47710000e-01, + -4.39400000e-02,-6.57270000e-01, 2.91000000e-01, 4.17280000e-01, 2.52880000e-01,-4.63400000e-01}); + + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {0, 1, 16}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto *s = result.at(0); + auto *u = result.at(1); + auto *v = result.at(2); + + ASSERT_TRUE(expS.isSameShape(s)); + ASSERT_TRUE(expU.isSameShape(u)); + ASSERT_TRUE(expV.isSameShape(v)); + + ASSERT_TRUE(expS.equalsTo(s)); + + if(sd::Environment::getInstance().isCPU()) { + ASSERT_TRUE(expU.equalsTo(u)); + ASSERT_TRUE(expV.equalsTo(v)); + } + else { + for(uint i = 0; i < expU.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u->e(i)), 1e-5); + for(uint i = 0; i < expV.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v->e(i)), 1e-5); + } + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, svd_test11) { + + NDArray x('c', {2,2,3,3}, {0.7788, 0.8012, 0.7244, 0.2309, 0.7271, 0.1804, 0.5056, 0.8925, 0.5461, 0.9234, 0.0856, 0.7938, 0.6591, 0.5555, + 0.1596, 0.3087, 0.1548, 0.4695, 0.7788, 0.8012, 0.7244, 0.2309, 0.7271, 0.1804, 0.5056, 0.8925, -0.5461, 0.9234, + 0.0856, -0.7938, 0.6591, 0.5555, 0.1500, 0.3087, 0.1548, 0.4695}); + NDArray expS('c', {2,2,3}, {1.89671, 0.37095, 0.05525,1.51296, 0.52741, 0.17622, 1.69095, 0.90438, 0.24688,1.33551, 0.87475, 0.21571}); + NDArray expU('c', {2,2,3,3}, {6.9205e-01, 6.0147e-01, -3.9914e-01, 3.8423e-01, -7.7503e-01, -5.0170e-01, 6.1110e-01, -1.9384e-01, 7.6746e-01, + 7.8967e-01, 4.5442e-01, -4.1222e-01, 4.9381e-01, -8.6948e-01, -1.2540e-02, 3.6412e-01, 1.9366e-01, 9.1100e-01, + 7.1764e-01, 5.9844e-01, 3.5617e-01, 4.4477e-01, -3.1000e-04, -8.9564e-01, 5.3588e-01, -8.0116e-01, 2.6639e-01, + 8.7050e-01, -4.2088e-01, -2.5513e-01, 4.8622e-01, 6.5499e-01, 5.7843e-01, 7.6340e-02, 6.2757e-01, -7.7481e-01}); + NDArray expV('c', {2,2,3,3}, {0.49383, 0.51614, -0.69981, 0.72718, -0.68641, 0.00688, 0.4768 , 0.51228, 0.7143 , 0.77137, -0.17763, + -0.6111 , 0.26324, -0.7852 , 0.56051, 0.57939, 0.59322, 0.55892, 0.55149, 0.06737, 0.83146, 0.81413, + -0.26072, -0.51887, 0.18182, 0.96306, -0.19863, 0.85948, 0.2707 , -0.4336 , 0.26688, 0.48582, 0.83232, + -0.43596, 0.83108, -0.34531}); + + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {0, 1, 16}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto s = result.at(0); + auto u = result.at(1); + auto v = result.at(2); + + ASSERT_TRUE(expS.isSameShape(s)); + ASSERT_TRUE(expU.isSameShape(u)); + ASSERT_TRUE(expV.isSameShape(v)); + + ASSERT_TRUE(expS.equalsTo(s)); + + if(sd::Environment::getInstance().isCPU()) { + ASSERT_TRUE(expU.equalsTo(u)); + ASSERT_TRUE(expV.equalsTo(v)); + } + else { + for(uint i = 0; i < expU.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expU.e(i)), sd::math::nd4j_abs(u->e(i)), 1e-5); + for(uint i = 0; i < expV.lengthOf(); ++i) + ASSERT_NEAR(sd::math::nd4j_abs(expV.e(i)), sd::math::nd4j_abs(v->e(i)), 1e-5); + } + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, svd_test12) { + + NDArray x('c', {4,3}, {1.7787856,0.80119777,0.72437465,0.23089433,1.7271413,0.18039072,0.50563407,0.89252293,1.5461209,0.92336726,0.085571885,0.79378015}); + NDArray expS('c', {3}, {3.024703, 1.459483, 1.026371}); + + + sd::ops::svd op; + auto result = op.evaluate({&x}, {}, {1, 0, 16}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto s = result.at(0); + + ASSERT_TRUE(expS.equalsTo(s)); + ASSERT_TRUE(expS.isSameShape(s)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, elu_test1) { + + auto x = NDArrayFactory::create('c', {3,3}, {0.1, .2, .3, -.4,-.5,-.6, .7, .8, .9}); + auto exp = NDArrayFactory::create('c', {3,3}, {.1, .2, .3, 0.5*-0.32968, 0.5*-0.393469, 0.5*-0.451188, .7, .8, .9}); + + sd::ops::elu op; + auto result = op.evaluate({&x}, {0.5}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto s = result.at(0); + ASSERT_TRUE(exp.equalsTo(s)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, elu_bp_test1) { + + auto x = NDArrayFactory::create('c', {3, 3}, {0.1, .2, .3, -.4, -.5, -.6, .7, .8, .9}); + auto eps = NDArrayFactory::create('c', {3,3}); + eps.assign(2.); + auto exp = NDArrayFactory::create('c', {3, 3}, {2, 2, 2, 0.5*1.34064, 0.5*1.213061, 0.5*1.097623, 2, 2, 2}); + + sd::ops::elu_bp op; + auto result = op.evaluate({ &x, &eps }, {0.5}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto s = result.at(0); + ASSERT_TRUE(exp.equalsTo(s)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, lrelu_test1) { + + auto x = NDArrayFactory::create('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9}); + auto exp = NDArrayFactory::create('c', {3,3}, {1, 2, 3, -0.8, -1., -1.2, 7, 8, 9}); + + sd::ops::lrelu op; + auto result = op.evaluate({&x}, {0.2}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto s = result.at(0); + ASSERT_TRUE(exp.equalsTo(s)); + +} + +TEST_F(DeclarableOpsTests3, lrelu_bp_test1) { + + auto x = NDArrayFactory::create('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9}); + auto eps = NDArrayFactory::create('c', {3,3}, {2,2,2,2,2,2,2, 2,2}); + auto exp = NDArrayFactory::create('c', {3,3}, {2, 2, 2, 0.4, 0.4, 0.4, 2, 2, 2}); + + sd::ops::lrelu_bp op; + auto result = op.evaluate({&x, &eps}, {0.2}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto s = result.at(0); + ASSERT_TRUE(exp.equalsTo(s)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, selu_test1) { + + auto x = NDArrayFactory::create('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9}); + auto exp = NDArrayFactory::create('c', {3,3}, {1.050701, 2.101402, 3.152103, -1.725899, -1.746253, -1.753742, 7.354907, 8.405608, 9.456309}); + + sd::ops::selu op; + auto result = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto s = result.at(0); + ASSERT_TRUE(exp.equalsTo(s)); + +} + +TEST_F(DeclarableOpsTests3, selu_test2) { + + auto x = NDArrayFactory::create('c', {3,3}, {1, 2, 3, -4,-5,-6, 7, 8, 9}); +// auto expS = NDArrayFactory::create('c', {3}); + auto eps = NDArrayFactory::create('c', {3,3}, {2,2,2,2,2,2,2, 2,2}); + auto exp = NDArrayFactory::create('c', {3,3}, {2.101401, 2.101402, 2.101402, 0.064401, 0.023692, 0.008716, 2.101402, 2.101402, 2.101402}); + + sd::ops::selu_bp op; + auto result = op.evaluate({&x, &eps}, {0.2}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto s = result.at(0); +// auto u = result.at(1); +// auto v = result.at(2); +// s->printIndexedBuffer("SELU_BP"); + ASSERT_TRUE(exp.equalsTo(s)); + +} + +TEST_F(DeclarableOpsTests3, EQScalarTests_1) { + Graph graph; + + auto x = NDArrayFactory::create(1.0f); + auto scalar = NDArrayFactory::create(1.0f); + + sd::ops::eq_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_TRUE(res); + +} + +TEST_F(DeclarableOpsTests3, EQScalarTests_2) { + Graph graph; + + auto x = NDArrayFactory::create(2.0f); + auto scalar = NDArrayFactory::create(1.0f); + + sd::ops::eq_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_FALSE(res); +} + +TEST_F(DeclarableOpsTests3, GTScalarTests_1) { + Graph graph; + + auto x = NDArrayFactory::create(1.0f); + auto scalar = NDArrayFactory::create(1.0f); + + sd::ops::gt_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_FALSE(res); +} + +TEST_F(DeclarableOpsTests3, GTScalarTests_2) { + Graph graph; + + auto x = NDArrayFactory::create(2.0f); + auto scalar = NDArrayFactory::create(1.0f); + + sd::ops::gt_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_TRUE(res); +} + +TEST_F(DeclarableOpsTests3, GTEScalarTests_1) { + Graph graph; + + auto x = NDArrayFactory::create(1.0f); + auto scalar = NDArrayFactory::create(1.0f); + + sd::ops::gte_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_TRUE(res); +} + +TEST_F(DeclarableOpsTests3, GTEScalarTests_2) { + Graph graph; + + auto x = NDArrayFactory::create(2.0f); + auto scalar = NDArrayFactory::create(1.0f); + + sd::ops::gte_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_TRUE(res); +} + +TEST_F(DeclarableOpsTests3, GTEScalarTests_3) { + Graph graph; + + auto x = NDArrayFactory::create(1.0f); + auto scalar = NDArrayFactory::create(2.0f); + + sd::ops::gte_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_FALSE(res); +} + +TEST_F(DeclarableOpsTests3, LTEScalarTests_1) { + Graph graph; + + auto x = NDArrayFactory::create(1.0f); + auto scalar = NDArrayFactory::create(1.0f); + + sd::ops::lte_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_TRUE(res); +} + +TEST_F(DeclarableOpsTests3, LTEScalarTests_2) { + Graph graph; + + auto x = NDArrayFactory::create(2.0f); + auto scalar = NDArrayFactory::create(1.0f); + + sd::ops::lte_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_FALSE(res); +} + +TEST_F(DeclarableOpsTests3, LTEScalarTests_3) { + Graph graph; + + auto x = NDArrayFactory::create(1.0f); + auto scalar = NDArrayFactory::create(2.0f); + + sd::ops::lte_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_TRUE(res); +} + +TEST_F(DeclarableOpsTests3, NEQScalarTests_1) { + Graph graph; + + auto x = NDArrayFactory::create(1.0f); + auto scalar = NDArrayFactory::create(1.0f); + + sd::ops::neq_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_FALSE(res); + +} + +TEST_F(DeclarableOpsTests3, NEQScalarTests_2) { + Graph graph; + + auto x = NDArrayFactory::create(2.0f); + auto scalar = NDArrayFactory::create(1.0f); + + sd::ops::neq_scalar op; + auto res = op.verify({&x, &scalar}); + ASSERT_TRUE(res); +} + +TEST_F(DeclarableOpsTests3, NOOPTests_1) { + Graph graph; + + auto x = NDArrayFactory::create(2.0f); + auto scalar = NDArrayFactory::create(1.0f); + + sd::ops::noop op; + auto res = op.evaluate({&x, &scalar}, {}, {}); + ASSERT_TRUE(res.status() == sd::Status::OK()); +} diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests4.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests4.cpp new file mode 100644 index 000000000..303f5f116 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests4.cpp @@ -0,0 +1,2450 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include + + +using namespace sd; +using namespace sd::graph; + +class DeclarableOpsTests4 : public testing::Test { +public: + + DeclarableOpsTests4() { + printf("\n"); + fflush(stdout); + + sd::ops::adjust_hue op0; + sd::ops::adjust_saturation op1; + } +}; + +template +class TypedDeclarableOpsTests4 : public testing::Test { +public: + + TypedDeclarableOpsTests4() { + printf("\n"); + fflush(stdout); + + sd::ops::adjust_hue op0; + sd::ops::adjust_saturation op1; + } +}; + +typedef ::testing::Types TestingTypes; +TYPED_TEST_CASE(TypedDeclarableOpsTests4, TestingTypes); + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_1) { + auto x = NDArrayFactory::create('c', {2, 4, 4, 2}); + auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, {6.f, 7.f, 10.f, 11.f, 22.f, 23.f, 26.f, 27.f, 38.f, 39.f, 42.f, 43.f, 54.f, 55.f, 58.f, 59.f}); + + x.linspace(1); + + sd::ops::avgpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 1, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_2) { + auto x = NDArrayFactory::create('c', {2, 4, 4, 2}); + auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, {6.f, 7.f, 10.f, 11.f, 22.f, 23.f, 26.f, 27.f, 38.f, 39.f, 42.f, 43.f, 54.f, 55.f, 58.f, 59.f}); + + x.linspace(1); + + + sd::ops::avgpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_3) { + auto x = NDArrayFactory::create('c', {2, 5, 5, 2}); + auto exp = NDArrayFactory::create('c', {2, 3, 3, 2}, {7.f, 8.f, 11.f, 12.f, 14.f, 15.f, 27.f, 28.f, 31.f, 32.f, 34.f, 35.f, 42.f, 43.f, 46.f, 47.f, 49.f, 50.f, 57.f, 58.f, 61.f, 62.f, 64.f, 65.f, 77.f, 78.f, 81.f, 82.f, 84.f, 85.f, 92.f, 93.f, 96.f, 97.f, 99.f, 100.f,}); + + x.linspace(1); + + + sd::ops::avgpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 0, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//auto padding{top,right, bottom, left} matching arm_compute +std::tuple getSpecialAutoPadding(int rank) { + auto extra_pad_x = rank < 1 ? 0 : 32; + auto pad_x = rank < 1 ? 0 : 4; + auto pad_y = rank < 2 ? 0 : 4; + return std::tuple{ pad_y, pad_x + extra_pad_x, pad_y, pad_x }; +} + +TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_padded_buffer) { + + int top, right, bottom, left; + std::tie(top, right, bottom, left) = getSpecialAutoPadding(4); + + auto input = NDArrayFactory::create('c', {2, 5, 5, 2}, DataTypeUtils::fromT(), {0, 0, top + bottom, left + right}, {0, 0, top, left} ); + auto output = NDArrayFactory::create('c', {2, 3, 3, 2}, DataTypeUtils::fromT(), {0, 0, top + bottom, left + right}, {0, 0, top, left} ); + auto exp = NDArrayFactory::create('c', {2, 3, 3, 2}, {7.f, 8.f, 11.f, 12.f, 14.f, 15.f, 27.f, 28.f, 31.f, 32.f, 34.f, 35.f, 42.f, 43.f, 46.f, 47.f, 49.f, 50.f, 57.f, 58.f, 61.f, 62.f, 64.f, 65.f, 77.f, 78.f, 81.f, 82.f, 84.f, 85.f, 92.f, 93.f, 96.f, 97.f, 99.f, 100.f,}); + + input.linspace(1); + + sd::ops::avgpool2d op; + auto status = op.execute({&input}, {&output}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 0, 1}); + + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + +} + + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_4) { + auto x = NDArrayFactory::create('c', {2, 5, 5, 2}); + auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, {7.f, 8.f, 11.f, 12.f, 27.f, 28.f, 31.f, 32.f, 57.f, 58.f, 61.f, 62.f, 77.f, 78.f, 81.f, 82.f}); + + x.linspace(1); + + sd::ops::avgpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 0, 1, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_5) { + auto x = NDArrayFactory::create('c', {2, 2, 5, 5}); + auto exp = NDArrayFactory::create('c', {2, 2, 3, 3}, {1.f, 2.5f, 4.5f, 8.5f, 10.f, 12.f, 18.5f, 20.f, 22.f, 26.f, 27.5f, 29.5f, 33.5f, 35.f, 37.f, 43.5f, 45.f, 47.f, 51.f, 52.5f, 54.5f, 58.5f, 60.f, 62.f, 68.5f, 70.f, 72.f, 76.f, 77.5f, 79.5f, 83.5f, 85.f, 87.f, 93.5f, 95.f, 97.f}); + + x.linspace(1); + + + sd::ops::avgpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_6) { + auto x = NDArrayFactory::create('c', {2, 2, 5, 5}); + auto exp = NDArrayFactory::create('c', {2, 2, 3, 3}, {0.25f, 1.25f, 2.25f, 4.25f, 10.f, 12.f, 9.25f, 20.f, 22.f, 6.5f, 13.75f, 14.75, 16.75f, 35.f, 37.f, 21.75f, 45.f, 47.f, 12.75f, 26.25f, 27.25f, 29.25f, 60.f, 62.f, 34.25f, 70.f, 72.f, 19.f, 38.75f, 39.75f, 41.75f, 85.f, 87.f, 46.75f, 95.f, 97.f}); + + x.linspace(1); + + + sd::ops::avgpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 1, 1, 1, 1, 0, 1, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_7) { + auto x = NDArrayFactory::create('c', {2, 2, 5, 5}); + auto exp = NDArrayFactory::create('c', {2, 2, 3, 3}, {4.f, 6.f, 7.5f, 14.f, 16.f, 17.5f, 21.5f, 23.5f, 25.f, 29.f, 31.f, 32.5f, 39.f, 41.f, 42.5f, 46.5f, 48.5f, 50.f, 54.f, 56.f, 57.5f, 64.f, 66.f, 67.5f, 71.5f, 73.5f, 75.f, 79.f, 81.f, 82.5f, 89.f, 91.f, 92.5f, 96.5f, 98.5f, 100.f}); + + x.linspace(1); + + + sd::ops::avgpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 2, 2, 0, 0, 1, 1, 1, 0, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_8) { + auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); + auto exp = NDArrayFactory::create('c', {1, 1, 2, 2}, {3.f, 4.f, 6.f, 7.f}); + + x.linspace(1); + + sd::ops::avgpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 0, 0, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_9) { + auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); + auto exp = NDArrayFactory::create('c', {1, 1, 3, 3}, {3.f, 4.f, 4.5f, 6.f, 7.f, 7.5f, 7.5f, 8.5f, 9.f}); + + x.linspace(1); + + + sd::ops::avgpool2d op; + auto result = op.evaluate({&x}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1, 0, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + //z->printShapeInfo("z shape:"); + //z->printBuffer("z buffer:"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, avgpool2d_10) { + + auto input = NDArrayFactory::create('c', {4, 10, 10, 3}, {9.37125111f, 2.20166993f, 2.91434479f, 5.43639755f, -2.10573769f, 4.08528662f, 5.86908436f, -4.46203756f, 2.21057916f, 5.35849190f, 0.01394637f, 4.40566349f, 7.07982206f, -0.09633455f, 2.42429352f, 3.97301817f, -1.89553940f, 1.99690318f, 6.33141708f, 0.55401880f, 1.70707977f, + 5.55204201f, -0.03513752f, 1.60011971f, 2.62700319f, -2.74582434f, 3.06697464f, 1.06277943f, -1.16075921f, -0.78095782f, 9.72352791f, -1.22686064f, 1.99644792f, 7.35571337f, 1.40607321f, 0.11390255f, 9.53334427f, 2.28303599f, -1.66728830f, 6.16678810f, -0.04532295f, -1.97708666f, 9.74906158f, 1.46223176f, -1.46734393f, 4.30761862f, + -1.23790228f, 1.24823606f, 6.13938427f, -3.83689475f, -1.19625473f, 7.91535568f, 6.05868721f, -3.22946382f, 8.81633949f, -0.19967777f, 0.66053957f, 2.30919123f, 0.74543846f, -0.39347672f, 11.11058044f, 0.53720862f, 1.52645731f, 5.70012379f, -1.15213466f, 1.16451406f, 7.00526333f, 1.57362783f, -2.44384766f, 5.54213285f, -1.98828590f, + -0.70483637f, 7.88281822f, -3.59875536f, 0.80745387f, 13.41578484f, -1.55507684f, -0.65855008f, 9.32583523f, -0.14544789f, 0.73436141f, 3.61176538f, -1.71268058f, -2.58490300f, 9.09280205f, -3.27405524f, -2.04569697f, 4.44761324f, -0.62955856f, -2.61917663f, 8.04890442f, 0.54579324f, 0.85929775f, 9.82259560f, -1.93825579f, 0.77703512f, + 4.67090321f, -4.79267597f, -2.38906908f, 9.31265545f, 0.96026313f, -1.14109385f, 11.54231834f, -0.01417295f, -0.39500344f, 8.49191666f, 0.55300158f, 2.79490185f, 6.92466164f, 1.72254205f, 2.82222271f, 8.83112717f, 2.95033407f, 2.18054962f, 6.73509789f, -2.22272944f, 0.51127720f, -1.04563558f, 2.15747333f, -2.30959272f, 9.55441570f, + 1.50396204f, 1.77370787f, 7.38146257f, -1.79076433f, 3.20961165f, 7.18864202f, 2.91217351f, 0.43018937f, 7.11078024f, -1.17386127f, -0.16817921f, 6.12327290f, -2.82205725f, 3.30696845f, 13.51291752f, -1.30856836f, -2.38332748f, 11.09487438f, -1.47190213f, -0.53050828f, 4.38285351f, -5.07309771f, 1.50714362f, 5.72274446f, -2.85825086f, + -0.89673209f, 3.73791552f, -0.67708802f, -4.13149452f, -0.00671843f, -0.26566532f, 0.32961160f, 7.14501762f, -1.41608179f, -4.96590328f, 12.26205540f, -0.65158135f, -0.88641000f, 6.95777559f, -0.79058206f, -0.10260171f, 7.87169170f, 1.35921454f, 1.11759663f, 5.46187401f, -2.57214499f, 2.48484039f, 4.04043484f, -2.07137156f, -1.42709637f, + 9.25487137f, -0.12605135f, -2.66949964f, 2.89412403f, 0.74451172f, -2.96250391f, 3.99258423f, 0.27084303f, 0.32213116f, 5.42332172f, -0.44414216f, 1.70881832f, 6.69346905f, 0.53058422f, -4.73146200f, 4.22051668f, 2.24834967f, 0.66996074f, 4.30173683f, 0.11849818f, -4.07520294f, 8.27318478f, -2.54398274f, -2.86705542f, 10.11775303f, + -0.99382895f, 0.65881538f, 7.93556786f, -1.27934420f, -1.69343162f, 9.68042564f, -1.02609646f, -1.18189347f, 5.75370646f, -1.67888868f, -4.48871994f, 4.79537392f, -0.79212248f, -0.19855022f, 6.15060997f, -0.01081491f, 3.64454579f, 10.82562447f, 1.58859253f, -2.65847278f, 8.60093212f, -1.59196103f, 0.07635692f, 11.76175690f, -1.17453325f, + 0.10122013f, 6.86458445f, -2.18891335f, -2.74004745f, 8.07066154f, 0.71818852f, -2.03035975f, 6.31053686f, 0.51509416f, 1.39789927f, 9.43515587f, 2.04256630f, 0.13985133f, 4.65010691f, 2.40911126f, -0.36255789f, -3.06867862f, -0.45225358f, -1.56778407f, 6.05917358f, -1.09891272f, 1.77184200f, 6.46248102f, 0.96042323f, -0.24346280f, + 4.63436460f, -4.69907761f, 1.25187206f, 11.46173859f, -2.21917558f, 1.28007793f, 6.92173195f, 2.11268163f, -3.47389889f, 5.08722782f, -3.03950930f, -4.17154264f, 11.30568314f, 0.80361372f, 2.53214502f, 7.18707085f, -4.49114513f, 2.85449266f, 10.14906883f, -0.31974933f, -0.84472644f, -0.52459574f, 0.12921631f, -1.81390119f, 2.76170087f, 1.03982210f, 2.91744232f, -0.29048753f, 5.87453508f, -1.53684759f, 1.85800636f, -0.91404629f, 1.28954852f, 5.11354685f, -2.47475505f, -1.33179152f, 2.58552408f, 1.37316465f, -3.32339454f, 1.54122913f, 3.24953628f, -0.29758382f, 2.82391763f, -1.51142192f, -1.22699404f, 6.75745535f, 0.65452754f, -3.29385471f, 2.06008053f, 2.53172946f, -4.23532820f, -1.53909743f, -0.07010663f, -1.42173731f, 7.29031610f, -0.18448229f, 4.59496164f, 6.73027277f, 0.73441899f, 0.14426160f, 4.14915276f, -2.97010231f, 6.05851364f, 4.95218086f, -2.39145470f, 2.40494704f, 2.10288811f, 0.53503096f, 1.44511235f, 6.66344261f, -3.05803776f, 7.21418667f, 3.30303526f, -0.24163735f, 3.47409391f, 3.64520788f, 2.15189481f, -3.11243272f, 3.62310791f, 0.37379482f, 0.40865007f, -0.83132005f, -4.78246069f, 2.07030797f, 6.51765442f, 3.16178989f, 5.06180477f, 3.78434467f, -0.96689719f, 0.35965276f, 5.89967585f, 1.40294051f, 1.11952639f, 10.59778214f, 0.26739889f, -1.61297631f, 6.24801159f, -0.93914318f, -0.57812452f, 9.92604542f, -0.73025000f, -3.38530874f, 2.45646000f, -2.47949195f, 0.51638460f, 10.65636063f, 1.97816694f, -3.00407791f, 2.66914415f, -0.81951088f, -0.23316640f, 2.40737987f, -2.70007610f, 1.51531935f, 4.08860207f, -0.27552786f, -1.31721711f, 7.11568260f, -3.33498216f, -4.02545023f, 7.22675610f, -0.81690705f, -2.52689576f, 1.04016697f, -0.79291463f, -0.34875512f, 10.00498390f, -4.24167728f, 1.46162593f, 11.82569408f, -1.70359993f, -0.30161047f, 16.44085884f, -0.82253462f, -0.09435523f, 6.13080597f, -0.20259480f, 0.68308711f, 6.15663004f, -6.61776876f, 0.33295766f, 2.55449438f, -0.17819691f, -1.14892209f, 5.56776142f, 1.99279118f, 1.33035934f, 4.45823956f, 3.34916544f, -2.59905386f, 6.16164446f, -2.03881931f, -2.45273542f, 12.46793365f, -2.22743297f, 2.83738565f, 8.48628139f, -1.39347959f, -1.30867767f, 11.08041477f, -4.00363779f, 2.09183025f, 11.30395889f, -2.20504737f, 1.37426853f, 8.98735619f, 1.04676604f, -0.72757077f, 8.28050232f, -6.70741081f, -0.65798020f, 5.68592072f, -0.60760021f, 0.35854483f, 6.26852131f, 1.94100165f, 1.32112014f, 0.80987954f, -1.74617672f, -0.25434083f, 7.16045523f, 1.58884013f, -2.64847064f, 13.14820385f, 1.21393633f, -2.47258949f, 9.41650105f, -0.79384226f, 2.48954105f, 10.95629311f, 0.47723705f, 4.02126694f, 8.02593136f, -2.20726371f, -1.18794477f, 1.50836647f, 0.93118095f, -1.73513174f, 8.85493565f, -2.99670315f, -0.79055870f, 2.39473820f, 2.05046916f, -2.38055134f, 11.82299423f, 0.15609655f, 0.68744308f, 5.66401434f, -0.69281673f, 2.09855556f, 7.74626589f, -0.34283102f, 1.00542057f, 9.95838642f, 0.80161905f, 2.33455157f, 9.80057335f, -0.93561798f, 2.56991577f, 8.29711342f, 0.94213426f, 0.44209945f, 11.70259857f, 0.92710167f, 2.60957146f, 0.24971688f, -0.86529571f, 3.78628922f, 6.80884457f, -0.68178189f, 2.21103406f, 3.18895817f, 0.60283208f, -2.92716241f, 6.72060776f, -1.06625068f, 2.56543374f, 9.97404480f, 3.58080721f, -0.94936347f, 10.16736984f, -1.38464379f, 1.18191063f, 6.66179037f, -3.56115270f, 0.32329530f, 10.90870762f, 2.20638227f, 0.19653285f, 7.34650040f, -3.63859272f, -1.03027737f, 5.98829985f, -3.66606474f, -3.89746714f, 8.63469028f, 1.22569811f, 1.63240814f, 3.74385309f, 0.58243257f, -0.56981975f, 3.69260955f, 1.00979900f, -1.44030499f, 8.57058144f, -1.10648811f, 1.20474911f, 5.43133020f, -2.14822555f, -0.07928789f, 11.25825310f, 0.19645604f, -5.49546146f, 10.41917038f, -0.68178523f, -2.99639869f, 6.50054455f, 0.46488351f, -5.42328453f, 9.09500027f, -2.82107449f, 0.05601966f, 15.34610748f, -0.06820253f, 3.86699796f, 10.73316956f, -3.04795432f, -0.14702171f, 5.64813185f, 1.44028485f, -2.47596145f, 0.07280898f, -3.03187990f, -1.35183525f, 9.35835648f, 2.72966957f, 1.88199532f, 10.36187744f, -0.22834805f, -3.26738238f, 6.92025137f, -2.34061313f, 4.77379704f, 5.28559113f, -2.96323752f, -1.76186585f, 5.94436455f, 0.38647744f, -5.73869514f, 6.76849556f, 1.40892124f, -1.19068217f, 5.37919092f, -6.65328646f, 3.62782669f, 12.34744644f, 2.44762444f, -4.19242620f, 6.14906216f, 0.08121119f, 0.61355996f, 2.69666457f, -1.88962626f, -0.55314136f, 1.84937525f, 1.56048691f, 1.17460012f, 3.75674725f, 1.06198275f, -5.74625874f, 5.41645575f, -1.28946674f, -1.51689398f, 4.32400894f, -0.05222082f, -4.83948946f, 1.80747867f, 1.63144708f, -2.73887825f, 1.63975775f, -2.02163982f, -0.16210437f, 2.93518686f, 1.14427686f, -2.83246303f, 4.79283667f, 2.69697428f, -3.12678456f, -1.19225168f, -2.37022972f, -3.09429741f, 1.94225383f, -1.13747168f, -2.55048585f, 5.40242243f, 1.12777328f, 3.43713188f, 3.62658787f, -2.16878843f, 0.30164462f, 2.97407579f, -0.07275413f, -1.31149673f, 4.70066261f, -2.01323795f, 4.85255766f, 4.59128904f, 1.68084168f, 1.60336494f, 6.58138466f, -1.04759812f, 2.69906545f, 3.55769277f, -0.74327278f, 2.65819693f, 5.39528131f, 2.11248922f, -1.06446671f, 5.24546766f, -2.43146014f, 4.58907509f, 0.06521678f, -2.24503994f, 2.45722699f, 6.94863081f, 0.35258654f, 2.83396196f, 9.92525196f, -1.12225175f, -0.34365177f, 7.19116688f, -4.39813757f, 0.46517885f, 13.22028065f, -2.57483673f, -6.37226963f, 7.58046293f, -2.74600363f, 0.42231262f, 8.04881668f, 0.17289802f, -0.53447008f, 16.55157471f, -5.63614368f, 0.39288223f, 3.37079263f, 1.26484549f, -0.12820500f, 8.46440125f, -4.39304399f, 2.97676420f, 0.65650189f, 0.83158541f, -1.11556435f, 6.32885838f, -0.36087769f, 2.80724382f, 9.90292645f, 1.15936041f, 0.20947981f, 6.91249275f, -2.67404819f, 2.93782163f, 6.65656614f, -2.30828357f, 2.98214006f, 6.80611229f, -4.93821478f, -7.66555262f, 7.59763002f, -0.54159302f, 3.87403512f, 12.42607784f, 2.59284401f, -0.23375344f, 8.95293331f, -0.71807784f, 0.61873478f, 8.66713524f, 1.24289191f, -2.37835455f, 2.08071637f, -0.88315344f, -3.41891551f, 6.85245323f, 1.73007369f, 1.02169311f, 7.69170332f, -2.85411978f, 2.69790673f, 8.12906551f, -1.19351399f, -2.26442742f, 12.26104450f, -0.75579089f, -1.73274946f, 10.68729019f, 2.20655656f, -0.90522075f, 12.42165184f, -1.67929137f, 2.44851565f, 9.31565762f, -0.06645700f, 1.52762020f, 6.18427515f, -1.68882596f, 3.70261097f, 3.02252960f, -3.44125366f, -1.31575799f, 2.84617424f, -0.96849400f, -4.52356243f, 9.95027161f, 0.19966406f, -0.78874779f, 8.18595028f, -4.08300209f, 1.75126517f, 0.96418417f, -4.04913044f, -0.95200396f, 12.03637886f, -0.03041124f, 0.41642749f, 8.88267422f, -3.24985337f, -2.24919462f, 7.32566118f, 0.16964148f, -2.74123430f, 7.05264473f, -3.30191112f, 0.17163286f, 4.81851053f, -1.64463484f, -0.85933101f, 7.29276276f, 2.34066939f, -2.14860010f, 3.46148157f, -0.01782012f, 1.51504040f, 4.79304934f, 1.85281146f, -1.70663762f, 6.93470192f, -4.15440845f, -1.25983095f, 10.52491760f, 0.42930329f, -1.85146868f, 11.70042324f, -0.41704914f, 3.83796859f, 9.21148491f, -2.79719448f, 0.79470479f, 6.26926661f, -5.85230207f, 3.95105338f, 7.84790897f, -1.38680744f, -1.78099084f, 11.95235348f, -2.99841452f, -1.34507811f, 6.15714645f, -1.07552516f, -2.81228638f, 1.66234732f, -4.55166149f, -1.92601109f, 8.64634514f, -0.48158705f, 3.31595659f, 7.67371941f, 2.56964207f, 0.12107098f, 4.56467867f, -0.93541539f, 1.39432955f, 11.99714088f, 1.05353570f, -2.13099813f, 3.67617917f, 3.45895386f, 1.37365830f, 8.74344158f, -4.17585802f, 1.43908918f, 6.28764772f, 3.97346330f, -0.69144285f, 9.07983303f, -0.41635889f, -0.14965028f, 8.85469818f, 1.11306190f, 2.59440994f, 5.38982344f, -1.07948279f, 1.37252975f, 10.26984596f, -0.09318046f, 2.73104119f, 12.45902252f, -1.55446684f, -2.76124811f, 12.19395065f, -0.51846564f, 1.02764034f, 11.42673588f, -0.95940983f, -0.04781032f, 8.78379822f, -4.88957930f, 0.32534006f, 11.97696400f, -3.35108662f, 1.95104563f, 4.46915388f, -2.32061648f, 3.45230985f, 8.29983711f, 2.81034684f, -2.35529327f, 6.07801294f, -0.98105043f, -0.05359888f, 2.52291036f, -0.01986909f, -2.35321999f, 10.51954269f, 2.11145401f, 3.53506470f, 7.29093266f, 0.03721160f, -1.13496494f, 7.43886709f, -5.84201956f, 2.50796294f, 12.14647675f, 2.77490377f, -2.18896222f, 6.05641937f, 5.32617044f, 1.04221284f, 10.79106712f, -2.95749092f, -2.75414610f, 11.30037117f, -3.40654182f, -2.24673963f, 7.49126101f, 0.70811015f, -6.18003702f, 13.83951187f, -1.01204085f, 1.36298490f, -1.04451632f, 2.42435336f, -0.02346706f, -0.85528886f, 1.04731262f, 0.22192979f, 4.15708160f, 0.34933877f, 0.04814529f, 2.24107265f, 0.49676740f, -1.47752666f, 0.45040059f, -0.70471478f, -1.19759345f, 0.21711677f, 0.88461423f, -2.76830935f, 5.52066898f, 1.97664857f, -1.75381601f, 3.45877838f, 1.52617192f, -1.61350942f, 0.85337949f, 1.97610760f, -3.40310287f, 3.40319014f, -3.38691044f, -0.71319139f, 1.65463758f, -0.60680127f, -1.80700517f, 8.02592373f, 2.59627104f, 2.65895891f, 5.93043184f, -4.48425817f, 3.92670918f, 4.19496679f, -2.28286791f, 6.41634607f, 5.72330523f, 1.16269672f, -0.28753027f, 2.46342492f, 0.36693189f, 0.26712441f, 6.37652683f, -2.50139046f, 2.43923736f, 5.56310415f, 0.98065847f, 1.04267502f, 4.16403675f, -0.04966142f, 4.40897894f, 3.72905660f, -3.46129870f, 3.59962773f, 1.34830284f, -1.76661730f, 0.47943926f, 5.29946661f, -1.12711561f, 1.26970029f, 15.17655945f, -1.50971997f, 5.81345224f, 8.48562050f, -4.36049604f, 2.48144460f, 8.23780441f, -3.46030426f, -0.84656560f, 5.94946814f, 1.12747943f, -2.65683913f, 8.69085693f, 1.31309867f, -2.79958344f, 8.76840591f, -1.56444156f, 1.62710834f, 2.41177034f, -0.72804940f, 5.70619011f, 4.67169666f, -0.86167198f, -1.83803177f, 2.96346045f, 2.82692933f, -2.81557131f, 7.11113358f, -1.90071094f, 2.54244423f, 11.19284058f, -0.06298946f, -1.71517313f, 12.98388577f, 0.84510714f, 3.00816894f, 2.57200313f, 0.03899818f, -1.49330592f, 9.60099125f, -3.59513044f, -1.30045319f, 7.09241819f, -0.65233821f, -2.33627677f, 8.81366920f, 0.84154201f, 1.03312039f, 9.85289097f, 0.19351870f, 1.78496623f, 7.34631205f, -2.16530800f, -0.65016162f, 2.46842360f, 0.24016285f, -1.24308395f, 4.78175163f, -0.97682536f, 2.20942235f, 6.68382788f, 3.76786447f, -1.44454038f, 6.26453733f, -3.23575711f, -2.30137897f, 9.53092670f, -5.55222607f, 3.25999236f, 9.37559509f, 1.86339056f, -0.23551451f, 10.23400211f, 3.93031883f, -0.52629089f, 7.85724449f, -2.91549587f, 4.46612740f, 5.66530371f, -2.70820427f, 4.81359577f, 10.31247330f, 1.92230141f, 2.53931546f, 0.74986327f, 1.70303428f, 0.48063779f, 5.31099129f, -0.78976244f, 3.75864220f, 4.23051405f, 2.34042454f, -7.98193836f, 9.83987141f, -1.46722627f, 3.54497814f, 10.36455154f, -4.51249075f, 0.77715248f, 7.78694630f, -4.59989023f, -2.49585629f, 9.90296268f, 1.38535416f, 1.17441154f, 10.10452843f, -0.98628229f, 0.60194463f, 9.12639141f, -3.90754628f, 2.88526392f, 7.24123430f, -0.15283313f, -0.75728363f, -1.15116858f, -2.53791571f, 0.77229571f, 6.44114161f, 0.02646767f, 4.95463037f, 7.21066380f, 1.79384065f, 0.73250306f, 8.04447937f, 0.32576546f, -0.79447043f, 10.12717724f, 2.33392906f, 1.30716443f, 12.36073112f, -0.36694977f, -1.20438910f, 7.03105593f, 0.59557682f, 0.69267452f, 10.18113136f, 2.49944925f, -0.42229167f, 8.83143330f, -1.18805945f, -2.87509322f, 4.53596449f, 4.09732771f, -3.39088297f, -1.02536607f, 0.82119560f, -3.47302604f, 9.29991817f, 0.21001509f, 4.97036457f, 9.50018406f, 1.04420102f, 1.96560478f, 10.74769592f, -6.22709799f, 3.11690164f, 5.06759691f, -1.23724771f, -3.05831861f, 8.12925529f, -1.93435478f, -1.10151744f, 9.32263088f, -0.04249470f, -5.98547363f, 10.49398136f, 0.26400441f, -0.78915191f, 13.28219604f, 2.99276900f, 0.74853164f, 2.49364305f, -3.43529654f, 4.05278301f, 2.13498688f, -2.35444307f, -0.79900265f, 4.66968822f, -0.31095147f, 3.60674143f, 12.37222099f, -0.07855003f, -3.30292702f, 12.15215874f, 0.60886210f, 2.87075138f, 7.75271845f, 0.38044083f, 3.34402204f, 6.40583277f, -0.87888050f, 0.67438459f, 6.91080809f, 1.98332930f, -0.08303714f, 8.08630371f, -0.16772588f, -2.74058914f, 7.17253590f, -2.69122696f, 1.48173678f, 8.99470139f, -1.43302310f, -0.88651133f, 2.66944790f, -0.29186964f, 2.00838661f, 5.09587479f, -0.76676071f, -2.88322186f, 8.31110573f, -0.14550979f, -1.37726915f, 10.28355122f, -1.60575438f, -0.04118848f, 9.97510815f, 0.14440438f, -3.24632120f, 9.00034523f, 4.14319563f, -1.31023729f, 7.16950464f, -0.70428526f, 2.01559544f, 7.26155043f, 2.40816474f, 2.09847403f, 7.31264496f, -0.75401551f, 2.13392544f, 7.03648758f, 1.04036045f, -1.15636516f, 1.09634531f, -0.06340861f, -0.58107805f, -0.65623116f, 1.18972754f, -0.80717683f, 1.40118241f, -0.61932516f, -3.60596156f, 1.59904599f, -2.23774099f, -1.13721037f, 3.89620137f, -0.09115922f, -7.51356888f, 2.36975193f, -1.42520905f, -2.34173775f, 3.33830214f, -2.74016523f, -3.04115510f, 6.00119495f, -1.36084354f, -2.45065260f, 4.56992292f, -3.02825928f, -3.74182844f, 5.11069250f, -0.91531068f, -2.31385994f, 1.83399653f, 3.39370203f, -3.60886002f}); + auto exp = NDArrayFactory::create('c', {4, 4, 4, 3}, {7.97172260f, 0.06878620f, 2.27749538f, 7.29276514f, -0.14074677f, 0.65480286f, 5.70313978f, -0.06546132f, 0.35443667f, 3.70382833f, -0.84020567f, 0.63826996f, 8.60301399f, -0.38236514f, 1.55177069f, 7.37542057f, -0.99374938f, -0.29971302f, 8.84352493f, -0.67121059f, 0.43132120f, 4.78175592f, -1.25070143f, -1.91523600f, 6.03855371f, -0.00292124f, -1.11214364f, 7.90158176f, -0.57949901f, -0.96735370f, 7.81192017f, -0.53255427f, -0.48009714f, 3.16953635f, 0.08353355f, -1.54299748f, 3.74821687f, 1.69396687f, 0.72724354f, 5.42915201f, -1.13686812f, -0.71793109f, 5.78376389f, -0.72239977f, -0.60055625f, 2.53636408f, 0.56777251f, -2.07892323f, 6.08064651f, 0.68620735f, 2.54017019f, 5.65828180f, -0.68255502f, 1.47283304f, 6.10842514f, -0.39655915f, 0.28380761f, 1.96707797f, -1.98206317f, 0.94027776f, 4.71811438f, 0.32104525f, -0.92409706f, 8.34588146f, -1.05581069f, -0.55217457f, 9.58440876f, -0.96549922f, 0.45820439f, 5.65453672f, -2.50953507f, -0.71441835f, 8.03059578f, -0.21281289f, 0.92125505f, 9.26900673f, -0.35963219f, -0.70039093f, 8.59924412f, -1.22358346f, 0.81318003f, 3.85920119f, -0.01305223f, -1.09234154f, 6.33158875f, 1.28094780f, -1.48926139f, 4.94969177f, -0.77126902f, -1.97033751f, 5.64381838f, -0.16285487f, -1.31277227f, 2.39893222f, -1.32902908f, -1.39609122f, 6.47572327f, -0.45267010f, 1.55727172f, 6.70965624f, -1.68735468f, -0.05672536f, 7.25092363f, -0.64613032f, 0.67050058f, 3.60789680f, -2.05948973f, 2.22687531f, 8.15202713f, -0.70148355f, 1.28314006f, 8.14842319f, -1.88807654f, -1.04808438f, 8.45500565f, -0.76425624f, 0.94542569f, 4.56179953f, -0.28786001f, -2.04502511f, 8.46278095f, -0.31019822f, 0.07339200f, 9.34214592f, -0.61948007f, 0.52481830f, 8.32515621f, -1.52418160f, 0.49678251f, 5.11082315f, -1.09908783f, -0.52969611f, 5.27806664f, 0.88632923f, 0.66754371f, 4.75839233f, 0.48928693f, -0.68036932f, 6.56925392f, -0.02949905f, -2.99189186f, 4.46320581f, -0.64534980f, -0.29516968f, 8.60809517f, -1.13120568f, 3.41720533f, 5.84243155f, -1.24109328f, 0.89566326f, 5.99578333f, -0.42496428f, 2.07076764f, 3.17812920f, -0.81566459f, -0.14363396f, 6.55184317f, 0.39633346f, -0.43852386f, 8.70214558f, -2.24613595f, 0.30708700f, 8.73882294f, -0.53545928f, 1.54409575f, 4.49452257f, -0.16509305f, 0.19028664f, 8.24897003f, 0.44750381f, 2.15448594f, 8.97640514f, -0.77728152f, 0.57272542f, 9.03467560f, 0.47173575f, -1.10807717f, 3.30056310f, -0.43268481f, -0.41470885f, 3.53798294f, -0.08546703f, -2.16840744f, 6.18733406f, -0.17871059f, -2.59837723f, 5.94218683f, -1.02990067f, -0.49760687f, 3.76938033f, 0.86383581f, -1.91504073f}); + + sd::ops::avgpool2d op; + auto result = op.evaluate({&input}, {3,3, 3,3, 0,0, 1,1,1, 0,1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + // z->printIndexedBuffer("z"); + // exp.printIndexedBuffer("e"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, avgpool2d_11) { + int inOutH = 5;// 35; + int inOutW = 5;// 35; + int inOutC = 10;// 192; + + auto x = NDArrayFactory::create('c', {1, inOutH, inOutW, inOutC}); + x.linspace(1.0); + + sd::ops::avgpool2d op; + auto result = op.evaluate({&x}, {3,3, 1,1, 0,0, 1,1, 1, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + int totalPadHeight = (inOutH - 1) * 1 + 3 - inOutH; + int padTop = totalPadHeight / 2; + int padBottom = totalPadHeight - totalPadHeight / 2; + + int k = 3; + + auto m = NDArrayFactory::create('c', {1, inOutH, inOutW, inOutC}); + auto c = NDArrayFactory::create('c', {1, inOutH, inOutW, inOutC}); + + for (int h = 0; h < inOutH; h++) { + for (int w = 0; w < inOutW; w++) { + int hFrom = h - padTop; + int wFrom = w - padBottom; + + int hTo = hFrom + k; + int wTo = wFrom + k; + + hFrom = sd::math::nd4j_max(0, hFrom); + wFrom = sd::math::nd4j_max(0, wFrom); + + hTo = sd::math::nd4j_min(inOutH, hTo); + wTo = sd::math::nd4j_min(inOutW, wTo); + + int idxOut[4]; + int idxIn[4]; + for (int ch = 0; ch < inOutC; ch++) { + idxOut[1] = h; + idxOut[2] = w; + idxOut[3] = ch; + idxIn[3] = ch; + + for (int kh = hFrom; kh < hTo; kh++) { + for (int kw = wFrom; kw < wTo; kw++) { + idxIn[1] = kh; + idxIn[2] = kw; + + auto inVal = x.e(0, kh, kw, ch); + m.p(0, h, w, ch, inVal + m.e(0, h, w, ch)); + c.p(0, h, w, ch, 1 + c.e(0, h, w, ch)); + } + } + } + } + } + m /= c; + + ASSERT_EQ(m, *z); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, avgpool2d_12) { + + int bS=4, iH=10,iW=10, iC=3, kH=3,kW=3, sH=3,sW=3, pH=0,pW=0, dH=1,dW=1; + int oH=4, oW=4; + int paddingMode = 1; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NHWC, 0-NDHW + + auto input = NDArrayFactory::create('c', {bS, iH, iW, iC}); + auto expected = NDArrayFactory::create('c', {bS, oH, oW, iC}, { 17.5, 18.5, 19.5, 25. , 26. , 27. , 34. , 35. , 36. , 41.5, 42.5, 43.5, 92.5, 93.5, 94.5, 100. , 101. , 102. , 109. , 110. , 111. , 116.5, 117.5, 118.5, + 182.5, 183.5, 184.5, 190. , 191. , 192. , 199. , 200. , 201. , 206.5, 207.5, 208.5, 257.5, 258.5, 259.5, 265. , 266. , 267. , 274. , 275. , 276. , 281.5, 282.5, 283.5, + 317.5, 318.5, 319.5, 325. , 326. , 327. , 334. , 335. , 336. , 341.5, 342.5, 343.5, 392.5, 393.5, 394.5, 400. , 401. , 402. , 409. , 410. , 411. , 416.5, 417.5, 418.5, + 482.5, 483.5, 484.5, 490. , 491. , 492. , 499. , 500. , 501. , 506.5, 507.5, 508.5, 557.5, 558.5, 559.5, 565. , 566. , 567. , 574. , 575. , 576. , 581.5, 582.5, 583.5, + 617.5, 618.5, 619.5, 625. , 626. , 627. , 634. , 635. , 636. , 641.5, 642.5, 643.5, 692.5, 693.5, 694.5, 700. , 701. , 702. , 709. , 710. , 711. , 716.5, 717.5, 718.5, + 782.5, 783.5, 784.5, 790. , 791. , 792. , 799. , 800. , 801. , 806.5, 807.5, 808.5, 857.5, 858.5, 859.5, 865. , 866. , 867. , 874. , 875. , 876. , 881.5, 882.5, 883.5, + 917.5, 918.5, 919.5, 925. , 926. , 927. , 934. , 935. , 936. , 941.5, 942.5, 943.5, 992.5, 993.5, 994.5,1000. , 1001. , 1002. ,1009. , 1010. , 1011. ,1016.5, 1017.5, 1018.5, + 1082.5, 1083.5, 1084.5,1090. , 1091. , 1092. ,1099. , 1100. , 1101. ,1106.5, 1107.5, 1108.5,1157.5, 1158.5, 1159.5,1165. , 1166. , 1167. ,1174. , 1175. , 1176. ,1181.5, 1182.5, 1183.5}); + input.linspace(1.); + + sd::ops::avgpool2d op; + auto results = op.evaluate({&input}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 0, dataFormat}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + //output->printIndexedBuffer("output"); + //expected.printIndexedBuffer("expected"); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, avgpool2d_13) { + + const int bS = 2; // batch size + const int iD = 1; // input depth (number of picture channels, for example rgb=3) + const int iH = 28; // picture height in pixels + const int iW = 28; // picture width in pixels + const int kH = 5; // kernel height in pixels + const int kW = 5; // kernel width in pixels + const int sH = 1; // stride step in horizontal direction + const int sW = 1; // stride step in vertical direction + const int pH = 0; // padding height + const int pW = 0; // padding width + const int dH = 2; // dilation height + const int dW = 2; // dilation width + const int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; // output height + const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width + + auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW}); + auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); + // auto z('c',{bS,iD,oH,oW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + std::vector* argI = block->getIArguments(); + *argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + sd::ops::avgpool2d pooling; + Nd4jStatus status = pooling.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + ASSERT_TRUE(exp.isSameShape(result)); + + + delete variableSpace; + delete block; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, avgpool2d_14) { + const int bS = 2; + const int iD = 1; + const int iH = 28; + const int iW = 28; + const int kH = 5; + const int kW = 5; + const int sH = 1; + const int sW = 1; + const int pH = 0; + const int pW = 0; + const int dH = 1; + const int dW = 1; + const int oH = (iH - kH - (kH-1)*(dH-1) + 2*pH)/sH + 1; // output height + const int oW = (iW - kW - (kW-1)*(dW-1) + 2*pW)/sW + 1; // output width + + + auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW}); + auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); + // auto z('c',{bS,iD,oH,oW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + std::vector* argI = block->getIArguments(); + *argI = {kH,kW, sH,sW, pH,pW, dW,dH, 0, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + sd::ops::avgpool2d pooling; + Nd4jStatus status = pooling.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + // result->printShapeInfo(); + ASSERT_TRUE(exp.isSameShape(result)); + + delete variableSpace; + delete block; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, Avgpool2d_test15) { + const int bS = 2; + const int iD = 1; + const int iH = 28; + const int iW = 28; + const int kH = 5; + const int kW = 5; + const int sH = 1; + const int sW = 1; + const int pH = 0; + const int pW = 0; + const int dH = 1; + const int dW = 1; + const int oH = (int) sd::math::nd4j_ceil(iH * 1.f / sH); + const int oW = (int) sd::math::nd4j_ceil(iW * 1.f / sW); + + + auto x = NDArrayFactory::create_('c', {bS,iD,iH,iW}); + auto exp = NDArrayFactory::create('c',{bS,iD,oH,oW}); + // auto z('c',{bS,iD,oH,oW}); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, x); + // variableSpace->putVariable(1, &z); + + auto block = new Context(1, variableSpace, false); + block->fillInputs({-1}); + std::vector* argI = block->getIArguments(); + *argI = {kH,kW, sH,sW, pH,pW, dW,dH, 1, 0, 0}; // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + sd::ops::avgpool2d pooling; + Nd4jStatus status = pooling.execute(block); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = variableSpace->getVariable(block->getNodeId())->getNDArray(); + // result->printShapeInfo(); + ASSERT_TRUE(exp.isSameShape(result)); + + delete variableSpace; + delete block; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, avgpool2d_16) { + + int bS=2, iH=4,iW=4, iC=2, kH=2,kW=2, sH=2,sW=2, pH=0,pW=0, dH=1,dW=1; + int oH=2,oW=2; + int paddingMode = 1; // 1-SAME, 0-VALID + int dataFormat = 1; // 1-NHWC, 0-NDHW + + NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32); + NDArray output('f', {bS, oH, oW, iC}, sd::DataType::FLOAT32); + NDArray expected('c', {bS, oH, oW, iC}, {6.f, 7.f, 10.f, 11.f, 22.f, 23.f, 26.f, 27.f, 38.f, 39.f, 42.f, 43.f, 54.f, 55.f, 58.f, 59.f}, sd::DataType::FLOAT32); + + input.linspace(1.); + + sd::ops::avgpool2d op; + auto status = op.execute({&input}, {&output}, {}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, 0, dataFormat}, {}); + + ASSERT_EQ(Status::OK(), status); + + // output.printBuffer(); + //expected.printIndexedBuffer("expected"); + + ASSERT_TRUE(expected.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, biasadd_1) { + auto x = NDArrayFactory::create('c', {2, 3, 3, 2}); + auto bias = NDArrayFactory::create('c', {2}, {1, 2}); + auto exp = NDArrayFactory::create('c', {2, 3, 3, 2}, {1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f, 1.f, 2.f}); + + sd::ops::biasadd op; + auto result = op.evaluate({&x, &bias}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests4, biasadd_2) { + auto x = NDArrayFactory::create('c', {2, 2, 3, 3}); + auto bias = NDArrayFactory::create('c', {2}, {1, 2}); + auto exp = NDArrayFactory::create('c', {2, 2, 3, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2}); + + sd::ops::biasadd op; + auto result = op.evaluate({&x, &bias}, {}, {}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests4, biasadd_3) { + auto x = NDArrayFactory::create('c', {2, 3}); + auto row = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 1, 2, 3}); + + sd::ops::biasadd op; + auto result = op.evaluate({&x, &row}, {}, {}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, biasadd_bp_1) { + + NDArray x('c', {2,2,2,3}, {1.,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}, sd::DataType::FLOAT32); + NDArray gradO('c', {2,2,2,3}, sd::DataType::FLOAT32); + NDArray bias('c', {3}, {-1., -2, -3}, sd::DataType::FLOAT32); + + NDArray expGradB('c', {3}, {9.2, 10. , 10.8}, sd::DataType::FLOAT32); + + gradO.linspace(0.1, 0.1); + + sd::ops::biasadd_bp op; + auto result = op.evaluate({&x, &bias, &gradO}, {}, {}, {false}); // NHWC + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto gradI = result.at(0); + auto gradB = result.at(1); + + ASSERT_TRUE(gradI->isSameShape(gradO)); + ASSERT_TRUE(gradI->equalsTo(gradO)); + + ASSERT_TRUE(gradB->isSameShape(expGradB)); + ASSERT_TRUE(gradB->equalsTo(expGradB)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, biasadd_bp_2) { + + NDArray x('c', {2,3,2,2}, {1.,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}, sd::DataType::FLOAT32); + NDArray gradO('c', {2,3,2,2}, sd::DataType::FLOAT32); + NDArray bias('c', {3}, {-1., -2, -3}, sd::DataType::FLOAT32); + + NDArray expGradB('c', {3}, {6.8, 10., 13.2}, sd::DataType::FLOAT32); + + gradO.linspace(0.1, 0.1); + + sd::ops::biasadd_bp op; + auto result = op.evaluate({&x, &bias, &gradO}, {}, {}, {true}); // NCHW + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto gradI = result.at(0); + auto gradB = result.at(1); + + ASSERT_TRUE(gradI->isSameShape(gradO)); + ASSERT_TRUE(gradI->equalsTo(gradO)); + + ASSERT_TRUE(gradB->isSameShape(expGradB)); + ASSERT_TRUE(gradB->equalsTo(expGradB)); + + +} + +TEST_F(DeclarableOpsTests4, biasadd_4) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + auto x = NDArrayFactory::create('c', {2, 3}); + auto y = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + auto z = NDArrayFactory::create('c', {2, 3}); + auto exp = NDArrayFactory::create('c', {2, 3}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f}); + + sd::ops::biasadd op; + auto status = op.execute({&x, &y}, {&z}, {}, {}, {true}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(exp, z); +} + +TEST_F(DeclarableOpsTests4, Test_Fill_1) { + auto x = NDArrayFactory::create('c', {1, 3}, {3, 2, 4}); + auto v = NDArrayFactory::create(2.); + auto exp = NDArrayFactory::create('c', {3, 2, 4}); + exp.assign(2.0f); + + sd::ops::fill op; + auto result = op.evaluate({&x, &v}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests4, Test_FirasSparce_1) { + auto x = NDArrayFactory::create('c', {1, 81}); + auto exp = NDArrayFactory::create('c', {1, 2}, {0, 1}); + + x.p(51, 1); + x.p(52, 0); + x.p(60, 1); + x.p(61, 0); + sd::ops::firas_sparse op; + auto result = op.evaluate({&x}, {0, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); +// z->printIndexedBuffer("FIRAS"); +// z->printShapeInfo("OUTSHAPE"); +// ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests4, Test_FlattenTests_1) { + auto x = NDArrayFactory::create('c', {3, 3, 3, 3}); + auto exp = NDArrayFactory::create('c', {81}); + + x.linspace(1); + exp.linspace(1); + sd::ops::flatten op; + auto result = op.evaluate({&x}, {}, {'c'}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); +// z->printIndexedBuffer("Flatten1"); +// z->printShapeInfo("Flatten1 shape"); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests4, Test_FlattenTests_2) { + auto x = NDArrayFactory::create('c', {3, 3, 3, 3}); + auto y = NDArrayFactory::create('c', {3, 3}); + auto exp = NDArrayFactory::create('c', {90}); + + x.linspace(1); + y.linspace(82); + exp.linspace(1); + sd::ops::flatten op; + auto result = op.evaluate({&x, &y}, {}, {'c'}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); +// z->printIndexedBuffer("Flatten2"); +// z->printShapeInfo("Flatten2 shape"); + + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests4, Test_FlattenTests_3) { + NDArray x('c', {2,2}, {1, 2, 3, 4}, sd::DataType::INT32); + NDArray y('f', {2,2}, sd::DataType::INT32); + NDArray exp('c', {8}, {1, 2, 3, 4, 1, 2, 3, 4}, sd::DataType::INT32); + + y.assign(x); + + sd::ops::flatten op; + auto result = op.evaluate({&x, &y}, {}, {'c'}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests4, Test_FlattenTests_4) { + NDArray x('c', {2,2}, {1, 2, 3, 4}, sd::DataType::INT32); + NDArray y('f', {2,2}, sd::DataType::INT32); + NDArray exp('c', {8}, {1, 3, 2, 4, 1, 3, 2, 4}, sd::DataType::INT32); + + y.assign(x); + + sd::ops::flatten op; + auto result = op.evaluate({&x, &y}, {}, {'f'}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests4, Test_FloorTests_1) { + auto x = NDArrayFactory::create('c', {3, 3}, {1.5, 2.3, 3.4, 4.3, 5.9, 6.1, 7.2, 8.9, 9.7}); + auto exp = NDArrayFactory::create('c', {3,3}); + + exp.linspace(1); + sd::ops::Floor op; + auto result = op.evaluate({&x}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); +// z->printIndexedBuffer("Flatten1"); +// z->printShapeInfo("Flatten1 shape"); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests4, Test_Split_1) { + auto x = NDArrayFactory::create('c', {5, 30}); + auto sizes = NDArrayFactory::create('c', {1, 3}, {4, 15, 11}); + + std::vector list0({0,0, 0,4}); + std::vector list1({0,0, 4,19}); + std::vector list2({0,0, 19,30}); + + auto sub0 = x(list0, true); + auto sub1 = x(list1, true); + auto sub2 = x(list2, true); + + sub0.assign(0.0); + sub1.assign(1.0); + sub2.assign(2.0); + + + sd::ops::split_v op; + auto result = op.evaluate({&x, &sizes}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_EQ(3, result.size()); + + auto z0 = result.at(0); + auto z1 = result.at(1); + auto z2 = result.at(2); + + ASSERT_TRUE(sub0.isSameShape(z0)); + ASSERT_TRUE(sub1.isSameShape(z1)); + ASSERT_TRUE(sub2.isSameShape(z2)); + + ASSERT_TRUE(sub0.equalsTo(z0)); + ASSERT_TRUE(sub1.equalsTo(z1)); + ASSERT_TRUE(sub2.equalsTo(z2)); + + +} + +// special test for TF mode, when axis goes first +TEST_F(DeclarableOpsTests4, Test_Split_2) { + auto x = NDArrayFactory::create('c', {5, 12}); + auto axis = NDArrayFactory::create('c', {1, 1}, {1.f}); + + std::vector list0 = {0,0, 0,3}; + std::vector list1 = {0,0, 3,6}; + std::vector list2 = {0,0, 6,9}; + std::vector list3 = {0,0, 9,12}; + + auto sub0 = x(list0, true); + auto sub1 = x(list1, true); + auto sub2 = x(list2, true); + auto sub3 = x(list3, true); + + sub0.assign(0.0f); + sub1.assign(1.0f); + sub2.assign(2.0f); + sub3.assign(3.0f); + + + sd::ops::split op; + auto result = op.evaluate({&axis, &x}, {}, {4}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z0 = result.at(0); + auto z1 = result.at(1); + auto z2 = result.at(2); + auto z3 = result.at(3); + + ASSERT_TRUE(sub0.isSameShape(z0)); + ASSERT_TRUE(sub1.isSameShape(z1)); + ASSERT_TRUE(sub2.isSameShape(z2)); + ASSERT_TRUE(sub3.isSameShape(z3)); + + ASSERT_TRUE(sub0.equalsTo(z0)); + ASSERT_TRUE(sub1.equalsTo(z1)); + ASSERT_TRUE(sub2.equalsTo(z2)); + ASSERT_TRUE(sub3.equalsTo(z3)); + + +} + +// special test for TF mode, when axis goes first +TEST_F(DeclarableOpsTests4, Test_Split_3) { + auto x = NDArrayFactory::create('c', {6, 12}); + auto axis = NDArrayFactory::create('c', {1, 1}, {0.f}); + + std::vector list0 = {0,2, 0,0}; + std::vector list1 = {2,4, 0,0}; + std::vector list2 = {4,6, 0,0}; + + auto sub0 = x(list0, true); + auto sub1 = x(list1, true); + auto sub2 = x(list2, true); + + sub0.assign(0.0f); + sub1.assign(1.0f); + sub2.assign(2.0f); + + sd::ops::split op; + auto result = op.evaluate({&axis, &x}, {}, {3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z0 = result.at(0); + auto z1 = result.at(1); + auto z2 = result.at(2); + + ASSERT_TRUE(sub0.isSameShape(z0)); + ASSERT_TRUE(sub1.isSameShape(z1)); + ASSERT_TRUE(sub2.isSameShape(z2)); + + ASSERT_TRUE(sub0.equalsTo(z0)); + ASSERT_TRUE(sub1.equalsTo(z1)); + ASSERT_TRUE(sub2.equalsTo(z2)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, split_test4) { + + auto input = NDArrayFactory::create('c', {10},{1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f}); + auto axis = NDArrayFactory::create(-1); + auto exp1 = NDArrayFactory::create('c', {5}, {1.f,2.f,3.f,4.f,5.f}); + auto exp2 = NDArrayFactory::create('c', {5}, {6.f,7.f,8.f,9.f,10.f}); + + sd::ops::split op; + auto results = op.evaluate({&input, &axis}, {}, {2}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto out1 = results.at(0); + auto out2 = results.at(1); + + ASSERT_TRUE(exp1.isSameShape(out1)); + ASSERT_TRUE(exp2.isSameShape(out2)); + ASSERT_TRUE(exp1.equalsTo(out1)); + ASSERT_TRUE(exp2.equalsTo(out2)); +} + + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, split_test5) { + + auto input = NDArrayFactory::create('c', {3,8},{1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f,11.f,12.f,13.f,14.f,15.f,16.f,17.f,18.f,19.f,20.f,21.f,22.f,23.f,24.f}); + auto exp1 = NDArrayFactory::create('c', {3,4}, {1.f,2.f,3.f,4.f, 9.f,10.f,11.f,12.f, 17.f,18.f,19.f,20.f}); + auto exp2 = NDArrayFactory::create('c', {3,4}, {5.f,6.f,7.f,8.f, 13.f,14.f,15.f,16.f, 21.f,22.f,23.f,24.f}); + + sd::ops::split op; + auto results = op.evaluate({&input}, {}, {2,-1},{}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto out1 = results.at(0); + auto out2 = results.at(1); + + ASSERT_TRUE(exp1.isSameShape(out1)); + ASSERT_TRUE(exp2.isSameShape(out2)); + ASSERT_TRUE(exp1.equalsTo(out1)); + ASSERT_TRUE(exp2.equalsTo(out2)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, split_test6) { + + NDArray input('c', {0,4}, sd::DataType::FLOAT32); + std::vector expShape = {0,1}; + + const int numSplits = 4; + const int axis = 1; + + sd::ops::split op; + auto results = op.evaluate({&input}, {}, {numSplits, axis}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + for (int i = 0; i < numSplits; ++i) + ASSERT_TRUE(results.at(i)->isSameShape(expShape)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, split_test7) { + + NDArray input('c', {0,4}, sd::DataType::FLOAT32); + std::vector expShape = {0,4}; + + const int numSplits = 4; + const int axis = 0; + + sd::ops::split op; + auto results = op.evaluate({&input}, {}, {numSplits, axis}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + for (int i = 0; i < numSplits; ++i) + ASSERT_TRUE(results.at(i)->isSameShape(expShape)); +} + + +TEST_F(DeclarableOpsTests4, Test_Squeeze_args_1) { + auto x = NDArrayFactory::create('c', {2, 1, 1, 1, 2}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {2, 1, 2}, {1, 2, 3, 4}); + + sd::ops::squeeze op; + auto result = op.evaluate({&x}, {}, {1, 3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests4, Test_Squeeze_args_2) { + auto x = NDArrayFactory::create('c', {2, 1, 1, 1, 2}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {2}, {1.f, 3.f}); + auto exp = NDArrayFactory::create('c', {2, 1, 2}, {1, 2, 3, 4}); + + sd::ops::squeeze op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(DeclarableOpsTests4, Test_Squeeze_args_3) { + auto x = NDArrayFactory::create('c', {2, 1, 1, 1, 2}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {2, 1, 2}, {1, 2, 3, 4}); + + sd::ops::squeeze op; + auto result = op.evaluate({&x}, {}, {-2, -3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests4, Test_SpaceToDepth_1) { + auto x = NDArrayFactory::create('c', {1, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto exp = NDArrayFactory::create('c', {1, 1, 1, 12}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + + sd::ops::space_to_depth op; + auto result = op.evaluate({&x}, {}, {2, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests4, Test_SpaceToDepth_2) { + auto x = NDArrayFactory::create('c', {1, 3, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto exp = NDArrayFactory::create('c', {1, 12, 1, 1}, {1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12}); + + sd::ops::space_to_depth op; + auto result = op.evaluate({&x}, {}, {2, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(DeclarableOpsTests4, Test_DepthToSpace_1) { + auto x = NDArrayFactory::create('c', {1, 1, 1, 12}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto exp = NDArrayFactory::create('c', {1, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + + sd::ops::depth_to_space op; + auto result = op.evaluate({&x}, {}, {2, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(DeclarableOpsTests4, Test_DepthToSpace_2) { + auto x = NDArrayFactory::create('c', {1, 12, 1, 1}, {1, 5, 9, 2, 6, 10, 3, 7, 11, 4, 8, 12}); + auto exp = NDArrayFactory::create('c', {1, 3, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + + sd::ops::depth_to_space op; + auto result = op.evaluate({&x}, {}, {2, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests4, Test_DepthToSpace_3) { + auto x = NDArrayFactory::create('c', {4, 4, 16, 16}); + auto exp = NDArrayFactory::create('c', {4, 16, 64, 1}); + + sd::ops::depth_to_space op; + auto result = op.evaluate({&x}, {}, {4, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + + +} + + +TEST_F(DeclarableOpsTests4, Test_Cross_1) { + auto a = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto b = NDArrayFactory::create('c', {3}, {6, 7, 8}); + auto exp = NDArrayFactory::create('c', {3}, {-5, 10, -5}); + + sd::ops::cross op; + auto result = op.evaluate({&a, &b}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(DeclarableOpsTests4, Test_Cross_2) { + auto a = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 1, 2, 3}); + auto b = NDArrayFactory::create('c', {2, 3}, {6, 7, 8, 6, 7, 8}); + auto exp = NDArrayFactory::create('c', {2, 3}, {-5, 10, -5, -5, 10, -5}); + + sd::ops::cross op; + auto result = op.evaluate({&a, &b}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(DeclarableOpsTests4, Test_Cross_3) { + auto a = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto b = NDArrayFactory::create('c', {3, 3}, {2, 3, 4, 7, 6, 5, 6, 3, 2}); + auto exp = NDArrayFactory::create('c', {3, 3}, { -1, 2, -1, -11, 22, -11, -11, 40, -27}); + + sd::ops::cross op; + auto result = op.evaluate({&a, &b}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests4, Test_Add_119) { + auto a = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto b = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 4}, {2, 4, 6, 8}); + + sd::ops::add op; + auto result = op.evaluate({&a, &b}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_EQ(2, z->rankOf()); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests4, Test_TileToShape_1) { + auto x = NDArrayFactory::create('c', {2, 1, 3}); + auto exp = NDArrayFactory::create('c', {2, 4, 3}, {1.f, 2.f, 3.f,1.f, 2.f, 3.f,1.f, 2.f, 3.f,1.f, 2.f, 3.f, + 4.f, 5.f, 6.f,4.f, 5.f, 6.f,4.f, 5.f, 6.f,4.f, 5.f, 6.f}); + x.linspace(1.f); + + sd::ops::tile_to_shape op; + auto result = op.evaluate({&x},{}, {2, 4, 3}); + + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_1) { + auto x = NDArrayFactory::create('c', {3, 4, 5}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {1,3,4,5}); + exp.linspace(1); + + sd::ops::strided_slice op; + auto result = op.evaluate({&x}, {}, {0,0,0,1,0, -999,0,0,0, -999,3,4,5, -999,1,1,1}); + + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_2) { + auto x = NDArrayFactory::create('c', {3, 4, 5}); + auto begin = NDArrayFactory::create('c', {4}, {-999,0,0,0}); + auto end = NDArrayFactory::create('c', {4}, {-999,3,4,5}); + auto stride = NDArrayFactory::create('c', {4}, {-999,1,1,1}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {1,3,4,5}); + exp.linspace(1); + + sd::ops::strided_slice op; + auto result = op.evaluate({&x, &begin, &end, &stride}, {}, {0,0,0,1,0}); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_3) { + int axis = 0; + auto x = NDArrayFactory::create('c', {1}, {10}); + auto begin = NDArrayFactory::create('c', {1}, {axis}); + auto end = NDArrayFactory::create('c', {1}, {axis}); + auto stride = NDArrayFactory::create('c', {1}, {1}); + //x.linspace(1); + //auto exp = NDArrayFactory::create('c', {1,3,4,5}); + //exp.linspace(1); + + sd::ops::strided_slice op; + auto result = op.evaluate({&x, &begin, &end, &stride}, {}, {1,0,0,0,0}); + + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_TRUE(z->isEmpty()); + + +} +TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_4) { + auto x = NDArrayFactory::create('c', {1,3}, {1, 2, 3}); + auto begin = NDArrayFactory::create('c', {2}, {0, 0}); + auto end = NDArrayFactory::create('c', {2}, {0,1}); + auto stride = NDArrayFactory::create('c', {2}, {1,1}); +// x.linspace(1); + auto exp = NDArrayFactory::create('c', {1}, {1}); + //exp.linspace(1); + + sd::ops::strided_slice op; + auto result = op.evaluate({&x, &begin, &end, &stride}, {}, {1,0,1,0,2}); + + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_TRUE(z->lengthOf() == 1); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, parallel_stack_test1) { + + auto x1 = NDArrayFactory::create('c', {2,2,2}); + auto x2 = NDArrayFactory::create('c', {2,2,2}); + auto x3 = NDArrayFactory::create('c', {2,2,2}); + x1.linspace(1); + x2.linspace(9); + x3.linspace(17); + + auto expected = NDArrayFactory::create('c', {3,2,2,2}); + expected.linspace(1); + + sd::ops::parallel_stack op; + auto results = op.evaluate({&x1, &x2, &x3}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, parallel_stack_test2) { + + auto x1 = NDArrayFactory::create('c', {1,2}, {1,2}); + auto x2 = NDArrayFactory::create('c', {1,2}, {3,4}); + auto x3 = NDArrayFactory::create('c', {1,2}, {5,6}); + + auto expected = NDArrayFactory::create('c', {3,1,2}, {1,2,3,4,5,6}); + + sd::ops::parallel_stack op; + auto results = op.evaluate({&x1, &x2, &x3}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, parallel_stack_test3) { + + auto x1 = NDArrayFactory::create('c', {2,1}, {1,2}); + auto x2 = NDArrayFactory::create('c', {2,1}, {3,4}); + auto x3 = NDArrayFactory::create('c', {2,1}, {5,6}); + + auto expected = NDArrayFactory::create('c', {3,2,1}, {1,2,3,4,5,6}); + + sd::ops::parallel_stack op; + auto results = op.evaluate({&x1, &x2, &x3}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} +\ +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, parallel_stack_test4) { + + auto x1 = NDArrayFactory::create('c', {2}, {1,2}); + auto x2 = NDArrayFactory::create('c', {2}, {3,4}); + auto x3 = NDArrayFactory::create('c', {2}, {5,6}); + + auto expected = NDArrayFactory::create('c', {3,2}, {1,2,3,4,5,6}); + + sd::ops::parallel_stack op; + auto results = op.evaluate({&x1, &x2, &x3}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, parallel_stack_test5) { + + auto x1 = NDArrayFactory::create('c', {1}, {1}); + auto x2 = NDArrayFactory::create('c', {1}, {3}); + auto x3 = NDArrayFactory::create('c', {1}, {5}); + + auto expected = NDArrayFactory::create('c', {3,1}, {1,3,5}); + + sd::ops::parallel_stack op; + auto results = op.evaluate({&x1, &x2, &x3}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, parallel_stack_test6) { + + auto x1 = NDArrayFactory::create(1.); + auto x2 = NDArrayFactory::create(3.); + auto x3 = NDArrayFactory::create(5.); + + auto expected = NDArrayFactory::create('c', {3}, {1,3,5}); + + sd::ops::parallel_stack op; + auto results = op.evaluate({&x1, &x2, &x3}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, parallel_stack_test7) { + + auto x1 = NDArrayFactory::create(1.); + auto expected = NDArrayFactory::create('c', {1}, {1.}); + + sd::ops::parallel_stack op; + auto results = op.evaluate({&x1}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, meshgrid_test1) { + + auto in0 = NDArrayFactory::create('c', {2}, {1, 2}); + auto in1 = NDArrayFactory::create('c', {3}, {10, 20, 30}); + auto in2 = NDArrayFactory::create('c', {4}, {100, 200, 300, 400}); + auto exp0 = NDArrayFactory::create('c', {2,3,4}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}); + auto exp1 = NDArrayFactory::create('c', {2,3,4}, {10, 10, 10, 10, 20, 20, 20, 20, 30, 30, 30, 30, 10, 10, 10, 10, 20, 20, 20, 20, 30, 30, 30, 30}); + auto exp2 = NDArrayFactory::create('c', {2,3,4}, {100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400}); + + sd::ops::meshgrid op; + auto results = op.evaluate({&in0, &in1, &in2}, {}, {0}); + auto out0 = results.at(0); + auto out1 = results.at(1); + auto out2 = results.at(2); + + // out0->printIndexedBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp0.isSameShape(out0)); + ASSERT_TRUE(exp0.equalsTo(out0)); + ASSERT_TRUE(exp1.isSameShape(out1)); + ASSERT_TRUE(exp1.equalsTo(out1)); + ASSERT_TRUE(exp2.isSameShape(out2)); + ASSERT_TRUE(exp2.equalsTo(out2)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, meshgrid_test2) { + + auto in0 = NDArrayFactory::create('c', {2}, {1, 2}); + auto in1 = NDArrayFactory::create('c', {3}, {10, 20, 30}); + auto in2 = NDArrayFactory::create('c', {4}, {100, 200, 300, 400}); + auto exp0 = NDArrayFactory::create('c', {3,2,4}, {1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2}); + auto exp1 = NDArrayFactory::create('c', {3,2,4}, {10, 10, 10, 10, 10, 10, 10, 10, 20, 20, 20, 20, 20, 20, 20, 20, 30, 30, 30, 30, 30, 30, 30, 30}); + auto exp2 = NDArrayFactory::create('c', {3,2,4}, {100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400}); + + sd::ops::meshgrid op; + auto results = op.evaluate({&in0, &in1, &in2}); + auto out0 = results.at(0); + auto out1 = results.at(1); + auto out2 = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp0.isSameShape(out0)); + ASSERT_TRUE(exp0.equalsTo(out0)); + ASSERT_TRUE(exp1.isSameShape(out1)); + ASSERT_TRUE(exp1.equalsTo(out1)); + ASSERT_TRUE(exp2.isSameShape(out2)); + ASSERT_TRUE(exp2.equalsTo(out2)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, meshgrid_test3) { + + auto in0 = NDArrayFactory::create('c', {2}, {1, 2}); + auto in1 = NDArrayFactory::create('c', {1,3}, {10, 20, 30}); + auto in2 = NDArrayFactory::create('c', {2,2}, {100, 200, 300, 400}); + auto exp0 = NDArrayFactory::create('c', {3,2,4}, {1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2}); + auto exp1 = NDArrayFactory::create('c', {3,2,4}, {10, 10, 10, 10, 10, 10, 10, 10, 20, 20, 20, 20, 20, 20, 20, 20, 30, 30, 30, 30, 30, 30, 30, 30}); + auto exp2 = NDArrayFactory::create('c', {3,2,4}, {100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400}); + + sd::ops::meshgrid op; + auto results = op.evaluate({&in0, &in1, &in2}); + auto out0 = results.at(0); + auto out1 = results.at(1); + auto out2 = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp0.isSameShape(out0)); + ASSERT_TRUE(exp0.equalsTo(out0)); + ASSERT_TRUE(exp1.isSameShape(out1)); + ASSERT_TRUE(exp1.equalsTo(out1)); + ASSERT_TRUE(exp2.isSameShape(out2)); + ASSERT_TRUE(exp2.equalsTo(out2)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, meshgrid_test4) { + + auto in0 = NDArrayFactory::create('c', {1,2}, {1, 2}); + auto in1 = NDArrayFactory::create('c', {3,1}, {10, 20, 30}); + auto in2 = NDArrayFactory::create('c', {1,4,1}, {100, 200, 300, 400}); + auto exp0 = NDArrayFactory::create('c', {2,3,4}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2}); + auto exp1 = NDArrayFactory::create('c', {2,3,4}, {10, 10, 10, 10, 20, 20, 20, 20, 30, 30, 30, 30, 10, 10, 10, 10, 20, 20, 20, 20, 30, 30, 30, 30}); + auto exp2 = NDArrayFactory::create('c', {2,3,4}, {100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400, 100, 200, 300, 400}); + + sd::ops::meshgrid op; + auto results = op.evaluate({&in0, &in1, &in2}, {}, {0}); + auto out0 = results.at(0); + auto out1 = results.at(1); + auto out2 = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp0.isSameShape(out0)); + ASSERT_TRUE(exp0.equalsTo(out0)); + ASSERT_TRUE(exp1.isSameShape(out1)); + ASSERT_TRUE(exp1.equalsTo(out1)); + ASSERT_TRUE(exp2.isSameShape(out2)); + ASSERT_TRUE(exp2.equalsTo(out2)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, meshgrid_test5) { + + auto in0 = NDArrayFactory::create(1); + auto in1 = NDArrayFactory::create(2); + auto in2 = NDArrayFactory::create(3); + auto exp0 = NDArrayFactory::create('c', {1,1,1}, {1}); + auto exp1 = NDArrayFactory::create('c', {1,1,1}, {2}); + auto exp2 = NDArrayFactory::create('c', {1,1,1}, {3}); + + sd::ops::meshgrid op; + auto results = op.evaluate({&in0, &in1, &in2}, {}, {0}); + auto out0 = results.at(0); + auto out1 = results.at(1); + auto out2 = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp0.isSameShape(out0)); + ASSERT_TRUE(exp0.equalsTo(out0)); + ASSERT_TRUE(exp1.isSameShape(out1)); + ASSERT_TRUE(exp1.equalsTo(out1)); + ASSERT_TRUE(exp2.isSameShape(out2)); + ASSERT_TRUE(exp2.equalsTo(out2)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, meshgrid_test6) { + + auto in0 = NDArrayFactory::create('c', {2,2},{1,2,3,4}); + auto in1 = NDArrayFactory::create(5); + auto in2 = NDArrayFactory::create(6); + auto exp0 = NDArrayFactory::create('c', {4,1,1}, {1,2,3,4}); + auto exp1 = NDArrayFactory::create('c', {4,1,1}, {5,5,5,5}); + auto exp2 = NDArrayFactory::create('c', {4,1,1}, {6,6,6,6}); + + sd::ops::meshgrid op; + auto results = op.evaluate({&in0, &in1, &in2}, {}, {0}); + auto out0 = results.at(0); + auto out1 = results.at(1); + auto out2 = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp0.isSameShape(out0)); + ASSERT_TRUE(exp0.equalsTo(out0)); + ASSERT_TRUE(exp1.isSameShape(out1)); + ASSERT_TRUE(exp1.equalsTo(out1)); + ASSERT_TRUE(exp2.isSameShape(out2)); + ASSERT_TRUE(exp2.equalsTo(out2)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, meshgrid_test7) { + + auto in0 = NDArrayFactory::create('c', {2,2},{1,2,3,4}); + auto in1 = NDArrayFactory::create(5); + auto in2 = NDArrayFactory::create(6); + auto exp0 = NDArrayFactory::create('c', {1,4,1}, {1,2,3,4}); + auto exp1 = NDArrayFactory::create('c', {1,4,1}, {5,5,5,5}); + auto exp2 = NDArrayFactory::create('c', {1,4,1}, {6,6,6,6}); + + sd::ops::meshgrid op; + auto results = op.evaluate({&in0, &in1, &in2}, {}, {1}); + auto out0 = results.at(0); + auto out1 = results.at(1); + auto out2 = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp0.isSameShape(out0)); + ASSERT_TRUE(exp0.equalsTo(out0)); + ASSERT_TRUE(exp1.isSameShape(out1)); + ASSERT_TRUE(exp1.equalsTo(out1)); + ASSERT_TRUE(exp2.isSameShape(out2)); + ASSERT_TRUE(exp2.equalsTo(out2)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, meshgrid_test8) { + + auto in0 = NDArrayFactory::create(5); + auto exp0 = NDArrayFactory::create('c', {1}, {5}); + + sd::ops::meshgrid op; + auto results = op.evaluate({&in0}, {}, {0}); + auto out0 = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp0.isSameShape(out0)); + ASSERT_TRUE(exp0.equalsTo(out0)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, meshgrid_test9) { + + auto in0 = NDArrayFactory::create(5); + auto exp0 = NDArrayFactory::create('c', {1}, {5}); + + sd::ops::meshgrid op; + auto results = op.evaluate({&in0}, {}, {1}); + auto out0 = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp0.isSameShape(out0)); + ASSERT_TRUE(exp0.equalsTo(out0)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, WeightedCrossEntropyWithLogits_1) { + + + auto input = NDArrayFactory::create('c', {2, 3}, {11.f, 13.f, 4.f, 15.f, 6.f, 3.f}); + auto targets = NDArrayFactory::create('c', {2, 3}, {15.5f, 15.7f, 5.f , 15.f, 5.f, 6.f}); + auto weight = NDArrayFactory::create(0.7f); + auto expected = NDArrayFactory::create('c', {2, 3}, {-159.50006, -191.1, -16.009075, -210., -24.001238, -15.03887}); + +//Targets {15.5f, 15.7f, 5.f , 15.f, 5.f, 6.f}; +//---------- +//Inputs {11.f, 13.f, 4.f, 15.f, 6.f, 3.f}; +//---------- +//Weights [0.7] +//Result {-159.50006, -191.1, -16.009075, -210., -24.001238, -15.03887} + + sd::ops::weighted_cross_entropy_with_logits op; + auto results = op.evaluate({&targets, &input, &weight}); + auto output = results.at(0); + + // output->printIndexedBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, WeightedCrossEntropyWithLogits_2) { + + + auto input = NDArrayFactory::create('c', {2, 3}, {11.f, 13.f, 4.f, 15.f, 6.f, 3.f}); + auto targets = NDArrayFactory::create('c', {2, 3}, {15.5f, 15.7f, 5.f, 15.f, 5.f, 6.f}); + auto weights = NDArrayFactory::create({0.5f, 0.7f, 1.0f}) ; + auto expected = NDArrayFactory::create('c', {2, 3}, {-159.5001f, -191.1f, -15.98185f, -210.f, -24.001238f, -14.951412f}); + + sd::ops::weighted_cross_entropy_with_logits op; + auto results = op.evaluate({&targets, &input, &weights}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, lstm_test1) { + + const int time = 5; + const int batchSize = 3; + const int inSize = 3; + const int numProj = 3; + const int numUnits = 3; + + auto x = NDArrayFactory::create('c', {time, batchSize, inSize}); + auto h0 = NDArrayFactory::create('c', {batchSize, numProj}); + auto c0 = NDArrayFactory::create('c', {batchSize, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, 4*numUnits}); + auto Wh = NDArrayFactory::create('c', {numProj, 4*numUnits}); + auto Wc = NDArrayFactory::create('c', {3*numUnits}); + auto Wp = NDArrayFactory::create('c', {numUnits, numProj}); + auto b = NDArrayFactory::create('c', {4*numUnits}); + + x.linspace(0.5, 0.5); + h0 = 1.; + c0 = 2.; + Wx = 0.003; + Wh = 0.006; + Wc = 0.; + Wp = 0.; + b = 0.5; + + auto expH = NDArrayFactory::create('c', {time, batchSize, numProj}, {0.57574,0.57574,0.57574,0.58006,0.58006,0.58006,0.58434,0.58434,0.58434, + 0.55114,0.55114,0.55114,0.55732,0.55732,0.55732,0.56338,0.56338,0.56338, + 0.53763,0.53763,0.53763,0.54534,0.54534,0.54534,0.55287,0.55287,0.55287, + 0.53626,0.53626,0.53626,0.54487,0.54487,0.54487,0.55327,0.55327,0.55327, + 0.54484,0.54484,0.54484,0.55379,0.55379,0.55379,0.5625 ,0.5625 ,0.5625}); + + auto expClast = NDArrayFactory::create('c', {1, batchSize, numProj}, {1.1589154,1.1589154,1.1589154,1.1892855,1.1892855,1.1892855,1.219861 ,1.219861 ,1.219861}); + + sd::ops::lstm op; + auto results = op.evaluate({&x, &h0, &c0, &Wx, &Wh, &Wc, &Wp, &b}, {0., 0., 0.}, {0, 0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto *h = results.at(0); + auto *c = results.at(1); + auto cLast = (*c)({4,5,0,0,0,0},true); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + + ASSERT_TRUE(expClast.isSameShape(&cLast)); + ASSERT_TRUE(expClast.equalsTo(&cLast)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, relu6_test1) { + + auto input = NDArrayFactory::create('c', {2,4}, {-13.,10,-5,0,2,7,6,12}); + auto expected = NDArrayFactory::create('c', {2,4}, {0., 6., 0., 0.,2., 6., 6., 6.}); + + sd::ops::relu6 op; + auto results = op.evaluate({&input}, {0.}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, relu6_bp_test1) { + + auto input = NDArrayFactory::create('c', {2,4}, {-13.,10, -5, 0, 2, 7, 6, 5}); + auto gradO = NDArrayFactory::create('c', {2,4}, {-1., -2., 0., 4., 5., 6., 7., 8.}); + + auto expected = NDArrayFactory::create('c', {2,4}, {0., 0., 0., 0., 5., 0., 0., 8.}); + + sd::ops::relu6_bp op; + auto results = op.evaluate({&input, &gradO}, {0.}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_1) { + + auto x = NDArrayFactory::create('c', {2, 2, 2, 2}, { 5.5f, 0.f, 0.3f, 5.5f, + 8.6f, 0.f, 0.f, 0.4f, + 1.5f, 1.f, 1.3f, 1.5f, + 2.6f, 2.f, 3.f, 1.4f} + ); + + auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, { + 0.98386997f, 0.f, 0.05358852f, 0.9824562f, + 0.99330735f, 0.f, 0.f, 0.37139067f, + 0.72760683f, 0.4850712f, 0.5848977f, 0.67488194f, + 0.7581754f, 0.58321184f, 0.86747235f, 0.4048204f} + ); + + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {5}); + auto out = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(out)); + // out->printIndexedBuffer("LRN out"); + // exp.printIndexedBuffer("LRN exp"); + ASSERT_TRUE(exp.equalsTo(out)); + + +} + + +//////////////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_2) { + + auto x = NDArrayFactory::create('c', {2, 2, 2, 2}, { 5.5f, 0.f, 0.3f, 5.5f, + 8.6f, 0.f, 0.f, 0.4f, + 1.5f, 1.f, 1.3f, 1.5f, + 2.6f, 2.f, 3.f, 1.4f}); + + auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, { + 0.98386997f, 0.f, 0.05358852f, 0.9824562f, + 0.99330735f, 0.f, 0.f, 0.37139067f, + 0.72760683f, 0.4850712f, 0.5848977f, 0.67488194f, + 0.7581754f, 0.58321184f, 0.86747235f, 0.4048204f}); + + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); + auto out = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(out)); + // out->printIndexedBuffer("LRN out"); + // exp.printIndexedBuffer("LRN exp"); + ASSERT_TRUE(exp.equalsTo(out)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_3) { + + auto x = NDArrayFactory::create('c', {2, 2, 2, 4}, { + + 5.5f, 0.f, 0.3f, 5.5f, + 1.5f, 0.f, 1.3f, 6.5f, + 8.6f, 0.f, 0.f, 0.4f, + 2.5f, 1.f, 0.3f, 4.5f, + 1.5f, 1.f, 1.3f, 1.5f, + 3.5f, 0.f, 1.3f, 2.5f, + 2.6f, 2.f, 3.f, 1.4f, + 4.5f, 1.f, 0.3f, 0.5f} + ); + + auto exp = NDArrayFactory::create('c', {2, 2, 2, 4}, { + 0.9824562f, 0.f, 0.03822664f, 0.9824562f, + 0.67488194f, 0.f, 0.18924236f, 0.96960944f, + 0.99330735f, 0.f, 0.f, 0.37139067f, + 0.86567914f, 0.18702209f, 0.05610663f, 0.9520745f, + 0.6154575f, 0.34942827f, 0.45425674f, 0.6154575f, + 0.905509f, 0.f, 0.2824086f, 0.8361251f, + 0.57063663f, 0.41959068f, 0.629386f, 0.3504383f, + 0.9520745f, 0.21039814f, 0.06311944f, 0.3268602f } + ); + + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); + auto out = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(out)); + // out->printIndexedBuffer("LRN out"); + // exp.printIndexedBuffer("LRN exp"); + ASSERT_TRUE(exp.equalsTo(out)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_4) { + + auto x = NDArrayFactory::create('c', {2, 2, 2, 4}, { + + 5.5f, 0.f, 0.3f, 5.5f, + 1.5f, 0.f, 1.3f, 6.5f, + 8.6f, 0.f, 0.f, 0.4f, + 2.5f, 1.f, 0.3f, 4.5f, + 1.5f, 1.f, 1.3f, 1.5f, + 3.5f, 0.f, 1.3f, 2.5f, + 2.6f, 2.f, 3.f, 1.4f, + 4.5f, 1.f, 0.3f, 0.5f} + ); + + auto exp = NDArrayFactory::create('c', {2, 2, 2, 4}, { + 0.70082176f, 0.f, 0.03822664f, 0.70082176f, + 0.21835658f, 0.f, 0.18924236f, 0.9462118f, + 0.9922489f, 0.f, 0.f, 0.04615111f, + 0.46755522f, 0.18702209f, 0.05610663f, 0.8415994f, + 0.5241424f, 0.34942827f, 0.45425674f, 0.5241424f, + 0.76033086f, 0.f, 0.2824086f, 0.54309344f, + 0.54546785f, 0.41959068f, 0.629386f, 0.29371348f, + 0.94679165f, 0.21039814f, 0.06311944f, 0.10519907f} + ); + + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {5}); + auto out = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(out)); + // out->printIndexedBuffer("LRN out"); + // exp.printIndexedBuffer("LRN exp"); + ASSERT_TRUE(exp.equalsTo(out)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests4, LrnTest_5) { + + auto x = NDArrayFactory::create('c', {2, 2, 2, 4}, { + + 5.5f, 0.f, 0.3f, 5.5f, + 1.5f, 0.f, 1.3f, 6.5f, + 8.6f, 0.f, 0.f, 0.4f, + 2.5f, 1.f, 0.3f, 4.5f, + 1.5f, 1.f, 1.3f, 1.5f, + 3.5f, 0.f, 1.3f, 2.5f, + 2.6f, 2.f, 3.f, 1.4f, + 4.5f, 1.f, 0.3f, 0.5f} + ); + + auto eps = NDArrayFactory::create('c', {2, 2, 2, 4}, { + 0.70082176f, 0.f, 0.03822664f, 0.70082176f, + 0.21835658f, 0.f, 0.18924236f, 0.9462118f, + + 0.9922489f, 0.f, 0.f, 0.04615111f, + 0.46755522f, 0.18702209f, 0.05610663f, 0.8415994f, + + + 0.5241424f, 0.34942827f, 0.45425674f, 0.5241424f, + 0.76033086f, 0.f, 0.2824086f, 0.54309344f, + + 0.54546785f, 0.41959068f, 0.629386f, 0.29371348f, + 0.94679165f, 0.21039814f, 0.06311944f, 0.10519907f} + ); + + auto exp = NDArrayFactory::create('c', {2, 2, 2, 4}); + + sd::ops::lrn_bp op; + auto results = op.evaluate({&x, &eps}, {1.0, 1.0, 0.5}, {5}, {}, {}, false); + auto out = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(out)); + // out->printIndexedBuffer("LRN out"); + // exp.printIndexedBuffer("LRN exp"); +// ASSERT_TRUE(exp.equalsTo(out)); + + +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, tri_test1) { + + const int rows = 3; + const int cols = 5; + + auto expected = NDArrayFactory::create('c', {rows, cols}, {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f}); + + sd::ops::tri op; + auto results = op.evaluate({}, {}, {rows, cols}); + auto output = results.at(0); + + // output->printIndexedBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, tri_test2) { + + const int rows = 3; + const int cols = 5; + const int diag = 2; + + auto expected = NDArrayFactory::create('c', {rows, cols}, {1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 1.f, 1.f, 0.f, 1.f, 1.f, 1.f, 1.f, 1.f}); + + sd::ops::tri op; + auto results = op.evaluate({}, {}, {rows, cols, diag}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, tri_test3) { + + const int rows = 3; + const int cols = 5; + const int diag = -1; + + auto expected = NDArrayFactory::create('c', {rows, cols}, {0.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f}); + + sd::ops::tri op; + auto results = op.evaluate({}, {}, {rows, cols, diag}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, tri_test4) { + + const int rows = 3; + const int cols = 5; + const int diag = -2; + + auto expected = NDArrayFactory::create('c', {rows, cols}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f}); + + sd::ops::tri op; + auto results = op.evaluate({}, {}, {rows, cols, diag}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, tri_test5) { + + const int rows = 5; + + auto expected = NDArrayFactory::create('c', {rows, rows}, {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 1.f, 1.f, 0.f, 1.f, 1.f, 1.f, 1.f, 1.f}); + + sd::ops::tri op; + auto results = op.evaluate({}, {}, {rows}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, tri_test6) { + + const int rows = 3; + const int cols = 5; + const int diag = -20; + + auto expected = NDArrayFactory::create('c', {rows, cols}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); + + sd::ops::tri op; + auto results = op.evaluate({}, {}, {rows, cols, diag}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, tri_test7) { + + const int rows = 3; + const int cols = 5; + const int diag = 20; + + auto expected = NDArrayFactory::create('c', {rows, cols}, {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}); + + sd::ops::tri op; + auto results = op.evaluate({}, {}, {rows, cols, diag}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, triu_test1) { + + auto input = NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto expected = NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 0, 5, 6, 0, 0, 9, 0, 0, 0}); + + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, triu_test2) { + + auto input = NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto expected = NDArrayFactory::create('c', {4, 3}, {1, 2, 3,4, 5, 6,0, 8, 9,0, 0, 12}); + + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {-1}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, triu_test3) { + + auto input = NDArrayFactory::create('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto expected = NDArrayFactory::create('c', {2, 3, 2}, {1, 2,3, 4,0, 6,7, 8,9,10,0,12}); + + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {-1}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, triu_test4) { + + auto input = NDArrayFactory::create('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto expected = NDArrayFactory::create('c', {2, 3, 2}, {1, 2,0, 4,0, 0,7, 8,0, 10,0, 0}); + + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, triu_test5) { + + auto input = NDArrayFactory::create('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto expected = NDArrayFactory::create('c', {2, 3, 2}, {0, 2,0, 0,0, 0,0, 8,0, 0,0, 0}); + + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {1}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, triu_test6) { + + auto input = NDArrayFactory::create('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto expected = NDArrayFactory::create('c', {2, 3, 2}, {0, 0,0, 0,0, 0,0, 0,0, 0,0, 0}); + + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {10}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, triu_test7) { + + auto input = NDArrayFactory::create('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto expected = NDArrayFactory::create('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {-10}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, triu_test8) { + + auto input = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); + auto expected = NDArrayFactory::create('c', {6, 6}, {1, 2, 3, 4, 5, 6,0, 2, 3, 4, 5, 6,0, 0, 3, 4, 5, 6,0, 0, 0, 4, 5, 6,0, 0, 0, 0, 5, 6,0, 0, 0, 0, 0, 6}); + + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, triu_test9) { + + auto input = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); + auto expected = NDArrayFactory::create('c', {6, 6}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 0, 2, 3, 4, 5, 6, 0, 0, 3, 4, 5, 6}); + + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {-3}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, triu_test10) { + + auto input = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); + auto expected = NDArrayFactory::create('c', {6, 6}, {0, 0, 0, 4, 5, 6, 0, 0, 0, 0, 5, 6, 0, 0, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {3}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, triu_test11) { + + auto input = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); + auto expected = NDArrayFactory::create('c', {6, 6}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}); + + sd::ops::triu op; + auto results = op.evaluate({&input}, {}, {-58}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, triu_bp_test1) { + + auto input = NDArrayFactory::create('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto gradO = NDArrayFactory::create('c', {2, 3, 2}); + gradO = 0.5; + + auto expected = NDArrayFactory::create('c', {2, 3, 2}, {0.,0.5,0.,0. ,0.,0. ,0.,0.5,0.,0. ,0.,0.}); + + sd::ops::triu_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {1}); + auto gradI = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expected.isSameShape(gradI)); + ASSERT_TRUE(expected.equalsTo(gradI)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, triu_bp_test2) { + + auto input = NDArrayFactory::create('c', {2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto gradO = NDArrayFactory::create('c', {2, 3, 2}); + gradO = 0.5; + + auto expected = NDArrayFactory::create('c', {2, 3, 2}, {0.5,0.5,0. ,0.5,0. ,0. ,0.5,0.5,0. ,0.5,0. ,0.}); + + sd::ops::triu_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {}); + auto gradI = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expected.isSameShape(gradI)); + ASSERT_TRUE(expected.equalsTo(gradI)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, triu_bp_test3) { + + auto input = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); + auto gradO = NDArrayFactory::create('c', {6,6}); + gradO = 0.5; + + auto expected = NDArrayFactory::create('c', {6,6}, {0.5, 0.5, 0.5, 0.5, 0.5, 0.5,0.5, 0.5, 0.5, 0.5, 0.5, 0.5,0.5, 0.5, 0.5, 0.5, 0.5, 0.5,0. , 0.5, 0.5, 0.5, 0.5, 0.5,0. , 0. , 0.5, 0.5, 0.5, 0.5,0. , 0. , 0. , 0.5, 0.5, 0.5}); + + sd::ops::triu_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {-2}); + auto gradI = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expected.isSameShape(gradI)); + ASSERT_TRUE(expected.equalsTo(gradI)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests4, triu_bp_test4) { + + auto input = NDArrayFactory::create('c', {2,3}, {1, 2, 3, 4, 5, 6}); + auto gradO = NDArrayFactory::create('c', {2,3}); + gradO = 0.5; + + auto expected = NDArrayFactory::create('c', {2,3}, {0., 0., 0., 0., 0., 0.}); + + sd::ops::triu_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {10}); + auto gradI = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + + ASSERT_TRUE(expected.isSameShape(gradI)); + ASSERT_TRUE(expected.equalsTo(gradI)); + + +} + diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests5.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests5.cpp new file mode 100644 index 000000000..0dbbe5490 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests5.cpp @@ -0,0 +1,3086 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + + +#include "testlayers.h" +#include +#include +#include +#include + + +using namespace sd; +using namespace sd::graph; + +class DeclarableOpsTests5 : public testing::Test { +public: + + DeclarableOpsTests5() { + printf("\n"); + fflush(stdout); + } +}; + + +TEST_F(DeclarableOpsTests5, Test_PermuteEquality_1) { + auto x = NDArrayFactory::create('c', {1, 60}); + auto exp = NDArrayFactory::create('c', {3, 5, 4}, {1.0, 6.0, 11.0, 16.0, 2.0, 7.0, 12.0, 17.0, 3.0, 8.0, 13.0, 18.0, 4.0, 9.0, 14.0, 19.0, 5.0, 10.0, 15.0, 20.0, 21.0, 26.0, 31.0, 36.0, 22.0, 27.0, 32.0, 37.0, 23.0, 28.0, 33.0, 38.0, 24.0, 29.0, 34.0, 39.0, 25.0, 30.0, 35.0, 40.0, 41.0, 46.0, 51.0, 56.0, 42.0, 47.0, 52.0, 57.0, 43.0, 48.0, 53.0, 58.0, 44.0, 49.0, 54.0, 59.0, 45.0, 50.0, 55.0, 60.0}); + x.linspace(1); + x.reshapei('c', {3, 4, 5}); + + sd::ops::permute op; + auto result = op.evaluate({&x}, {}, {0, 2, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests5, Test_PermuteEquality_0) { + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0}); + x.reshapei('c', {3, 4, 5}); + +// x.printShapeInfo("{0, 1, 2} shape"); +// x.printBuffer("{0, 1, 2} data"); + + sd::ops::permute op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(DeclarableOpsTests5, Test_PermuteEquality_2) { + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {4, 3, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 21.0, 22.0, 23.0, 24.0, 25.0, 41.0, 42.0, 43.0, 44.0, 45.0, 6.0, 7.0, 8.0, 9.0, 10.0, 26.0, 27.0, 28.0, 29.0, 30.0, 46.0, 47.0, 48.0, 49.0, 50.0, 11.0, 12.0, 13.0, 14.0, 15.0, 31.0, 32.0, 33.0, 34.0, 35.0, 51.0, 52.0, 53.0, 54.0, 55.0, 16.0, 17.0, 18.0, 19.0, 20.0, 36.0, 37.0, 38.0, 39.0, 40.0, 56.0, 57.0, 58.0, 59.0, 60.0}); + x.reshapei('c', {3, 4, 5}); + +// x.printShapeInfo("{1, 0, 2} shape"); +// x.printBuffer("{1, 0, 2} data"); + + sd::ops::permute op; + auto result = op.evaluate({&x}, {}, {1, 0, 2}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests5, Test_PermuteEquality_3) { + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {4, 5, 3}, {1.0, 21.0, 41.0, 2.0, 22.0, 42.0, 3.0, 23.0, 43.0, 4.0, 24.0, 44.0, 5.0, 25.0, 45.0, 6.0, 26.0, 46.0, 7.0, 27.0, 47.0, 8.0, 28.0, 48.0, 9.0, 29.0, 49.0, 10.0, 30.0, 50.0, 11.0, 31.0, 51.0, 12.0, 32.0, 52.0, 13.0, 33.0, 53.0, 14.0, 34.0, 54.0, 15.0, 35.0, 55.0, 16.0, 36.0, 56.0, 17.0, 37.0, 57.0, 18.0, 38.0, 58.0, 19.0, 39.0, 59.0, 20.0, 40.0, 60.0}); + x.reshapei('c', {3, 4, 5}); + +// x.printShapeInfo("{1, 2, 0} shape"); +// x.printBuffer("{1, 2, 0} data"); + + sd::ops::permute op; + auto result = op.evaluate({&x}, {}, {1, 2, 0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests5, Test_PermuteEquality_4) { + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {5, 3, 4}, {1.0, 6.0, 11.0, 16.0, 21.0, 26.0, 31.0, 36.0, 41.0, 46.0, 51.0, 56.0, 2.0, 7.0, 12.0, 17.0, 22.0, 27.0, 32.0, 37.0, 42.0, 47.0, 52.0, 57.0, 3.0, 8.0, 13.0, 18.0, 23.0, 28.0, 33.0, 38.0, 43.0, 48.0, 53.0, 58.0, 4.0, 9.0, 14.0, 19.0, 24.0, 29.0, 34.0, 39.0, 44.0, 49.0, 54.0, 59.0, 5.0, 10.0, 15.0, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0, 50.0, 55.0, 60.0}); + x.reshapei('c', {3, 4, 5}); + +// x.printShapeInfo("{2, 0, 1} shape"); +// x.printBuffer("{2, 0, 1} data"); + + sd::ops::permute op; + auto result = op.evaluate({&x}, {}, {2, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests5, Test_PermuteEquality_5) { + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {5, 4, 3}, {1.0, 21.0, 41.0, 6.0, 26.0, 46.0, 11.0, 31.0, 51.0, 16.0, 36.0, 56.0, 2.0, 22.0, 42.0, 7.0, 27.0, 47.0, 12.0, 32.0, 52.0, 17.0, 37.0, 57.0, 3.0, 23.0, 43.0, 8.0, 28.0, 48.0, 13.0, 33.0, 53.0, 18.0, 38.0, 58.0, 4.0, 24.0, 44.0, 9.0, 29.0, 49.0, 14.0, 34.0, 54.0, 19.0, 39.0, 59.0, 5.0, 25.0, 45.0, 10.0, 30.0, 50.0, 15.0, 35.0, 55.0, 20.0, 40.0, 60.0}); + x.reshapei('c', {3, 4, 5}); + +// x.printShapeInfo("{2, 1, 0} shape"); +// x.printBuffer("{2, 1, 0} data"); + + sd::ops::permute op; + auto result = op.evaluate({&x}, {}, {2, 1, 0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests5, Test_TTS_bp_1) { + auto x = NDArrayFactory::create('c', {2, 1, 3}); + auto eps = NDArrayFactory::create('c', {2, 4, 3}); + + auto exp = NDArrayFactory::create('c', {2, 1, 3}, {22.f, 26.f, 30.f, 70.f, 74.f, 78.f}); + + eps.linspace(1.f); + + sd::ops::tile_to_shape_bp op; + auto result = op.evaluate({&x, &eps}, {}, {2, 4, 3}); + + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + // z->printShapeInfo("RES shape"); + // x.printShapeInfo("EXP shape"); + // z->printIndexedBuffer("RES output"); + ASSERT_TRUE(x.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(DeclarableOpsTests5, Test_Rdiv_bp_1) { + auto x = NDArrayFactory::create('c', {3, 1}, {1, 2, 3}); + auto y = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto eps = NDArrayFactory::create('c', {3, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + + + sd::ops::reversedivide op_ff; + auto result_ff = op_ff.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result_ff.status()); + + auto z_ff = result_ff.at(0); + ASSERT_TRUE(eps.isSameShape(z_ff)); + + sd::ops::reversedivide_bp op_bp; + auto result_bp = op_bp.evaluate({&x, &y, &eps}, {}, {}); + ASSERT_EQ(Status::OK(), result_bp.status()); + + auto z_bp = result_bp.at(0); + ASSERT_TRUE(x.isSameShape(z_bp)); +} + + +TEST_F(DeclarableOpsTests5, Test_Boolean_diff_1) { + auto x = NDArrayFactory::create('c', {1, 1}, {1.0f}); + auto y = NDArrayFactory::create(2.0f); + + sd::ops::less op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(result.at(0)->t(0), true); + +} + +TEST_F(DeclarableOpsTests5, Test_SetSeed_1) { + auto x = NDArrayFactory::create('c', {1, 1}, {120}); + auto y = NDArrayFactory::create(5); + + sd::ops::set_seed op; + auto result = op.evaluate({&x, &y}, {}, {120, 5}); + + ASSERT_EQ(Status::OK(), result.status()); +// result->at(0)->printIndexedBuffer("RES SEED"); + + sd::ops::get_seed getOp; + auto getRes = getOp.evaluate({}); + ASSERT_EQ(Status::OK(), getRes.status()); +// getres.at(0)->printIndexedBuffer("Output RES GET SEED"); +// ASSERT_EQ(result.at(0)->t(0), true); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, scatterMul_test1) { + auto matrix = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); + NDArray idc('c', {1}, std::vector({0LL}), sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 2}, {10.f, 1.f}); + auto exp = NDArrayFactory::create('c', {2, 2}, {10.f, 2.f, 3.f, 4.f}); + + sd::ops::scatter_mul op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, scatterDiv_test1) { + auto matrix = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); + NDArray idc('c', {1}, std::vector({0LL}), sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 2}, {10.f, 1.f}); + auto exp = NDArrayFactory::create('c', {2, 2}, {0.10f, 2.f, 3.f, 4.f}); + + sd::ops::scatter_div op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); +// z->printIndexedBuffer("Scatter Div"); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, scatterSub_test1) { + auto matrix = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); + NDArray idc('c', {1}, std::vector({0LL}), sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 2}, {10.f, 1.f}); + auto exp = NDArrayFactory::create('c', {2, 2}, {-9.f, 1.f, 3.f, 4.f}); + + sd::ops::scatter_sub op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); +// z->printIndexedBuffer("Scatter Sub"); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, hardsigmoid_test1) { + auto matrix = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {2, 2}, {0.7f, 0.9f, 1.f, 1.f}); + + sd::ops::hardsigmoid op; + auto result = op.evaluate({&matrix}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, hardsigmoid_test2) { + auto matrix = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); + auto eps = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {2, 2}, {0.2f, 0.4f, 0.f, 0.f}); + + sd::ops::hardsigmoid_bp op; + auto result = op.evaluate({&matrix, &eps}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, hardtanh_test1) { + auto matrix = NDArrayFactory::create('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3, 3}, {-1, -1, -1, -1, 0, 1, 1, 1, 1}); + + sd::ops::hardtanh op; + auto result = op.evaluate({&matrix}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); +// z->printIndexedBuffer("Hardtanh 2x2"); + ASSERT_TRUE(exp.equalsTo(z)); + + +} +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, hardtanh_test2) { + auto matrix = NDArrayFactory::create('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4}); + auto eps = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto exp = NDArrayFactory::create('c', {3, 3}, {0, 0, 0, 4, 5, 6, 0, 0, 0}); + + sd::ops::hardtanh_bp op; + auto result = op.evaluate({&matrix, &eps}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); +// z->printIndexedBuffer("Hardtanh_bp 2x2"); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, histogram_test1) { + auto matrix = NDArrayFactory::create('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {3, 3, 3}); + + sd::ops::histogram op; + auto result = op.evaluate({&matrix}, {}, {3}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); +// z->printIndexedBuffer("Histogram3"); + ASSERT_TRUE(exp.equalsTo(z)); + + +} +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, histogram_test2) { + auto matrix = NDArrayFactory::create('c', {3}, {1, 2, 1}); + auto exp = NDArrayFactory::create('c', {4}, {2, 0, 0, 1}); + + sd::ops::histogram op; + auto result = op.evaluate({&matrix}, {}, {4}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, Identity_test1) { + auto matrix = NDArrayFactory::create('c', {3, 3}, {-4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f}); +// auto exp = NDArrayFactory::create('c', {3, 3}, {3, 3, 3}); + + sd::ops::identity op; + auto result = op.evaluate({&matrix}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + ASSERT_TRUE(matrix.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, Identity_test2) { + auto matrix = NDArrayFactory::create('c', {3, 3}, {-4, -3, -2, -1, 0, 1, 2, 3, 4}); + auto eps = NDArrayFactory::create('c', {3, 3}, {1,2,3,4,5,6,7,8,9}); +// auto exp = NDArrayFactory::create('c', {3,3}); + sd::ops::identity_bp op; + auto result = op.evaluate({&matrix, &eps}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + ASSERT_TRUE(z->equalsTo(eps)); + + +} +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, Log1p_test1) { + auto matrix = NDArrayFactory::create('c', {3, 3}, {4, 3, 2, 1, 0, 1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {3,3}, {5,4,3,2,1,2,3,4,5}); + // auto eps = NDArrayFactory::create('c', {3, 3}, {1,2,3,4,5,6,7,8,9}); +// auto exp = NDArrayFactory::create('c', {3,3}); + sd::ops::Log1p op; + y.applyTransform(sd::transform::Log, y); + auto result = op.evaluate({&matrix}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + ASSERT_TRUE(z->equalsTo(y)); + + +} + +TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_1) { + + auto x = NDArrayFactory::create('c', {1, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto exp = NDArrayFactory::create('c', {4, 1, 1, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); + + sd::ops::space_to_batch op; + auto result = op.evaluate({&x, &paddings}, {}, {2}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_2) { + auto x = NDArrayFactory::create('c', {1, 2, 2, 1}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4, 1, 1, 1}, {1, 2, 3, 4}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); + + sd::ops::space_to_batch op; + auto result = op.evaluate({&x, &paddings}, {}, {2}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_3) { + + auto x = NDArrayFactory::create('c', {2, 2, 4, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 2, 0}); + auto exp = NDArrayFactory::create('c', {8, 1, 3, 1}, {0, 1, 3, 0, 9, 11,0, 2, 4, 0, 10, 12,0, 5, 7, 0, 13, 15,0, 6, 8, 0, 14, 16}); + + sd::ops::space_to_batch op; + auto result = op.evaluate({&x, &paddings}, {}, {2}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + // z->printIndexedBuffer(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, Test_SpaceToBatch_4) { + + const int blockSize = 2; + NDArray x('c', {3, 3*blockSize - 1 - 2, 4*blockSize - 2 - 3, 2}, {147, 148, 219, 220, 149, 150, 11, 12, 83, 84, 13, 14, 155, 156, 227, 228, 157, 158, 171, 172, 243, 244, 173, 174, 35, 36, 107, 108, 37, 38, 179, 180, 251, 252, 181, 182, 195, 196, 267, 268, 197, 198, 59, 60, 131, 132, 61, 62, 203, 204, 275, 276, 205, 206}, sd::DataType::FLOAT32); + NDArray paddings = NDArrayFactory::create('c', {2, 2}, {1, 2, 2, 3}); + + NDArray exp('c', {3*blockSize*blockSize, 3, 4, 2}, {0,0, 0,0, 0,0, 0,0, 0,0, 11,12, 13,14, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, + 0,0, 0,0, 0,0, 35,36, 37,38, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 59,60, 61,62, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, + 0,0, 0,0, 0,0, 0,0, 83,84, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 107, 108, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, + 0,0, 0,0, 0,0, 0,0, 0,0, 131, 132, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 147, 148, 149, 150, 0,0, 0,0, 155, 156, 157, 158, + 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 171, 172, 173, 174, 0,0, 0,0, 179, 180, 181, 182, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 195, 196, + 197, 198, 0,0, 0,0, 203, 204, 205, 206, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 219, 220, 0,0, 0,0, 0,0, 227, 228, 0,0, 0,0, 0,0, + 0,0, 0,0, 0,0, 0,0, 243, 244, 0,0, 0,0, 0,0, 251, 252, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0, 267, 268, 0,0, 0,0, 0,0, 275, + 276, 0,0, 0,0, 0,0, 0,0, 0,0, 0,0}, sd::DataType::FLOAT32); + + sd::ops::space_to_batch op; + auto result = op.evaluate({&x, &paddings}, {}, {blockSize}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + // z->printIndexedBuffer(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests5, Test_BatchToSpace_1) { + auto x = NDArrayFactory::create('c', {4, 1, 1, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto exp = NDArrayFactory::create('c', {1, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto crops = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); + + sd::ops::batch_to_space op; + auto result = op.evaluate({&x, &crops}, {}, {2}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + // z->printIndexedBuffer(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests5, Test_BatchToSpace_2) { + auto x = NDArrayFactory::create('c', {4, 1, 1, 1}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 2, 2, 1}, {1, 2, 3, 4}); + auto crops = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); + + sd::ops::batch_to_space op; + auto result = op.evaluate({&x, &crops}, {}, {2}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(DeclarableOpsTests5, Test_BatchToSpace_3) { + auto x = NDArrayFactory::create('c', {8, 1, 3, 1}, {0, 1, 3, 0, 9, 11, + 0, 2, 4, 0, 10, 12, + 0, 5, 7, 0, 13, 15, + 0, 6, 8, 0, 14, 16}); + auto exp = NDArrayFactory::create('c', {2, 2, 4, 1}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + auto crops = NDArrayFactory::create('c', {2, 2}, {0, 0, 2, 0}); + + sd::ops::batch_to_space op; + auto result = op.evaluate({&x, &crops}, {}, {2}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, Test_BatchToSpace_4) { + + const int blockSize = 2; + NDArray x('c', {3*blockSize*blockSize, 3, 4, 2}, sd::DataType::FLOAT32); + x.linspace(1, 1); + NDArray crops = NDArrayFactory::create('c', {2, 2}, {1, 2, 2, 3}); + + NDArray exp('c', {3, 3*blockSize - 1 - 2, 4*blockSize - 2 - 3, 2}, {147, 148, 219, 220, 149, 150, 11, 12, 83, 84, 13, 14, 155, 156, 227, 228, 157, 158, 171, 172, 243, 244, 173, 174, 35, 36, 107, 108, 37, 38, 179, 180, 251, 252, 181, 182, 195, 196, 267, 268, 197, 198, 59, 60, 131, 132, 61, 62, 203, 204, 275, 276, 205, 206}, sd::DataType::FLOAT32); + + sd::ops::batch_to_space op; + auto result = op.evaluate({&x, &crops}, {}, {blockSize}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, eye_test1) { + + auto expected = NDArrayFactory::create('c', {3, 3}, {1.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 1.f}); + + sd::ops::eye op; + auto results = op.evaluate({}, {}, {-99, 3}); + auto output = results.at(0); + // output->printIndexedBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, eye_test2) { + + auto expected = NDArrayFactory::create('c', {3, 4}, {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f}); + + sd::ops::eye op; + auto results = op.evaluate({}, {}, {-99, 3, 4}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, eye_test3) { + + auto expected = NDArrayFactory::create('c', {2, 3, 4}, {1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0}); + + sd::ops::eye op; + auto results = op.evaluate({}, {9 /*int*/}, {-99, 3, 4, 2}); + auto output = results.at(0); + // output->printIndexedBuffer("Output eye"); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, eye_test4) { + + auto expected = NDArrayFactory::create('c', {2, 2, 3, 4}, {1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.}); + + sd::ops::eye op; + auto results = op.evaluate({}, {6/*double*/}, {-99, 3, 4, 2, 2}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, eye_test5) { + + sd::ops::eye op; + auto result = op.evaluate({},{},{3, 2}); + + auto z = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, gatherNd_test1) { + + auto input = NDArrayFactory::create('c', {4, 3, 2}); + input.linspace(1); + auto indices = NDArrayFactory::create('c', {2,2,1}, {3,2,3,2}); + + auto expected = NDArrayFactory::create('c', {2,2,3,2}, {19, 20, 21, 22, 23, 24, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 13, 14, 15, 16, 17, 18}); + + sd::ops::gather_nd op; + auto results = op.evaluate({&input, &indices}, {}, {}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, gatherNd_test2) { + + auto input = NDArrayFactory::create('c', {4, 3, 2}); + input.linspace(1); + auto indices = NDArrayFactory::create('c', {2,2,2}, {3,2,1,2, 0,1,0,1}); + + auto expected = NDArrayFactory::create('c', {2,2,2}, {23, 24, 11, 12, 3, 4, 3, 4}); + + sd::ops::gather_nd op; + auto results = op.evaluate({&input, &indices}, {}, {}, {true}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, gatherNd_test3) { + + auto input = NDArrayFactory::create('c', {4, 3, 2}); + input.linspace(1); + auto indices = NDArrayFactory::create('c', {3}, {3,2,1}); + auto expected = NDArrayFactory::create(24.); + + sd::ops::gather_nd op; + auto results = op.evaluate({&input, &indices}, {}, {}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, gatherNd_test4) { + + auto input = NDArrayFactory::create('c', {4, 3, 2}); + input.linspace(1); + auto indices = NDArrayFactory::create('c', {2,3}, {3,2,1,0,2,1}); + auto expected = NDArrayFactory::create('c',{2}, {24., 6}); + + sd::ops::gather_nd op; + auto results = op.evaluate({&input, &indices}, {}, {}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, gatherNd_test5) { + + auto input = NDArrayFactory::create('c', {4}, {1,2,3,4}); + auto indices = NDArrayFactory::create('c', {5,1}, {3,2,0,1,1}); + auto expected = NDArrayFactory::create('c',{5}, {4.,3,1,2,2}); + + sd::ops::gather_nd op; + auto results = op.evaluate({&input, &indices}, {}, {}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, gatherNd_test6) { + + auto input = NDArrayFactory::create('c', {4}, {1,2,3,4}); + std::vector shape = {1}; + auto indices = NDArrayFactory::create('c', shape, {2}); + auto expected = NDArrayFactory::create(3.); + + sd::ops::gather_nd op; + auto results = op.evaluate({&input, &indices}, {}, {}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, gatherNd_test7) { + + auto input = NDArrayFactory::create('c', {4, 4}); + input.linspace(1); + auto indices = NDArrayFactory::create('c', {3,3,2}, {0,2,1, 0,1,0, 1,3,1, 0,2,1, 0,1,0, 1,3,1}); + auto expected = NDArrayFactory::create('c', {3,3}, {3,5,5,8,5,10,2,2,14}); + + sd::ops::gather_nd op; + auto results = op.evaluate({&input, &indices}, {}, {}, {true}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, gatherNd_test8) { + auto x = NDArrayFactory::create('c', {2, 2}, {1., 2., 3., 4.}); + auto y = NDArrayFactory::create('c', {2, 2}, {0, 0, 1, 1}); + auto e = NDArrayFactory::create('c', {2}, {1., 4.}); + + sd::ops::gather_nd op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_EQ(e, *z); + + +} + +TEST_F(DeclarableOpsTests5, gatherNd_test9) { + auto x = NDArrayFactory::create('c', {2, 4, 2, 2}); + auto indices = NDArrayFactory::create('c', {3, 3}, {0,2,1, 0,1,0, 1,3,1}); + auto exp = NDArrayFactory::create('c', {3,2}, {11.f, 12.f, 5.f, 6.f, 31.f, 32.f}); + x.linspace(1); + + sd::ops::gather_nd op; + auto result = op.evaluate({&x, &indices}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + //z->printIndexedBuffer(); + //z->printShapeInfo("z shape"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, gatherNd_test10) { + + auto input = NDArrayFactory::create('c', {4, 3, 2}); + auto indices = NDArrayFactory::create('c', {2,2,2}, {30,20,1,2, 0,10,0,1}); + + auto output = NDArrayFactory::create('c', {2,2,2}); + + sd::ops::gather_nd op; + + ASSERT_ANY_THROW(op.execute({&input, &indices}, {&output}, {}, {}, {true})); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, gatherNd_test11) { + + auto input = NDArrayFactory::create('c', {4, 4}); + auto indices = NDArrayFactory::create('c', {3,3,2}, {0,2,1, 0,10,0, 1,30,1, 0,20,1, 0,1,0, 1,30,1}); + auto output = NDArrayFactory::create('c', {3,3}); + + sd::ops::gather_nd op; + + ASSERT_ANY_THROW(op.execute({&input, &indices}, {&output}, {}, {}, {true})); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, reverse_sequense_test1) { + + auto input = NDArrayFactory::create('c', {3, 4, 5}); + input.linspace(1); + auto seqLengths = NDArrayFactory::create('c', {4}, {4,4,4,4}); + auto exp = NDArrayFactory::create('c', {3, 4, 5}, {4, 3, 2, 1, 5, 9, 8, 7, 6, 10, 14, 13, 12, 11, 15, 19, 18, 17, 16, 20, 24, 23, 22, 21, 25, 29, 28, 27, 26, 30, 34, 33, 32, 31, 35, 39, 38, 37, 36, 40, 44, 43, 42, 41, 45, 49, 48, 47, 46, 50, 54, 53, 52, 51, 55, 59, 58, 57, 56, 60}); + + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {2, 1}); + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, reverse_sequense_test2) { + + auto input = NDArrayFactory::create('c', {3, 4, 5}); + input.linspace(1); + auto seqLengths = NDArrayFactory::create('c', {4}, {0,1,2,3}); + auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 11, 13, 14, 15, 18, 17, 16, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 32, 31, 33, 34, 35, 38, 37, 36, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 52, 51, 53, 54, 55, 58, 57, 56, 59, 60}); + + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {2, 1}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, reverse_sequense_test3) { + + auto input = NDArrayFactory::create('c', {3, 4, 5}); + input.linspace(1); + auto seqLengths = NDArrayFactory::create('c', {3}, {2,3,4}); + auto exp = NDArrayFactory::create('c', {3, 4, 5}, {2, 1, 3, 4, 5, 7, 6, 8, 9, 10, 12, 11, 13, 14, 15, 17, 16, 18, 19, 20, 23, 22, 21, 24, 25, 28, 27, 26, 29, 30, 33, 32, 31, 34, 35, 38, 37, 36, 39, 40, 44, 43, 42, 41, 45, 49, 48, 47, 46, 50, 54, 53, 52, 51, 55, 59, 58, 57, 56, 60}); + + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {2, 0}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, reverse_sequense_test4) { + + auto input = NDArrayFactory::create('c', {3, 4, 5}); + input.linspace(1); + auto seqLengths = NDArrayFactory::create('c', {5}, {1, 2, 1, 2, 3}); + auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1, 22, 3, 24, 45, 6, 27, 8, 29, 50, 11, 32, 13, 34, 55, 16, 37, 18, 39, 60, 21, 2, 23, 4, 25, 26, 7, 28, 9, 30, 31, 12, 33, 14, 35, 36, 17, 38, 19, 40, 41, 42, 43, 44, 5, 46, 47, 48, 49, 10, 51, 52, 53, 54, 15, 56, 57, 58, 59, 20}); + + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {0, 2}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, reverse_sequense_test5) { + + auto input = NDArrayFactory::create('c', {3, 4, 5}); + input.linspace(1); + auto seqLengths = NDArrayFactory::create('c', {5}, {1, 2, 4, 2, 3}); + auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1, 7, 18, 9, 15, 6, 2, 13, 4, 10, 11, 12, 8, 14, 5, 16, 17, 3, 19, 20, 21, 27, 38, 29, 35, 26, 22, 33, 24, 30, 31, 32, 28, 34, 25, 36, 37, 23, 39, 40, 41, 47, 58, 49, 55, 46, 42, 53, 44, 50, 51, 52, 48, 54, 45, 56, 57, 43, 59, 60}); + + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {1, 2}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, reverse_sequense_test6) { + + auto input = NDArrayFactory::create('c', {3, 4, 5}); + input.linspace(1); + auto seqLengths = NDArrayFactory::create('c', {4}, {1, 2, 3, 2}); + auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1, 2, 3, 4, 5, 26, 27, 28, 29, 30, 51, 52, 53, 54, 55, 36, 37, 38, 39, 40, 21, 22, 23, 24, 25, 6, 7, 8, 9, 10, 31, 32, 33, 34, 35, 16, 17, 18, 19, 20, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 11, 12, 13, 14, 15, 56, 57, 58, 59, 60}); + + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {0, 1}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, reverse_sequense_test7) { + + auto input = NDArrayFactory::create('c', {1, 5}); + input.linspace(1); + std::vector data = {3}; + auto seqLengths = NDArrayFactory::create('c', {1}, data); + auto exp = NDArrayFactory::create('c', {1, 5}, {3, 2, 1, 4, 5}); + + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {1, 0}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, reverse_sequense_test8) { + + auto input = NDArrayFactory::create('c', {1, 5}); + input.linspace(1); + std::vector data = {1,0,1,0,1}; + auto seqLengths = NDArrayFactory::create('c', {5}, data); + auto exp = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); + + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {0, 1}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, reverse_sequense_test9) { + + auto input = NDArrayFactory::create('c', {5, 1}); + input.linspace(1); + std::vector data = {1,0,1,0,1}; + auto seqLengths = NDArrayFactory::create('c', {5}, data); + auto exp = NDArrayFactory::create('c', {5, 1}, {1, 2, 3, 4, 5}); + + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {1, 0}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, reverse_sequense_test10) { + + auto input = NDArrayFactory::create('c', {5, 1}); + input.linspace(1); + std::vector data = {3}; + auto seqLengths = NDArrayFactory::create('c', {1}, data); + auto exp = NDArrayFactory::create('c', {5, 1}, {3, 2, 1, 4, 5}); + + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {0, 1}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, reverse_sequense_test11) { + + auto input = NDArrayFactory::create('c', {1, 1, 5, 1}); + input.linspace(1); + std::vector data = {1, 0, 1, 0, 1}; + auto seqLengths = NDArrayFactory::create('c', {5}, data); + auto exp = NDArrayFactory::create('c', {1, 1, 5, 1}, {1, 2, 3, 4, 5}); + + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {1, 2}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, reverse_sequense_test12) { + + auto input = NDArrayFactory::create('c', {1, 1, 5, 1}); + input.linspace(1); + std::vector data = {3}; + auto seqLengths = NDArrayFactory::create('c', {1}, data); + auto exp = NDArrayFactory::create('c', {1, 1, 5, 1}, {3, 2, 1, 4, 5}); + + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {2, 0}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, reverse_sequense_test13) { + + auto input = NDArrayFactory::create('c', {1, 1, 5, 1}); + input.linspace(1); + std::vector data = {1}; + auto seqLengths = NDArrayFactory::create('c', {1}, data); + auto exp = NDArrayFactory::create('c', {1, 1, 5, 1}, {1, 2, 3, 4, 5}); + + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &seqLengths}, {}, {3, 0}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, reverse_sequense_test14) { + auto input = NDArrayFactory::create('c', {8, 8, 3, 2}, {0.09753360, 0.76124972, 0.24693797, 0.13813169, 0.33144656, 0.08299957, 0.67197708, 0.80659380, 0.98274191, 0.63566073, 0.21592326, 0.54902743, 0.54555996, 0.23407607, 0.11372584, 0.49965927, 0.15210842, 0.53268608, 0.38700677, 0.68832738, 0.37292716, 0.94616004, 0.77735792, 0.60803430, 0.61523204, 0.64298760, 0.26848351, 0.75015615, 0.28683049, 0.70937606, 0.06478678, 0.68985848, 0.55216783, 0.55382648, 0.34652863, 0.17261296, 0.54193264, 0.05176904, 0.82555761, 0.71106697, 0.04416722, 0.07653656, 0.01034390, 0.99430482, 0.59944390, 0.17973880, 0.36437840, 0.86383673, 0.45025550, 0.97136977, 0.13565978, 0.71567448, 0.92094825, 0.93536442, 0.93630291, 0.67277404, 0.93899264, 0.52422773, 0.44892176, 0.03127759, 0.85910449, 0.18252879, 0.72830945, 0.96736828, 0.89831575, 0.83437150, 0.59050780, 0.36145925, 0.16483070, 0.44021176, 0.76018652, 0.44227383, 0.13052339, 0.18204235, 0.99743733, 0.26885190, 0.87726522, 0.16396056, 0.94943412, 0.40016700, 0.65267938, 0.71073267, 0.40094733, 0.91182634, 0.05391789, 0.49520416, 0.24963864, 0.34847086, 0.74088617, 0.36115701, 0.63074210, 0.97423085, 0.42216846, 0.06326975, 0.07858702, 0.20586622, 0.28752144, 0.38146961, 0.83518735, 0.08207577, 0.82083487, 0.81665728, 0.33309570, 0.67563176, 0.98343578, 0.95919930, 0.66994391, 0.89296165, 0.34755773, 0.63166554, 0.18849320, 0.34828456, 0.98477707, 0.75163124, 0.83306004, 0.14203056, 0.01497920, 0.85727447, 0.71194544, 0.85654019, 0.86160433, 0.79580411, 0.47710411, 0.09318029, 0.31369071, 0.64122249, 0.58399725, 0.26706597, 0.05655339, 0.91025211, 0.30330468, 0.33142930, 0.05668627, 0.02936449, 0.12613087, 0.09960114, 0.16218074, 0.15088139, 0.31239040, 0.55980062, 0.34804391, 0.34941538, 0.61370555, 0.07022964, 0.59757058, 0.31189846, 0.25215345, 0.52546591, 0.55744218, 0.59485650, 0.60553664, 0.07536713, 0.55971796, 0.38764845, 0.20737843, 0.37989120, 0.18361641, 0.48636240, 0.06052657, 0.04241913, 0.66710351, 0.07007925, 0.59371493, 0.74479056, 0.84699625, 0.51210368, 0.12489571, 0.23371067, 0.27274571, 0.83306066, 0.75830824, 0.25963478, 0.87137718, 0.24418835, 0.05032742, 0.52076188, 0.47762345, 0.89829370, 0.34417708, 0.84705151, 0.08203183, 0.10632956, 0.78431292, 0.86441722, 0.36487598, 0.09833603, 0.85863594, 0.11010505, 0.11659283, 0.42500288, 0.02747301, 0.12359903, 0.01753431, 0.41160932, 0.47245979, 0.08268172, 0.21580773, 0.75770279, 0.19736489, 0.44461885, 0.33341706, 0.22519571, 0.31528710, 0.14802902, 0.64171939, 0.52643769, 0.19261234, 0.98032835, 0.15401656, 0.85274458, 0.66408502, 0.23212704, 0.74630026, 0.05713613, 0.49025892, 0.48418810, 0.59541513, 0.09243053, 0.93919152, 0.95357019, 0.52377729, 0.65963871, 0.47934951, 0.49919534, 0.34369898, 0.78211256, 0.13908708, 0.95754117, 0.84107746, 0.09126213, 0.42979124, 0.10295325, 0.34631257, 0.69448345, 0.41720536, 0.15282440, 0.74329854, 0.45775009, 0.12786280, 0.39830299, 0.20386769, 0.59703523, 0.94077086, 0.42255597, 0.80453309, 0.79757204, 0.28653229, 0.60175909, 0.55859623, 0.34318230, 0.63002770, 0.36533324, 0.89689906, 0.73236186, 0.61491989, 0.83787947, 0.67939463, 0.72016694, 0.77499849, 0.72428343, 0.34571059, 0.23143007, 0.20099338, 0.85583142, 0.73174191, 0.54284092, 0.20264181, 0.53037061, 0.30493131, 0.82279766, 0.58542432, 0.72632070, 0.18394258, 0.00608118, 0.23808232, 0.17007573, 0.75245459, 0.84990616, 0.38827634, 0.33809538, 0.01080317, 0.27250145, 0.81769542, 0.15323253, 0.71668395, 0.99427044, 0.11355576, 0.50511923, 0.60248266, 0.36610154, 0.99123140, 0.10519719, 0.18754650, 0.43232584, 0.25247084, 0.47968157, 0.88649124, 0.33588961, 0.92338319, 0.18808573, 0.79433656, 0.12074559, 0.02325163, 0.10117917, 0.83559239, 0.67213900, 0.67265260, 0.11917707, 0.76574855, 0.43842117, 0.28530411, 0.79648090, 0.47939640, 0.73564612, 0.41465671, 0.10995635, 0.20271728, 0.00521771, 0.22952055, 0.78271870, 0.12833592, 0.88639055, 0.76398188, 0.49533508, 0.85447872, 0.15937568, 0.92947480, 0.62705964, 0.85960084, 0.13435660, 0.81845809, 0.60715133, 0.83030708, 0.83071910, 0.38883408, 0.92033237, 0.46066239, 0.48806761, 0.50688779, 0.00654483, 0.32076493, 0.42367646, 0.07381865, 0.22801110, 0.26669388, 0.99691302, 0.12113623, 0.34373057, 0.98977921, 0.96225332, 0.90143562, 0.19559914, 0.08978307, 0.09687492, 0.59820890, 0.75527947, 0.67683355, 0.21847023, 0.29395619, 0.50477953, 0.07112842, 0.54090558, 0.68230725, 0.49713828, 0.41958965, 0.68013847, 0.47691765, 0.63269259, 0.94304095, 0.54587271, 0.72447569, 0.28913523, 0.75766936, 0.52965692, 0.96854824, 0.15589071, 0.84128672, 0.16337522, 0.05771034, 0.21556356, 0.12094140, 0.29721207, 0.00811008, 0.66184926}); + auto lengths = NDArrayFactory::create('c', {8}, {7, 2, 3, 5, 2, 1, 6, 4}); + auto e = NDArrayFactory::create('c', {8, 8, 3, 2}, {0.54193264, 0.05176904, 0.82555761, 0.71106697, 0.04416722, 0.07653656, 0.06478678, 0.68985848, 0.55216783, 0.55382648, 0.34652863, 0.17261296, 0.61523204, 0.64298760, 0.26848351, 0.75015615, 0.28683049, 0.70937606, 0.38700677, 0.68832738, 0.37292716, 0.94616004, 0.77735792, 0.60803430, 0.54555996, 0.23407607, 0.11372584, 0.49965927, 0.15210842, 0.53268608, 0.67197708, 0.80659380, 0.98274191, 0.63566073, 0.21592326, 0.54902743, 0.09753360, 0.76124972, 0.24693797, 0.13813169, 0.33144656, 0.08299957, 0.01034390, 0.99430482, 0.59944390, 0.17973880, 0.36437840, 0.86383673, 0.93630291, 0.67277404, 0.93899264, 0.52422773, 0.44892176, 0.03127759, 0.45025550, 0.97136977, 0.13565978, 0.71567448, 0.92094825, 0.93536442, 0.85910449, 0.18252879, 0.72830945, 0.96736828, 0.89831575, 0.83437150, 0.59050780, 0.36145925, 0.16483070, 0.44021176, 0.76018652, 0.44227383, 0.13052339, 0.18204235, 0.99743733, 0.26885190, 0.87726522, 0.16396056, 0.94943412, 0.40016700, 0.65267938, 0.71073267, 0.40094733, 0.91182634, 0.05391789, 0.49520416, 0.24963864, 0.34847086, 0.74088617, 0.36115701, 0.63074210, 0.97423085, 0.42216846, 0.06326975, 0.07858702, 0.20586622, 0.34755773, 0.63166554, 0.18849320, 0.34828456, 0.98477707, 0.75163124, 0.33309570, 0.67563176, 0.98343578, 0.95919930, 0.66994391, 0.89296165, 0.28752144, 0.38146961, 0.83518735, 0.08207577, 0.82083487, 0.81665728, 0.83306004, 0.14203056, 0.01497920, 0.85727447, 0.71194544, 0.85654019, 0.86160433, 0.79580411, 0.47710411, 0.09318029, 0.31369071, 0.64122249, 0.58399725, 0.26706597, 0.05655339, 0.91025211, 0.30330468, 0.33142930, 0.05668627, 0.02936449, 0.12613087, 0.09960114, 0.16218074, 0.15088139, 0.31239040, 0.55980062, 0.34804391, 0.34941538, 0.61370555, 0.07022964, 0.27274571, 0.83306066, 0.75830824, 0.25963478, 0.87137718, 0.24418835, 0.59371493, 0.74479056, 0.84699625, 0.51210368, 0.12489571, 0.23371067, 0.18361641, 0.48636240, 0.06052657, 0.04241913, 0.66710351, 0.07007925, 0.60553664, 0.07536713, 0.55971796, 0.38764845, 0.20737843, 0.37989120, 0.59757058, 0.31189846, 0.25215345, 0.52546591, 0.55744218, 0.59485650, 0.05032742, 0.52076188, 0.47762345, 0.89829370, 0.34417708, 0.84705151, 0.08203183, 0.10632956, 0.78431292, 0.86441722, 0.36487598, 0.09833603, 0.85863594, 0.11010505, 0.11659283, 0.42500288, 0.02747301, 0.12359903, 0.19736489, 0.44461885, 0.33341706, 0.22519571, 0.31528710, 0.14802902, 0.01753431, 0.41160932, 0.47245979, 0.08268172, 0.21580773, 0.75770279, 0.64171939, 0.52643769, 0.19261234, 0.98032835, 0.15401656, 0.85274458, 0.66408502, 0.23212704, 0.74630026, 0.05713613, 0.49025892, 0.48418810, 0.59541513, 0.09243053, 0.93919152, 0.95357019, 0.52377729, 0.65963871, 0.47934951, 0.49919534, 0.34369898, 0.78211256, 0.13908708, 0.95754117, 0.84107746, 0.09126213, 0.42979124, 0.10295325, 0.34631257, 0.69448345, 0.41720536, 0.15282440, 0.74329854, 0.45775009, 0.12786280, 0.39830299, 0.20386769, 0.59703523, 0.94077086, 0.42255597, 0.80453309, 0.79757204, 0.28653229, 0.60175909, 0.55859623, 0.34318230, 0.63002770, 0.36533324, 0.89689906, 0.73236186, 0.61491989, 0.83787947, 0.67939463, 0.72016694, 0.77499849, 0.72428343, 0.34571059, 0.23143007, 0.20099338, 0.85583142, 0.73174191, 0.54284092, 0.20264181, 0.53037061, 0.30493131, 0.82279766, 0.58542432, 0.72632070, 0.18394258, 0.00608118, 0.23808232, 0.17007573, 0.75245459, 0.84990616, 0.38827634, 0.33809538, 0.01080317, 0.27250145, 0.81769542, 0.15323253, 0.71668395, 0.99427044, 0.11355576, 0.50511923, 0.22952055, 0.78271870, 0.12833592, 0.88639055, 0.76398188, 0.49533508, 0.47939640, 0.73564612, 0.41465671, 0.10995635, 0.20271728, 0.00521771, 0.67265260, 0.11917707, 0.76574855, 0.43842117, 0.28530411, 0.79648090, 0.79433656, 0.12074559, 0.02325163, 0.10117917, 0.83559239, 0.67213900, 0.25247084, 0.47968157, 0.88649124, 0.33588961, 0.92338319, 0.18808573, 0.60248266, 0.36610154, 0.99123140, 0.10519719, 0.18754650, 0.43232584, 0.85447872, 0.15937568, 0.92947480, 0.62705964, 0.85960084, 0.13435660, 0.81845809, 0.60715133, 0.83030708, 0.83071910, 0.38883408, 0.92033237, 0.59820890, 0.75527947, 0.67683355, 0.21847023, 0.29395619, 0.50477953, 0.98977921, 0.96225332, 0.90143562, 0.19559914, 0.08978307, 0.09687492, 0.07381865, 0.22801110, 0.26669388, 0.99691302, 0.12113623, 0.34373057, 0.46066239, 0.48806761, 0.50688779, 0.00654483, 0.32076493, 0.42367646, 0.07112842, 0.54090558, 0.68230725, 0.49713828, 0.41958965, 0.68013847, 0.47691765, 0.63269259, 0.94304095, 0.54587271, 0.72447569, 0.28913523, 0.75766936, 0.52965692, 0.96854824, 0.15589071, 0.84128672, 0.16337522, 0.05771034, 0.21556356, 0.12094140, 0.29721207, 0.00811008, 0.66184926}); + + sd::ops::reverse_sequence op; + auto results = op.evaluate({&input, &lengths}, {}, {1, 0}); + ASSERT_EQ(Status::OK(), results.status()); + + auto z = results.at(0); + + ASSERT_EQ(e, *z); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, Test_TopK_0) { + auto x = NDArrayFactory::create('c', {2, 6}, {1.0, 1.0, 1.0, 1.0, 11.0, 3.0, 1.0, 1.0, 1.0, 14.0, 5.0, 6.0}); + auto expV = NDArrayFactory::create('c', {2, 1}, {11.0, 14.0}); + auto expI = NDArrayFactory::create('c', {2, 1}, {4, 3}); + + sd::ops::top_k op; + auto result = op.evaluate({&x}, {}, {1, 0}); // without sorting + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); + + auto v = result.at(0); + auto i = result.at(1); +/* + v->printShapeInfo("topK_0: shape v"); + expV.printShapeInfo("topK_0: shape expV"); + + i->printShapeInfo("topK_0: shape I"); + expI.printShapeInfo("topK_0: shape expI"); + + v->printIndexedBuffer("topK_0: v"); + expV.printIndexedBuffer("topK_0: expV"); + i->printIndexedBuffer("topK_0: i"); + expI.printIndexedBuffer("topK_0: expI"); +*/ + + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); + + ASSERT_TRUE(expI.isSameShape(i)); + ASSERT_TRUE(expI.equalsTo(i)); + // repeat res again + for (int cases = 0; cases < 100; ++cases) { + op.execute({&x}, std::vector{v, i}, {}, {1, 0}, {}); // without sorting + } + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, Test_TopK_1) { + auto x = NDArrayFactory::create('c', {2, 3}, {1.0f, 11.0f, 3.0f, 14.0f, 5.0f, 6.0f}); + auto expV = NDArrayFactory::create('c', {2, 1}, {11.0f, 14.0f}); + auto expI = NDArrayFactory::create('c', {2, 1}, {1, 0}); + + sd::ops::top_k op; + auto result = op.evaluate({&x}, {}, {1, 0}); // without sorting + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); + + auto v = result.at(0); + auto i = result.at(1); + +// v->printShapeInfo("topK_1: shape v"); +// expV.printShapeInfo("topK_1: shape expV"); + +// i->printShapeInfo("topK_1: shape I"); +// expI.printShapeInfo("topK_1: shape expI"); + +// v->printIndexedBuffer("topK_1: v"); +// expV.printIndexedBuffer("topK_1: expV"); +// i->printIndexedBuffer("topK_1: i"); +// expI.printIndexedBuffer("topK_1: expI"); + + + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); + + ASSERT_TRUE(expI.isSameShape(i)); + ASSERT_TRUE(expI.equalsTo(i)); + // repeat res again + for (int cases = 0; cases < 100; ++cases) { + op.execute({&x}, std::vector{v, i}, {}, {1, 0}, {}); // without sorting + } + +} + +/////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, Test_TopK_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}, {11.0, 3.0, 14.0, 5.0, + 6.0, 9.0, 3.5, 7.0, + 21.0, 3.0, 14.0, 15.0, + 6.0, 9.0, 3.5, 7.0, + 11.0, 13.0, 14.0, 5.0, + 16.0, 9.0, 13.5, 7.0 + } + ); +// <<<14.>,<9.>>, <<21.>,<9.>>, <<14.>,<16.>>> + auto expV = NDArrayFactory::create('c', {2, 3, 1}, {14.0f, 9.0f, + 21.0f, + 9.0f, 14.0f, + 16.0f + } + ); + + auto expI = NDArrayFactory::create('c', {2, 3, 1 }, {2, 1, 0, 1, 2, 0}); + + sd::ops::top_k op; + auto result = op.evaluate({&x}, {}, {1, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); + + auto v = result.at(0); + auto i = result.at(1); + +// v->printShapeInfo("shape v"); +// expV.printShapeInfo("shape expV"); + +// i->printShapeInfo("shape I"); +// expI.printShapeInfo("shape expI"); + +// v->printIndexedBuffer("v"); +// expV.printIndexedBuffer("expV"); +// i->printIndexedBuffer("i"); +// expI.printIndexedBuffer("expI"); + + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); + + ASSERT_TRUE(expI.isSameShape(i)); + ASSERT_TRUE(expI.equalsTo(i)); + + +} + +TEST_F(DeclarableOpsTests5, Test_TopK_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}, {11.0, 3.0, 14.0, 5.0, + 6.0, 9.0, 3.5, 7.0, + 21.0, 3.0, 14.0, 15.0, + 6.0, 9.0, 3.5, 7.0, + 11.0, 13.0, 14.0, 5.0, + 16.0, 9.0, 13.5, 7.0 + } + ); + + auto expV = NDArrayFactory::create('c', {2, 3, 2}, {14.0f, 11.0f, + 9.0f, 7.0f, + 21.0f, 15.0f, + 9.0f, 7.0f, + 14.0f, 13.0f, + 16.0f, 13.5f + } + ); + + auto expI = NDArrayFactory::create('c', {2, 3, 2 }, {2, 0, 1, 3, 0, 3, 1, 3, 2, 1, 0, 2}); + + sd::ops::top_k op; + auto result = op.evaluate({&x}, {}, {2, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); + + auto v = result.at(0); + auto i = result.at(1); + +// v->printShapeInfo("shape v"); +// expV.printShapeInfo("shape expV"); + +// i->printShapeInfo("shape I"); +// expI.printShapeInfo("shape expI"); + +// v->printIndexedBuffer("v"); +// expV.printIndexedBuffer("expV"); +// i->printIndexedBuffer("i"); +// expI.printIndexedBuffer("expI"); + + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); + + ASSERT_TRUE(expI.isSameShape(i)); + ASSERT_TRUE(expI.equalsTo(i)); + + +} + +TEST_F(DeclarableOpsTests5, Test_TopK_3_unsorted) { + auto x = NDArrayFactory::create('c', {2, 3, 4}, {11.0, 3.0, 14.0, 5.0, + 6.0, 9.0, 3.5, 7.0, + 21.0, 3.0, 14.0, 15.0, + 6.0, 9.0, 3.5, 7.0, + 11.0, 13.0, 14.0, 5.0, + 16.0, 9.0, 13.5, 7.0 + } + ); + + auto expV = NDArrayFactory::create('c', {2, 3, 2}, {11.0f, 14.0f, + 9.0f, 7.0f, + 21.0f, 15.0f, + 9.0f, 7.0f, + 13.0f, 14.0f, + 16.0f, 13.5f + } + ); + + auto expI = NDArrayFactory::create('c', {2, 3, 2 }, {0, 2, 1, 3, 0, 3, 1, 3, 1, 2, 0, 2}); + + sd::ops::top_k op; + auto result = op.evaluate({&x}, {}, {2}, {false}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); + + auto v = result.at(0); + auto i = result.at(1); + + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); + + ASSERT_TRUE(expI.isSameShape(i)); + ASSERT_TRUE(expI.equalsTo(i)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, Test_TopK_4) { + auto x = NDArrayFactory::create('c', {2, 3}, {1.0f, 11.0f, 3.0f, 14.0f, 5.0f, 6.0f}); + auto expV = NDArrayFactory::create('c', {2, 2}, {11.0f, 3.0f, 14.0f, 6.0f}); + auto expI = NDArrayFactory::create('c', {2, 2}, {1, 2, 0, 2}); + + sd::ops::top_k op; + auto result = op.evaluate({&x}, {}, {2, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); + + auto v = result.at(0); + auto i = result.at(1); + + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); + + ASSERT_TRUE(expI.isSameShape(i)); + ASSERT_TRUE(expI.equalsTo(i)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, Test_TopK_5) { + auto x = NDArrayFactory::create('f', {2, 3}, {1.1, 5.2, 3.1, 14.2, 11.1, 6.2}); + auto expV = NDArrayFactory::create('f', {2, 2}, {11.1, 14.2, 3.1, 6.2}); + auto expI = NDArrayFactory::create('f', {2, 2}, {2, 1, 1, 2}); + + sd::ops::top_k op; + auto result = op.evaluate({&x}, {}, {2, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); + + auto v = result.at(0); + auto i = result.at(1); + + ASSERT_TRUE(expV.isSameShape(v)); + ASSERT_TRUE(expV.equalsTo(v)); + + ASSERT_TRUE(expI.isSameShape(i)); + ASSERT_TRUE(expI.equalsTo(i)); + + +} + +/////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, Test_Moments_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4}, {11.0, 3.0, 14.0, 5.0, + 6.0, 9.0, 3.5, 7.0, + 21.0, 3.0, 14.0, 15.0, + 6.0, 9.0, 3.5, 7.0, + 11.0, 13.0, 14.0, 5.0, + 16.0, 9.0, 13.5, 7.0} + ); + + auto y = NDArrayFactory::create('c', {3}, {0, 1, 2}); + //auto expV('f', {6}, {1, 0, 0, 0, 0, 0 }); + + float expMean = 9.395833f; + float expDeviation = 22.4579f; +//Mean 9.395833 +//Deviance 22.4579 + + float inf = 1.e-5f; + + sd::ops::moments op; + auto result = op.evaluate({&x, &y}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); + + auto v = result.at(0); + auto d = result.at(1); + +// v->printIndexedBuffer("Result is "); +// d->printIndexedBuffer("Result is "); + + ASSERT_TRUE(v->isScalar()); + ASSERT_NEAR(expMean, v->e(0), inf); + ASSERT_NEAR(expDeviation, d->e(0), inf); + + +} + +TEST_F(DeclarableOpsTests5, Test_Moments_2) { + NDArray x('c', {2, 3, 4}, {11.0, 3.0, 14.0, 5.0, + 6.0, 9.0, 3.5, 7.0, + 21.0, 3.0, 14.0, 15.0, + 6.0, 9.0, 3.5, 7.0, + 11.0, 13.0, 14.0, 5.0, + 16.0, 9.0, 13.5, 7.0} + ); + + NDArray expV('c', {4}, {11.833333, 7.6666665, 10.416667, 7.6666665}); + NDArray expD('c', {4}, {28.472221, 12.888889, 23.951387, 11.555554}); + + sd::ops::moments op; + auto result = op.evaluate({&x}, {}, {0, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); + + auto v = result.at(0); + auto d = result.at(1); + + ASSERT_TRUE(v->isVector()); + ASSERT_TRUE(d->isVector()); + + ASSERT_TRUE(v->equalsTo(&expV)); + ASSERT_TRUE(d->equalsTo(&expD)); + + +} + +TEST_F(DeclarableOpsTests5, Test_Moments_3) { + auto x = NDArrayFactory::create('c', {2, 3, 4}, {11.0, 3.0, 14.0, 5.0, + 6.0, 9.0, 3.5, 7.0, + 21.0, 3.0, 14.0, 15.0, + 6.0, 9.0, 3.5, 7.0, + 11.0, 13.0, 14.0, 5.0, + 16.0, 9.0, 13.5, 7.0} + ); + + auto expV = NDArrayFactory::create('c', {3, 4}, { 8.5f, 6.f , 8.75f, 6.f, + 8.5f, 11.f, 8.75f, 6.f, + 18.5f, 6.f, 13.75f, 11.f}); + auto expD = NDArrayFactory::create('c', {3, 4}, { 6.25f, 9.f, 27.5625f, 1.f, + 6.25f, 4.f, 27.5625f, 1.f, + 6.25f, 9.f, 0.0625f, 16.f}); + + sd::ops::moments op; + auto result = op.evaluate({&x}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); + + auto v = result.at(0); + auto d = result.at(1); + + ASSERT_TRUE(v->isMatrix()); + ASSERT_TRUE(d->isMatrix()); + + ASSERT_TRUE(v->equalsTo(&expV)); + ASSERT_TRUE(d->equalsTo(&expD)); + + +} + +TEST_F(DeclarableOpsTests5, Test_Moments_4) { + + auto x = NDArrayFactory::create('f', {2, 3, 4}, {11.0f, 6.0f, 6.0f, 11.0f, 21.0f, 16.0f, 3.0f, 9.0f, 9.0f, 13.0f, 3.0f, 9.0f, + 14.0f, 3.5f, 3.5f, 14.0f, 14.0f, 13.5f, 5.0f, 7.0f, 7.0f, 5.0f, 15.0f, 7.0f}); + + + auto expV = NDArrayFactory::create('c', {3, 4}, { 8.5f, 6.f , 8.75f, 6.f, 8.5f, 11.f, 8.75f, 6.f, 18.5f, 6.f, 13.75f, 11.f}); + auto expD = NDArrayFactory::create('c', {3, 4}, { 6.25f, 9.f, 27.5625f, 1.f, 6.25f, 4.f, 27.5625f, 1.f, 6.25f, 9.f, 0.0625f, 16.f}); + + sd::ops::moments op; + auto result = op.evaluate({&x}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(2, result.size()); + + auto v = result.at(0); + auto d = result.at(1); + + ASSERT_TRUE(v->isMatrix()); + ASSERT_TRUE(d->isMatrix()); + + // v->printIndexedBuffer("v"); + // expV.printIndexedBuffer("expV"); + + // d->printIndexedBuffer("d"); + // expD.printIndexedBuffer("expD"); + + ASSERT_TRUE(v->equalsTo(&expV)); + ASSERT_TRUE(d->equalsTo(&expD)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, trace_test1) { + + auto input = NDArrayFactory::create('c', {3, 4, 5}); + input.linspace(1); + auto exp = NDArrayFactory::create('c', {3}, {40, 120, 200}); + NDArray matrix('c', {3, 3}, {1., 2., 3., 4., 5., 6., 7., 8., 9.}); + sd::ops::trace op; + auto results = op.evaluate({&input}, {}, {}); + auto output = results.at(0); + double traceM = matrix.getTrace(); + // nd4j_printf("Trace for matrix is %f\n", traceM); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + // exp.printIndexedBuffer("EXP TRACE"); + // output->printIndexedBuffer("OUT TRACE"); + ASSERT_TRUE(exp.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, trace_test2) { + + auto input = NDArrayFactory::create('c', {4, 5}); + input.linspace(1); + auto exp = NDArrayFactory::create(40.); + + sd::ops::trace op; + auto results = op.evaluate({&input}, {}, {}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, trace_test3) { + + auto input = NDArrayFactory::create('c', {1, 5}); + input.linspace(1); + auto exp = NDArrayFactory::create(1.); + + sd::ops::trace op; + auto results = op.evaluate({&input}, {}, {}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, trace_test4) { + + auto input = NDArrayFactory::create('c', {5, 1}); + input.linspace(1); + auto exp = NDArrayFactory::create(1.); + + sd::ops::trace op; + auto results = op.evaluate({&input}, {}, {}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, trace_test5) { + + auto input = NDArrayFactory::create('c', {3, 4, 5, 6}); + input.linspace(1); + auto exp = NDArrayFactory::create('c', {3, 4}, {75, 225, 375, 525, 675, 825, 975, 1125, 1275, 1425, 1575, 1725}); + + sd::ops::trace op; + auto results = op.evaluate({&input}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, random_shuffle_test1) { + + auto input = NDArrayFactory::create('c', {2, 2, 2}); + input.linspace(1); + NDArray exp1 = input.dup(); + NDArray exp2('c',{2,2,2}, {5,6,7,8, 1,2,3,4}, sd::DataType::DOUBLE); + + sd::ops::random_shuffle op; + auto results = op.evaluate({&input}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(output->equalsTo(exp1) || output->equalsTo(exp2)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, random_shuffle_test2) { + + auto input = NDArrayFactory::create('c', {1, 3, 2}); + input.linspace(1); + NDArray exp1 = input.dup(); + + sd::ops::random_shuffle op; + auto results = op.evaluate({&input}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(output->equalsTo(exp1)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, random_shuffle_test3) { + + auto input = NDArrayFactory::create('c', {3, 2, 1}); + input.linspace(1); + NDArray exp1 = input.dup(); + NDArray exp2('c',{3,2,1}, {1,2, 5,6, 3,4}, sd::DataType::DOUBLE); + NDArray exp3('c',{3,2,1}, {3,4, 1,2, 5,6}, sd::DataType::DOUBLE); + NDArray exp4('c',{3,2,1}, {3,4, 5,6, 1,2}, sd::DataType::DOUBLE); + NDArray exp5('c',{3,2,1}, {5,6, 1,2, 3,4}, sd::DataType::DOUBLE); + NDArray exp6('c',{3,2,1}, {5,6, 3,4, 1,2}, sd::DataType::DOUBLE); + + sd::ops::random_shuffle op; + auto results = op.evaluate({&input}, {}, {}, {}, {}, true); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(input.equalsTo(exp1) || input.equalsTo(exp2) || input.equalsTo(exp3) + || input.equalsTo(exp4) || input.equalsTo(exp5) || input.equalsTo(exp6)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, random_shuffle_test4) { + + auto input = NDArrayFactory::create('c', {3, 2, 1}); + input.linspace(1); + NDArray exp1 = input.dup(); + NDArray exp2('c',{3,2,1}, {1,2, 5,6, 3,4}, sd::DataType::DOUBLE); + NDArray exp3('c',{3,2,1}, {3,4, 1,2, 5,6}, sd::DataType::DOUBLE); + NDArray exp4('c',{3,2,1}, {3,4, 5,6, 1,2}, sd::DataType::DOUBLE); + NDArray exp5('c',{3,2,1}, {5,6, 1,2, 3,4}, sd::DataType::DOUBLE); + NDArray exp6('c',{3,2,1}, {5,6, 3,4, 1,2}, sd::DataType::DOUBLE); + + sd::ops::random_shuffle op; + auto results = op.evaluate({&input}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(output->equalsTo(exp1) || output->equalsTo(exp2) || output->equalsTo(exp3) + || output->equalsTo(exp4) || output->equalsTo(exp5) || output->equalsTo(exp6)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, random_shuffle_test5) { + auto input = NDArrayFactory::create('c', {4}); + input.linspace(1); + + sd::ops::random_shuffle op; + auto results = op.evaluate({&input}, {}, {}, {}, {}, false); + auto output = results.at(0); + // output->printBuffer(); + + ASSERT_EQ(Status::OK(), results.status()); + // ASSERT_TRUE(!output->equalsTo(input)); + + bool hasDublicates = false; + for(int i = 0; i < output->lengthOf() - 1; ++i) + for(int j = i+1; j < output->lengthOf(); ++j) + if(output->t(i) == output->t(j)) { + hasDublicates = true; + i = output->lengthOf(); + break; + } + ASSERT_TRUE(!hasDublicates); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, random_shuffle_test6) { + auto input = NDArrayFactory::create('c', {4,1,1}); + input.linspace(1); + + sd::ops::random_shuffle op; + auto results = op.evaluate({&input}, {}, {}, {}, {}, false); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + // ASSERT_TRUE(!output->equalsTo(input)); + + bool hasDublicates = false; + for(int i = 0; i < output->lengthOf() - 1; ++i) + for(int j = i+1; j < output->lengthOf(); ++j) + if(output->t(i) == output->t(j)) { + hasDublicates = true; + i = output->lengthOf(); + break; + } + ASSERT_TRUE(!hasDublicates); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, random_shuffle_test7) { + auto input = NDArrayFactory::create('c', {16010}); + input.linspace(1); + + sd::ops::random_shuffle op; + auto results = op.evaluate({&input}, {}, {}, {}, {}, false); + auto output = results.at(0); + // output->printBuffer(); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(!output->equalsTo(input)); + + auto vec1 = input.getBufferAsVector(); + auto vec2 = output->getBufferAsVector(); + std::sort(vec2.begin(), vec2.end()); + ASSERT_TRUE(std::equal(vec1.begin(), vec1.end(), vec2.begin())); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, random_shuffle_test8) { + auto input = NDArrayFactory::create('c', {1,4,1}); + input.linspace(1); + NDArray inCopy = input.dup(); + + sd::ops::random_shuffle op; + auto results = op.evaluate({&input}, {}, {}, {}, {}, false); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(input.equalsTo(inCopy)); + +} + +TEST_F(DeclarableOpsTests5, random_shuffle_test9) { + + auto x = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + auto z = x.ulike(); + + sd::ops::random_shuffle op; + auto status = op.execute({&x}, {&z}); + ASSERT_EQ(Status::OK(), status); + + auto vec = z.getBufferAsVector(); + std::sort(vec.begin(), vec.end()); + ASSERT_EQ(std::vector({1, 2, 3, 4}), vec); +} + +//////////////////////////////////////////////////////////////////////////////////////// + +TEST_F(DeclarableOpsTests5, EmbeddingLookup_1) { + + auto x = NDArrayFactory::create('c', {3, 4, 2}, {10, 20, 11, 21, 12, 22, 13, 23, + 14, 24, 15, 25, 16, 26, 17, 27, + 18, 28, 19, 29, 20, 30, 21, 31}); + + auto y = NDArrayFactory::create({1, 1, 1, 0, 0, 0, 2, 2, 2}); + auto exp = NDArrayFactory::create('c', {9, 4, 2}, {14, 24, 15, 25, 16, 26, 17, 27, 14, 24, 15, 25, + 16, 26, 17, 27, 14, 24, 15, 25, 16, 26, 17, 27, + 10, 20, 11, 21, 12, 22, 13, 23, 10, 20, 11, 21, + 12, 22, 13, 23, 10, 20, 11, 21, 12, 22, 13, 23, + 18, 28, 19, 29, 20, 30, 21, 31, 18, 28, 19, 29, + 20, 30, 21, 31, 18, 28, 19, 29, 20, 30, 21, 31}); + + // y.printShapeInfo("y shape"); + // y.printIndexedBuffer("y buffer"); + + sd::ops::embedding_lookup op; + auto result = op.evaluate({&x, &y}, {}, {0}); + auto output = result.at(0); + // x.printShapeInfo("Input"); + output->printShapeInfo("Output"); + exp.printShapeInfo("Expected"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + //output->printIndexedBuffer("Output"); + //exp.printIndexedBuffer("Expect"); + + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +TEST_F(DeclarableOpsTests5, EmbeddingLookup_2) { + + auto x = NDArrayFactory::create('c', {3, 4, 2}, {10, 20, 30, 40, 50, 60, + 70, 80, 90, 10, 11, 12, + 13, 14, 15, 16, 17, 18, + 19, 20, 21, 22, 23, 24}); + //1, 0, 1, 0, 1, 0 + auto y = NDArrayFactory::create({1, 0, 1, 0, 1, 0}); + auto exp = NDArrayFactory::create('c', {6, 4, 2}, {90, 10, 11, 12, 13, 14, + 15, 16, 10, 20, 30, 40, + 50, 60, 70, 80, 90, 10, + 11, 12, 13, 14, 15, 16, + 10, 20, 30, 40, 50, 60, + 70, 80, 90, 10, 11, 12, + 13, 14, 15, 16, 10, 20, + 30, 40, 50, 60, 70, 80}); + + // y.printShapeInfo("y shape"); + // y.printIndexedBuffer("y buffer"); + + sd::ops::embedding_lookup op; + auto result = op.evaluate({&x, &y}, {}, {0}); + auto output = result.at(0); + // x.printShapeInfo("Input"); + // output->printShapeInfo("Output"); + // exp.printShapeInfo("Expected"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + // output->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); + + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +TEST_F(DeclarableOpsTests5, EmbeddingLookup_3) { + + + auto y = NDArrayFactory::create('c', {3,2}, {5, 4, 4, 5, 3, 3}); + auto exp = NDArrayFactory::create('c', {6, 3, 3}, { + 6, 20, 11, 21, 12, 22, 13, 23, 14, + 5, 20, 11, 21, 12, 22, 13, 23, 14, + 5, 20, 11, 21, 12, 22, 13, 23, 14, + 6, 20, 11, 21, 12, 22, 13, 23, 14, + 4, 20, 11, 21, 12, 22, 13, 23, 14, + 4, 20, 11, 21, 12, 22, 13, 23, 14 }); + + // y.printShapeInfo("y shape"); + // y.printIndexedBuffer("y buffer"); + auto p1 = NDArrayFactory::create('c', {3,3}, {1, 20, 11, 21, 12, 22, 13, 23, 14}); + auto p2 = NDArrayFactory::create('c', {3,3}, {2, 20, 11, 21, 12, 22, 13, 23, 14}); + auto p3 = NDArrayFactory::create('c', {3,3}, {3, 20, 11, 21, 12, 22, 13, 23, 14}); + auto p4 = NDArrayFactory::create('c', {3,3}, {4, 20, 11, 21, 12, 22, 13, 23, 14}); + auto p5 = NDArrayFactory::create('c', {3,3}, {5, 20, 11, 21, 12, 22, 13, 23, 14}); + auto p6 = NDArrayFactory::create('c', {3,3}, {6, 20, 11, 21, 12, 22, 13, 23, 14}); + auto p7 = NDArrayFactory::create('c', {3,3}, {7, 20, 11, 21, 12, 22, 13, 23, 14}); + auto p8 = NDArrayFactory::create('c', {3,3}, {8, 20, 11, 21, 12, 22, 13, 23, 14}); + +// res = tf.nn.embedding_lookup((p1, p2, p3, p4, p5, p6, p7), ids, 'mod') + + sd::ops::embedding_lookup op; + auto result = op.evaluate({&p1, &p2, &p3, &p4, &p5, &p6, &p7, &p8, &y}, {}, {1}); + auto output = result.at(0); + // x.printShapeInfo("Input"); + // output->printIndexedBuffer("Output"); + // exp.printShapeInfo("Expected"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + // output->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); + + ASSERT_TRUE(exp.equalsTo(output)); + + +} +/* @Test + public void testDynamicPartition(){ + INDArray data = Nd4j.createFromArray(2, 1, 2, 0); + INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0); + INDArray[] out = Nd4j.exec(DynamicCustomOp.builder("dynamic_partition") + .addOutputs(Nd4j.createUninitialized(DataType.INT, 2), Nd4j.createUninitialized(DataType.INT, 1), Nd4j.createUninitialized(DataType.INT, 1)) + .addIntegerArguments(3) //3 partitions + .addInputs(data, partitions).build()); + + INDArray exp0 = Nd4j.createFromArray(2, 0); + INDArray exp1 = Nd4j.createFromArray(2); + INDArray exp2 = Nd4j.createFromArray(1); + + assertEquals(exp0, out[0]); //Usually just gives [0,0] + assertEquals(exp1, out[1]); + assertEquals(exp2, out[2]); + }*/ +TEST_F(DeclarableOpsTests5, DynamicPartition_01) { + + auto x = NDArrayFactory::create({2,1,2,0}); + + auto y = NDArrayFactory::create({0,2,1,0}); + + int numPartition = 3; + std::vector exp( { NDArrayFactory::create('c', {2}, {2, 0}), + NDArrayFactory::create('c', {1}, {2}), + NDArrayFactory::create('c', {1}, {1})}); + + sd::ops::dynamic_partition op; + auto result = op.evaluate({&x, &y}, {}, {numPartition}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(result.size(), numPartition); // result has the same size as given param 4 + + for (int e = 0; e < result.size(); e++) { + auto output = result.at(e); + // output->printShapeInfo("Output shape> "); + // output->printIndexedBuffer("Output data> "); + ASSERT_TRUE(exp[e].isSameShape(output)); + ASSERT_TRUE(exp[e].equalsTo(output)); + } + + +} + +TEST_F(DeclarableOpsTests5, DynamicPartition_1) { + + auto x = NDArrayFactory::create('c', {3, 4, 2}, {10, 20, 11, 21, 12, 22, + 13, 23, 14, 24, 15, 25, 16, 26, 17, 27, + 18, 28, 19, 29, 20, 30, 21, 31}); + + auto y = NDArrayFactory::create('c', {3, 4, 2}, {0, 0, 0, 0, 0, 0, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 1, 1, 1, 1, 1, 1, 1, 1 + } + ); +/* auto y = NDArrayFactory::create('c', {3, 4}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, 2.f, + 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f + } + ); +*/ + int numPartition = 3; + std::vector exp( { NDArrayFactory::create('c', {6}, {10, 20, 11, 21, 12, 22}), + NDArrayFactory::create('c', {8}, {18, 28, 19, 29, 20, 30, 21, 31}), + NDArrayFactory::create('c', {10}, {13, 23, 14, 24, 15, 25, 16, 26, 17, 27})}); + + sd::ops::dynamic_partition op; + auto result = op.evaluate({&x, &y}, {}, {numPartition}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(result.size(), numPartition); // result has the same size as given param 4 + + for (int e = 0; e < result.size(); e++) { + auto output = result.at(e); + // output->printShapeInfo("Output shape> "); + // output->printIndexedBuffer("Output data> "); + ASSERT_TRUE(exp[e].isSameShape(output)); + ASSERT_TRUE(exp[e].equalsTo(output)); + } + + +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST_F(DeclarableOpsTests5, DynamicPartition_2) { + + auto x = NDArrayFactory::create('c', {2, 4}, {0.1f, -1.f, 5.2f, 4.3f, -1.f, 7.4f, 0.0f, -2.2f}); + auto y = NDArrayFactory::create('c', {2, 4}, {1, 2, 1, 2, 1, 2, 3, 0}); + + std::vector exp( {NDArrayFactory::create('c', {1}, {-2.2}), + NDArrayFactory::create('c', {3}, {0.1, 5.2, -1.}), + NDArrayFactory::create('c', {3}, {-1., 4.3, 7.4}), + NDArrayFactory::create('c', {1}, {0.0})}); + + sd::ops::dynamic_partition op; + int numPartition = 4; + auto result = op.evaluate({&x, &y}, {}, {numPartition}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(result.size(), numPartition); // result has the same size as given param 4 + + for (int e = 0; e < result.size(); e++) { + auto output = result.at(e); + + ASSERT_TRUE(exp[e].isSameShape(output)); + ASSERT_TRUE(exp[e].equalsTo(output)); + } + + +} + + +TEST_F(DeclarableOpsTests5, DynamicPartition_3) { + + auto x = NDArrayFactory::create('c', {2, 4}, {0.1f, -1.f, 5.2f, 4.3f, -1.f, 7.4f, 0.0f, -2.2f}); + auto y = NDArrayFactory::create('c', {2, 4}, {0, 1, 0, 2, 0, 2, 3, 0}); + + std::vector exp( {NDArrayFactory::create({0.1f, 5.2f, -1.f, -2.2f}), + NDArrayFactory::create('c', {1}, {-1.f}), + NDArrayFactory::create({4.3f, 7.4f}), + NDArrayFactory::create('c', {1}, {0.0f})}); + + sd::ops::dynamic_partition op; + int numPartition = 4; + auto result = op.evaluate({&x, &y}, {}, {numPartition}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(result.size(), numPartition); // result has the same size as given param 4 + + for (int e = 0; e < result.size(); e++) { + auto output = result.at(e); + if (output) + { + // output->printShapeInfo("Output shape> "); + // exp[e].printShapeInfo("Expected shape> "); + // output->printIndexedBuffer("Output data> "); + + ASSERT_TRUE(exp[e].isSameShape(output)); + ASSERT_TRUE(exp[e].equalsTo(output)); + } + else + { + ASSERT_TRUE(exp[e].lengthOf() == 0); + } + } + + +} + +TEST_F(DeclarableOpsTests5, DynamicStitch_empty_1) { + auto i0 = NDArrayFactory::create('c', {2}, {2, 3}); + auto i1 = NDArrayFactory::empty(); + auto i2 = NDArrayFactory::create('c', {2}, {0, 1}); + + auto d0 = NDArrayFactory::create('c', {2, 5}, {0.085571885,0.7937801,0.65908563,0.55552566,0.15962744,0.7787856,0.80119777,0.72437465,0.23089433,0.72714126}); + auto d1 = NDArrayFactory::empty(); + auto d2 = NDArrayFactory::create('c', {2, 5}, {0.94414854,0.5956861,0.8668989,0.3502196,0.5100082,0.061725974,0.6621324,0.034165382,0.32576954,0.51917326}); + + sd::ops::dynamic_stitch op; + auto result = op.evaluate({&i0, &i1, &i2, &d0, &d1, &d2}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + +} + +TEST_F(DeclarableOpsTests5, DynamicStitch_empty_2) { + auto i0 = NDArrayFactory::create('c', {2}, {2, 3}); + auto i1 = NDArrayFactory::create('c', {0}); + auto i2 = NDArrayFactory::create('c', {2}, {0, 1}); + + auto d0 = NDArrayFactory::create('c', {2, 5}, {0.085571885,0.7937801,0.65908563,0.55552566,0.15962744,0.7787856,0.80119777,0.72437465,0.23089433,0.72714126}); + auto d1 = NDArrayFactory::create('c', {0, 5}); + auto d2 = NDArrayFactory::create('c', {2, 5}, {0.94414854,0.5956861,0.8668989,0.3502196,0.5100082,0.061725974,0.6621324,0.034165382,0.32576954,0.51917326}); + + sd::ops::dynamic_stitch op; + auto result = op.evaluate({&i0, &i1, &i2, &d0, &d1, &d2}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST_F(DeclarableOpsTests5, DynamicStitch_1) { + + auto x1 = NDArrayFactory::create({1, 3, 5, 0}); + auto x2 = NDArrayFactory::create({2, 4}); + auto y2 = NDArrayFactory::create({-1., -1.}); + auto y1 = NDArrayFactory::create({0.1f, 5.2f, 4.3f, 7.4f}); + + + auto exp = NDArrayFactory::create({7.4f, 0.1f, -1.f, 5.2f, -1.f, 4.3f}); + + sd::ops::dynamic_stitch op; + auto result = op.evaluate({&x1, &x2, &y1, &y2}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST_F(DeclarableOpsTests5, DynamicStitch_2) { + + auto x1 = NDArrayFactory::create({1, 3}); + auto x2 = NDArrayFactory::create({5, 0, 2, 4}); + auto y1 = NDArrayFactory::create({-1.f, -1.f}); + auto y2 = NDArrayFactory::create({0.1f, 5.2f, 4.3f, 7.4f}); + + + auto exp = NDArrayFactory::create({5.2f, -1.f, 4.3f, -1.f, 7.4f, 0.1f}); + + sd::ops::dynamic_stitch op; + auto result = op.evaluate({&x1, &x2, &y1, &y2}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + + // output->printShapeInfo("Output shape> "); + // exp.printShapeInfo("Expected shape> "); + // output->printIndexedBuffer("Output data> "); + // exp.printIndexedBuffer("Expected res>"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, fusedBatchNorm_test1) { + + auto x = NDArrayFactory::create('c', {2, 2, 3, 4}); + x.linspace(1); + auto scale = NDArrayFactory::create('c', {4}); + + scale = 0.5; + auto offset = NDArrayFactory::create('c', {4}); + offset = 2.; + auto expY = NDArrayFactory::create('c', {2, 2, 3, 4}, {1.20337462, 1.20337462, 1.20337462, 1.20337462, 1.34821558, 1.34821558, 1.34821558, 1.34821558, 1.49305654, 1.49305654, 1.49305654, 1.49305654, 1.63789749, 1.63789749, 1.63789749, 1.63789749, 1.78273857, 1.78273857, 1.78273857, 1.78273857, 1.92757952, 1.92757952, 1.92757952, 1.92757952, 2.0724206 , 2.0724206 , 2.0724206 , 2.0724206 , 2.21726155, 2.21726155, 2.21726155, 2.21726155, 2.36210251, 2.36210251, 2.36210251, 2.36210251, 2.50694346, 2.50694346, 2.50694346, 2.50694346, 2.65178442, 2.65178442, 2.65178442, 2.65178442, 2.79662538, 2.79662538, 2.79662538, 2.79662538}); + auto expBatchMean = NDArrayFactory::create('c', {4}, {23., 24., 25., 26.}); + auto expBatchVar = NDArrayFactory::create('c', {4}, {208.00001526, 208.00001526, 208.00001526, 208.00001526}); + + + sd::ops::fused_batch_norm op; + auto results = op.evaluate({&x, &scale, &offset}, {}, {0,1}); + auto y = results.at(0); + auto batchMean = results.at(1); + auto batchVar = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expY.isSameShape(y)); + ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); + ASSERT_TRUE(expBatchVar.isSameShape(batchVar)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, fusedBatchNorm_test2) { + + auto x = NDArrayFactory::create('c', {2, 2, 3, 4}); + x.linspace(1); + + auto scale = NDArrayFactory::create('c', {4}); + + scale = 0.5; + auto offset = NDArrayFactory::create('c', {4}); + offset = 2.; + auto expY = NDArrayFactory::create('c', {2, 2, 3, 4}, {1.20347691, 1.20347691, 1.20347691, 1.20347691, 1.34829926, 1.34829926, 1.34829926, 1.34829926, 1.49312162, 1.49312162, 1.49312162, 1.49312162, 1.6379441 , 1.6379441 , 1.6379441 , 1.6379441 , 1.78276646, 1.78276646, 1.78276646, 1.78276646, 1.92758882, 1.92758882, 1.92758882, 1.92758882, 2.0724113 , 2.0724113 , 2.0724113 , 2.0724113 , 2.21723366, 2.21723366, 2.21723366, 2.21723366, 2.36205602, 2.36205602, 2.36205602, 2.36205602, 2.50687838, 2.50687838, 2.50687838, 2.50687838, 2.65170074, 2.65170074, 2.65170074, 2.65170074, 2.79652309, 2.79652309, 2.79652309, 2.79652309}); + auto expBatchMean = NDArrayFactory::create('c', {4}, {23., 24., 25., 26.}); + auto expBatchVar = NDArrayFactory::create('c', {4}, {208.00001526, 208.00001526, 208.00001526, 208.00001526}); + + sd::ops::fused_batch_norm op; + auto results = op.evaluate({&x, &scale, &offset}, {0.05}, {0,1}); + auto y = results.at(0); + auto batchMean = results.at(1); + auto batchVar = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expY.isSameShape(y)); + ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); + ASSERT_TRUE(expBatchVar.isSameShape(batchVar)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, fusedBatchNorm_test3) { + + auto x = NDArrayFactory::create('c', {2, 4, 2, 3}); + x.linspace(1); + + auto scale = NDArrayFactory::create('c', {4}); + + scale = 0.5; + auto offset = NDArrayFactory::create('c', {4}); + offset = 2.; + auto expY = NDArrayFactory::create('c', {2, 4, 2, 3}, {1.20337462, 1.20337462, 1.20337462, 1.20337462, 1.34821558, 1.34821558, 1.34821558, 1.34821558, 1.49305654, 1.49305654, 1.49305654, 1.49305654, 1.63789749, 1.63789749, 1.63789749, 1.63789749, 1.78273857, 1.78273857, 1.78273857, 1.78273857, 1.92757952, 1.92757952, 1.92757952, 1.92757952, 2.0724206 , 2.0724206 , 2.0724206 , 2.0724206 , 2.21726155, 2.21726155, 2.21726155, 2.21726155, 2.36210251, 2.36210251, 2.36210251, 2.36210251, 2.50694346, 2.50694346, 2.50694346, 2.50694346, 2.65178442, 2.65178442, 2.65178442, 2.65178442, 2.79662538, 2.79662538, 2.79662538, 2.79662538}); + auto expBatchMean = NDArrayFactory::create('c', {4}, {23., 24., 25., 26.}); + auto expBatchVar = NDArrayFactory::create('c', {4}, {208.00001526, 208.00001526, 208.00001526, 208.00001526}); + + sd::ops::fused_batch_norm op; + auto results = op.evaluate({&x, &scale, &offset}, {}, {1,1}); + auto y = results.at(0); + auto batchMean = results.at(1); + auto batchVar = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expY.isSameShape(y)); + ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); + ASSERT_TRUE(expBatchVar.isSameShape(batchVar)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, fusedBatchNorm_test4) { + + auto x = NDArrayFactory::create('c', {2, 2, 3, 4}); + x.linspace(1); + std::vector shape = {4}; + auto scale = NDArrayFactory::create('c', shape); + auto offset = NDArrayFactory::create('c', shape); + auto mean = NDArrayFactory::create('c', shape); + auto variance = NDArrayFactory::create('c', shape); + + scale = 0.5; + offset = 2.; + mean = 25.; + variance = 5.; + + auto expY = NDArrayFactory::create('c', {2, 2, 3, 4}, {-3.36602688, -3.14244223, -2.91885757, -2.6952734 , -2.47168875, -2.24810457, -2.02451992, -1.80093551, -1.57735109, -1.35376668, -1.13018227, -0.90659785, -0.68301344, -0.45942879, -0.23584437, -0.01225996, 0.21132445, 0.43490887, 0.65849328, 0.88207781, 1.10566223, 1.32924664, 1.55283117, 1.77641559, 2. , 2.22358441, 2.44716883, 2.67075348, 2.89433765, 3.11792231, 3.34150672, 3.56509113, 3.78867555, 4.01225996, 4.23584461, 4.45942879, 4.68301344, 4.90659809, 5.13018227, 5.35376644, 5.57735109, 5.80093575, 6.02451992, 6.24810457, 6.47168875, 6.6952734 , 6.91885757, 7.14244223}); + auto expBatchMean = NDArrayFactory::create('c', shape, {0., 0., 0., 0.}); + auto expBatchVar = NDArrayFactory::create('c', shape, {0., 0., 0., 0.}); + + + sd::ops::fused_batch_norm op; + auto results = op.evaluate({&x, &scale, &offset}, {}, {0,1}); + auto y = results.at(0); + auto batchMean = results.at(1); + auto batchVar = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expY.isSameShape(y)); + ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); + ASSERT_TRUE(expBatchVar.isSameShape(batchVar)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, fusedBatchNorm_test5) { + + auto x = NDArrayFactory::create('c', {2, 2, 3, 4}); + x.linspace(1); + std::vector shape = {4}; + auto scale = NDArrayFactory::create('c', shape); + auto offset = NDArrayFactory::create('c', shape); + auto mean = NDArrayFactory::create('c', shape); + auto variance = NDArrayFactory::create('c', shape); + + scale = 0.5; + offset = 2.; + mean = 25.; + variance = 5.; + + auto expY = NDArrayFactory::create('c', {2, 2, 3, 4}, {-3.33992958e+00, -3.11743259e+00, -2.89493513e+00, -2.67243814e+00, -2.44994116e+00, -2.22744417e+00, -2.00494719e+00, -1.78244996e+00, -1.55995297e+00, -1.33745599e+00, -1.11495876e+00, -8.92461777e-01, -6.69964790e-01, -4.47467566e-01, -2.24970579e-01, -2.47359276e-03, 2.20023513e-01, 4.42520618e-01, 6.65017605e-01, 8.87514710e-01, 1.11001182e+00, 1.33250880e+00, 1.55500591e+00, 1.77750289e+00, 2.00000000e+00, 2.22249699e+00, 2.44499421e+00, 2.66749120e+00, 2.88998818e+00, 3.11248541e+00, 3.33498240e+00, 3.55747938e+00, 3.77997637e+00, 4.00247383e+00, 4.22497082e+00, 4.44746780e+00, 4.66996479e+00, 4.89246178e+00, 5.11495876e+00, 5.33745575e+00, 5.55995274e+00, 5.78244972e+00, 6.00494719e+00, 6.22744417e+00, 6.44994116e+00, 6.67243814e+00, 6.89493513e+00, 7.11743259e+00}); + auto expBatchMean = NDArrayFactory::create('c', shape, {0., 0., 0., 0.}); + auto expBatchVar = NDArrayFactory::create('c', shape, {0., 0., 0., 0.}); + + + sd::ops::fused_batch_norm op; + auto results = op.evaluate({&x, &scale, &offset}, {0.05}, {0,1}); + auto y = results.at(0); + auto batchMean = results.at(1); + auto batchVar = results.at(2); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expY.isSameShape(y)); + ASSERT_TRUE(expBatchMean.isSameShape(batchMean)); + ASSERT_TRUE(expBatchVar.isSameShape(batchVar)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, confusion_matrix_test1) { + + auto labels = NDArrayFactory::create('c', {1, 3}, {1, 2, 4}); + auto predictions = NDArrayFactory::create('c', {1, 3}, {2, 2, 4}); + auto expected = NDArrayFactory::create('c', {5, 5}, {0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}); + + sd::ops::confusion_matrix op; + auto results = op.evaluate({&labels, &predictions}, {}, {}); + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, confusion_matrix_test2) { + + auto labels = NDArrayFactory::create('c', {1, 2}, {1, 2}); + auto predictions = NDArrayFactory::create('c', {1, 2}, {0, 2}); + auto expected = NDArrayFactory::create('c', {3, 3}, {0, 0, 0, 1, 0, 0, 0, 0, 1}); + + sd::ops::confusion_matrix op; + auto results = op.evaluate({&labels, &predictions}, {}, {3}); + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, confusion_matrix_test3) { + + auto labels = NDArrayFactory::create('c', {1, 2}, {1, 2}); + auto predictions = NDArrayFactory::create('c', {1, 2}, {0, 2}); + auto weights = NDArrayFactory::create('c', {1, 2}, {100, 200}); + auto expected = NDArrayFactory::create('c', {3, 3}, {0, 0, 0, 100, 0, 0, 0, 0, 200}); + + sd::ops::confusion_matrix op; + auto results = op.evaluate({&labels, &predictions, &weights}, {}, {3}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, confusion_matrix_test4) { + + auto labels = NDArrayFactory::create('c', {1, 2}, {1, 2}); + auto predictions = NDArrayFactory::create('c', {1, 2}, {0, 2}); + auto weights = NDArrayFactory::create('c', {1, 2}, {100, 200}); + auto expected = NDArrayFactory::create('c', {3, 3}, {0, 0, 0, 100, 0, 0, 0, 0, 200}); + + sd::ops::confusion_matrix op; + auto results = op.evaluate({&labels, &predictions, &weights}, {}, {3}, {}, {sd::DataType::DOUBLE}); + auto output = results.at(0); + + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +/////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, ZeroFraction_1) { + + auto x = NDArrayFactory::create('c', {3, 4, 2}, {0, 20, 30, 0, 50, 0, + 70, 0, 90, 0, 11, 12, + 13, 14, 15, 16, 17, 18, + 19, 0, 21, 22, 23, 24}); + + sd::ops::zero_fraction op; + auto res = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(Status::OK(), res.status()); + ASSERT_TRUE(res.at(0)->isScalar()); + ASSERT_EQ(res.at(0)->e(0), 0.25); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, ZeroFraction_2) { + + auto x = NDArrayFactory::create('c', {2, 2, 2}, {5.5, 0., 0.3, 5.5, 8.6, 0., 0., 0.4}); + + sd::ops::zero_fraction op; + auto res = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(Status::OK(), res.status()); + ASSERT_TRUE(res.at(0)->isScalar()); + ASSERT_EQ(res.at(0)->e(0), 0.375); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, ZeroFraction_3) { + + auto x = NDArrayFactory::create('f', {2, 2, 2}, {5.5, 0., 0.3, 5.5, 8.6, 0., 0., 0.4}); + + sd::ops::zero_fraction op; + auto res = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(Status::OK(), res.status()); + ASSERT_TRUE(res.at(0)->isScalar()); + ASSERT_EQ(res.at(0)->e(0), 0.375); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, XWPlusB_1) { + + auto x = NDArrayFactory::create('c', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); + auto y = NDArrayFactory::create('c', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); + auto b = NDArrayFactory::create({ 100.f, 200.f }); + + auto exp = NDArrayFactory::create('c', { 2,2 }, { 173.f, 264.f, 310.f, 279.f }); + + sd::ops::xw_plus_b op; + auto result = op.evaluate({ &x, &y, &b }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, XWPlusB_2) { + + auto x = NDArrayFactory::create('c', { 1, 2 }, { 1.f, 11.f }); + auto y = NDArrayFactory::create('c', { 2, 3 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); + auto b = NDArrayFactory::create({ 100.f, 200.f, 300.f }); + + auto exp = NDArrayFactory::create('c', { 1, 3 }, { 166.f, 269.f, 326.f }); + + sd::ops::xw_plus_b op; + auto result = op.evaluate({ &x, &y, &b }, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, XWPlusB_3) { + + auto x = NDArrayFactory::create('c', { 1, 2 }, { 1.f, 11.f }); + auto y = NDArrayFactory::create('c', { 2, 1 }, { 11.f, 3.f }); + auto b = NDArrayFactory::create('c', { 1 }, { 200.f }); + + auto exp = NDArrayFactory::create('c', { 1,1 }, { 244.f }); + + sd::ops::xw_plus_b op; + auto result = op.evaluate({ &x, &y, &b }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, XWPlusB_4) { + + auto x = NDArrayFactory::create('f', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); + auto y = NDArrayFactory::create('f', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); + auto b = NDArrayFactory::create({ 100.f, 200.f }); + + auto exp = NDArrayFactory::create('f', { 2,2 }, { 140.f, 287.f, 233.f, 351.f }); + + sd::ops::xw_plus_b op; + auto result = op.evaluate({ &x, &y, &b }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, XWPlusB_5) { + + auto x = NDArrayFactory::create('c', { 2,3 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); + auto y = NDArrayFactory::create('c', { 3,2 }, { 11.f, 3.f, 4.f, 5.f, 6.f, 2.f }); + + y = y.transpose(); + + auto b = NDArrayFactory::create({ 100.f, 200.f }); + + auto exp = NDArrayFactory::create('c', { 2,2 }, { 173.f, 264.f, 310.f, 279.f }); + + + sd::ops::xw_plus_b op; + auto result = op.evaluate({ &x, &y, &b }, {}, { 1 }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, XWPlusB_6) { + + auto x = NDArrayFactory::create('c', { 3, 2 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); + auto y = NDArrayFactory::create('c', { 2, 1 }, { 11.f, 3.f }); + + auto b = NDArrayFactory::create('c', { 1 }, { 100.f }); + + auto exp = NDArrayFactory::create('c', { 3, 1 }, { 144.f, 175.f, 173.f }); + + sd::ops::xw_plus_b op; + auto result = op.evaluate({ &x, &y, &b }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, XWPlusB_7) { + + auto x = NDArrayFactory::create('c', { 3, 4 }, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f, 1.f, 11.f, 3.f, 14.f, 5.f, 6.f }); + auto y = NDArrayFactory::create('c', { 4, 5 }, { 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 11.f, 3.f, 3.f, 11.f, 3.f, 11.f }); + + auto b = NDArrayFactory::create('c', { 5 }, { 100.f, 200.f, 300.f, 400.f, 500.f }); + + auto exp = NDArrayFactory::create('c', { 3, 5 }, { 219.f, 375.f, 531.f, 575.f, 731.f, 217.f, 317.f, 505.f, 517.f, 705.f, 248.f, 396.f, 496.f, 596.f, 696.f }); + + sd::ops::xw_plus_b op; + auto result = op.evaluate({ &x, &y, &b }); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, StopGradient_1) { + + auto x = NDArrayFactory::create('c', {2,3}, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f}); + + sd::ops::stop_gradient op; + auto result = op.evaluate({&x}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + + // output->printShapeInfo("Output shape> "); + // x.printShapeInfo("Expected shape> "); + // output->printIndexedBuffer("Output data> "); + // x.printIndexedBuffer("Expected res>"); + + ASSERT_TRUE(x.isSameShape(output)); + ASSERT_TRUE(x.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, StopGradient_2) { + + auto x = NDArrayFactory::create('f', {2,3}, { 1.f, 11.f, 3.f, 14.f, 5.f, 6.f}); + + sd::ops::stop_gradient op; + auto result = op.evaluate({&x}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto output = result.at(0); + + // output->printShapeInfo("Output shape> "); + // x.printShapeInfo("Expected shape> "); + // output->printIndexedBuffer("Output data> "); + // x.printIndexedBuffer("Expected res>"); + + ASSERT_TRUE(x.isSameShape(output)); + ASSERT_TRUE(x.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, log_softmax_test1) { + + auto input = NDArrayFactory::create('c', {3, 3, 3}, {-1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14}); + auto expOutput = NDArrayFactory::create('c', {3, 3, 3}, {-2.16985e+00,-1.69846e-01,-3.16985e+00, -1.31507e+00,-6.31507e+00,-3.15072e-01, -8.00046e+00,-4.58767e-04,-9.00046e+00, -1.31327e+00,-1.23133e+01,-3.13266e-01, -1.40000e+01,-1.13743e-06,-1.50000e+01, -1.31326e+00,-1.83133e+01,-3.13262e-01, -2.00000e+01,-2.81941e-09,-2.10000e+01, -1.31326e+00,-2.43133e+01,-3.13262e-01, -2.73133e+01,-1.31326e+00,-3.13262e-01}); + + sd::ops::log_softmax op; + auto results = op.evaluate({&input}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, log_softmax_test2) { + + auto input = NDArrayFactory::create('c', {3, 3, 3}, {-1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14}); + auto expOutput = NDArrayFactory::create('c', {3, 3, 3}, {-3.05095e+00,-3.04946e+00,-5.00705e+00, -5.09458e-02,-7.04946e+00,-7.04851e-03, -6.05095e+00,-4.94556e-02,-8.00705e+00, -3.04859e+00,-1.30000e+01,-3.04859e+00, -1.50486e+01,-2.37286e-06,-1.70486e+01, -4.85876e-02,-1.60000e+01,-4.85874e-02, -2.10000e+01,-3.04859e+00,-2.51269e+01, -7.96007e-10,-2.50486e+01,-2.12693e+00, -2.40000e+01,-4.85874e-02,-1.26928e-01}); + + sd::ops::log_softmax op; + auto results = op.evaluate({&input}, {}, {1}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, log_softmax_test3) { + + auto input = NDArrayFactory::create('c', {3, 3, 3}, {-1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14}); + auto expOutput = NDArrayFactory::create('c', {3, 3, 3}, {-2.16985e+00,-1.69846e-01,-3.16985e+00, -1.31507e+00,-6.31507e+00,-3.15072e-01, -8.00046e+00,-4.58767e-04,-9.00046e+00, -1.31327e+00,-1.23133e+01,-3.13266e-01, -1.40000e+01,-1.13743e-06,-1.50000e+01, -1.31326e+00,-1.83133e+01,-3.13262e-01, -2.00000e+01,-2.81941e-09,-2.10000e+01, -1.31326e+00,-2.43133e+01,-3.13262e-01, -2.73133e+01,-1.31326e+00,-3.13262e-01}); + + sd::ops::log_softmax op; + auto results = op.evaluate({&input}, {}, {2}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); + + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, log_softmax_test5) { + + auto input = NDArrayFactory::create('c', {3, 3}, {-1, 1, -2, 2, -3, 3, -4, 4, 5}); + auto expOutput = NDArrayFactory::create('c', {3, 3}, {-2.16985, -0.16985, -3.16985, -1.31507, -6.31507, -0.31507, -9.31335, -1.31335, -0.31335}); + + sd::ops::log_softmax op; + auto results = op.evaluate({&input}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, log_softmax_test6) { + + auto input = NDArrayFactory::create('c', {3, 3}, {-1, 1, -2, 2, -3, 3, -4, 4, 5}); + auto expOutput = NDArrayFactory::create('c', {3, 3}, {-3.05095,-3.04946,-7.12773, -0.05095,-7.04946,-2.12773, -6.05095,-0.04946,-0.12773}); + + sd::ops::log_softmax op; + auto results = op.evaluate({&input}, {}, {0}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, log_softmax_test7) { + + auto input = NDArrayFactory::create('c', {1, 5}, {-1, 1, -2, 2, 3}); + auto expOutput = NDArrayFactory::create('c', {1, 5}, {-4.42414, -2.42414, -5.42414, -1.42414, -0.42414}); + + sd::ops::log_softmax op; + auto results = op.evaluate({&input}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, log_softmax_test8) { + + auto input = NDArrayFactory::create('c', {1, 5}, {-1, 1, -2, 2, 3}); + auto expOutput = NDArrayFactory::create('c', {1, 5}, {0, 0, 0, 0, 0}); + + sd::ops::log_softmax op; + auto results = op.evaluate({&input}, {}, {0}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, log_softmax_test9) { + + auto input = NDArrayFactory::create('c', {5, 1}, {-1, 1, -2, 2, 3}); + auto expOutput = NDArrayFactory::create('c', {5, 1}, {0, 0, 0, 0, 0}); + + sd::ops::log_softmax op; + auto results = op.evaluate({&input}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, log_softmax_test10) { + + auto input = NDArrayFactory::create('c', {5, 1}, {-1, 1, -2, 2, 3}); + auto expOutput = NDArrayFactory::create('c', {5, 1}, {-4.42414, -2.42414, -5.42414, -1.42414, -0.42414}); + + sd::ops::log_softmax op; + auto results = op.evaluate({&input}, {}, {0}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, log_softmax_test11) { + + auto input = NDArrayFactory::create('c', {5}, {-1, 1, -2, 2, 3}); + auto expOutput = NDArrayFactory::create('c', {5}, {-4.42414, -2.42414, -5.42414, -1.42414, -0.42414}); + + sd::ops::log_softmax op; + auto results = op.evaluate({&input}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, log_softmax_test12) { + + auto input = NDArrayFactory::create('c', {1, 4}, {0.1869, -1.4918, -0.6497, -0.8864}); + auto expOutput = NDArrayFactory::create('c', {1, 4}, {-0.6738, -2.3525, -1.5104, -1.7472}); + + for (int i = 0; i < 10; ++i) + { + sd::ops::log_softmax op; + auto results = op.evaluate({&input}); + auto z = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOutput.isSameShape(z)); + ASSERT_TRUE(expOutput.equalsTo(z, 1e-4)); + + + } +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, log_softmax_bp_test1) { + + auto input = NDArrayFactory::create('c', {2, 2}, {1,2,3,4}); + auto epsilon = NDArrayFactory::create('c', {2, 2}, {0.1, 0.2, 0.3, 0.4}); + auto exp = NDArrayFactory::create('c', {2, 2}, {-0.07311,0.02689, -0.07311,0.02689}); + + sd::ops::log_softmax_bp op; + auto results = op.evaluate({&input, &epsilon}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, log_softmax_bp_test2) { + + auto input = NDArrayFactory::create('c', {2, 2}, {1,2,3,4}); + auto epsilon = NDArrayFactory::create('c', {2, 2}, {0.1, 0.2, 0.3, 0.4}); + auto exp = NDArrayFactory::create('c', {2, 2}, {-0.17616, -0.17616, 0.02384, 0.02384}); + + sd::ops::log_softmax_bp op; + auto results = op.evaluate({&input, &epsilon}, {}, {0}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, ELU_1) { + + auto input = NDArrayFactory::create('c', {2, 2, 2}, { -1., 2. , 1.5, -1.4, 1., 2., 2., 1.}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}, { -0.63212055, 2. , 1.5, -0.753403, 1., 2., 2., 1.}); + auto res = NDArrayFactory::create('c', {2, 2, 2}); + + input.applyScalar(sd::scalar::ELU, 1.f, res); + + ASSERT_TRUE(res.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, L2_Loss_1) { + + auto input = NDArrayFactory::create('c', {2, 2, 2}, { -1., 2. , 1.5, -1.4, 1., 2., 2., 1.}); + double exp(9.605); + + sd::ops::l2_loss op; + auto results = op.evaluate({&input}, {}, {}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(output->isScalar()); + + ASSERT_EQ(output->e(0), exp); + + +} + +TEST_F(DeclarableOpsTests5, L2_Loss_2) { + auto x = NDArrayFactory::create(0.7787855863571167); + auto e = NDArrayFactory::create(0.303254); + + sd::ops::l2_loss op; + auto results = op.evaluate({&x}, {}, {}); + ASSERT_EQ(Status::OK(), results.status()); + + auto z = results.at(0); + + ASSERT_EQ(e, *z); + + +} + +TEST_F(DeclarableOpsTests5, L2_Loss_3) { + auto x = NDArrayFactory::create(0.7787855863571167); + auto e = NDArrayFactory::create(0.303254); + auto z = NDArrayFactory::create(0.0); + + sd::ops::l2_loss op; + auto status = op.execute({&x}, {&z} , {}, {}, {}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, LogPoissonLoss_1) { + auto weights = NDArrayFactory::create('c', {1, 1}, {1}); + + auto input = NDArrayFactory::create('c', {2, 2, 2}, { -1., 2. , 1.5, -1.4, 1., 2., 2., 1.}); + auto targets = NDArrayFactory::create('c', {2, 2, 2}, {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0}); + + auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1.3678794, 5.389056, 2.981689, 1.6465969, 1.7182817, 5.389056, 5.389056, 1.7182817}); + + sd::ops::log_poisson_loss op; + auto results = op.evaluate({&input, &weights, &targets}, {}, {0}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, LogPoissonLoss_2) { + + auto weights = NDArrayFactory::create('c', {1, 1}, {1}); + + auto input = NDArrayFactory::create('c', {2, 2, 2}, { -1., 2. , 1.5, -1.4, 1., 2., 2., 1.}); + auto targets = NDArrayFactory::create('c', {2, 2, 2}, {2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0}); + + auto exp = NDArrayFactory::create('c', {2, 2, 2}, {3.0196857, 4.0408626, 2.1334953, 3.6984034, 1.3700882, 4.0408626, 4.0408626, 1.3700882}); + + sd::ops::log_poisson_loss op; + auto results = op.evaluate({&input, &weights, &targets}, {}, {0, 1}); + auto output = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, NormalizeMoments_1) { + + auto means = NDArrayFactory::create('c', {2, 3, 4}, { 11., 3., 14., 5., + 6., 9., 3.5, 7., + 21., 3., 14., 15., + 6., 9., 3.5, 7., + 11., 13., 14., 5., + 16., 9., 13.5, 7.}); + + auto deviance = NDArrayFactory::create('c', {2, 3, 4}, { 21., 13., 24., 15., + 16., 19., 13.5, 17., + 31., 13., 24., 25., + 16., 19., 13.5, 17., + 21., 23., 24., 15., + 26., 19., 23.5, 17.}); + + auto counts = NDArrayFactory::create(2.0); + + auto expMeans = NDArrayFactory::create('c', {2, 3, 4}, { + 5.5, 1.5, 7., 2.5, + 3., 4.5, 1.75, 3.5, + 10.5, 1.5, 7., 7.5, + 3., 4.5, 1.75, 3.5, + 5.5, 6.5, 7., 2.5, + 8., 4.5, 6.75, 3.5}); + + auto expDeviance = NDArrayFactory::create('c', {2, 3, 4}, { + -19.75, 4.25, -37., 1.25, + -1., -10.75, 3.6875, -3.75, + -94.75, 4.25, -37., -43.75, + -1., -10.75, 3.6875, -3.75, + -19.75, -30.75, -37., 1.25, + -51., -10.75, -33.8125, -3.75}); + + sd::ops::normalize_moments op; + auto results = op.evaluate({&counts, &means, &deviance}, {0.0}, {}); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_EQ(results.size(), 2); + + auto outputMeans = results.at(0); + auto outputDeviance = results.at(1); + + ASSERT_TRUE(expMeans.isSameShape(outputMeans)); + ASSERT_TRUE(expMeans.equalsTo(outputMeans)); + ASSERT_TRUE(expMeans.isSameShape(outputDeviance)); + ASSERT_TRUE(expDeviance.equalsTo(outputDeviance)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, NormalizeMoments_2) { + + auto means = NDArrayFactory::create('c', {3, 2, 4}, { 11., 3., 14., 5., + 6., 9., 3.5, 7., + 21., 3., 14., 15., + 6., 9., 3.5, 7., + 11., 13., 14., 5., + 16., 9., 13.5, 7.}); + + auto deviance = NDArrayFactory::create('c', {3, 2, 4}, { 21., 13., 24., 15., + 16., 19., 13.5, 17., + 31., 13., 24., 25., + 16., 19., 13.5, 17., + 21., 23., 24., 15., + 26., 19., 23.5, 17.}); + + auto counts = NDArrayFactory::create(12.0); + + auto expMeans = NDArrayFactory::create('c', {3, 2, 4}, { 0.9166667, 0.25, 1.1666667, 0.4166667, + 0.5, 0.75, 0.2916667, 0.5833334, + 1.75, 0.25, 1.1666667, 1.25, + 0.5, 0.75, 0.2916667, 0.5833334, + 0.9166667, 1.0833334, 1.1666667, 0.4166667, + 1.3333334, 0.75, 1.125, 0.5833334}); + + auto expDeviance = NDArrayFactory::create('c', {3, 2, 4}, { + 0.9097222, 1.0208334, 0.6388887, 1.0763888, + 1.0833334, 1.0208334, 1.0399306, 1.076389, + -0.4791665, 1.0208334, 0.6388887, 0.5208335, + 1.0833334, 1.0208334, 1.0399306, 1.076389, + 0.9097222, 0.7430556, 0.6388887, 1.0763888, + 0.38888884, 1.0208334, 0.6927084, 1.076389}); + + sd::ops::normalize_moments op; + auto results = op.evaluate({&counts, &means, &deviance}, {0.0}, {}); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_EQ(results.size(), 2); + + auto outputMeans = results.at(0); + auto outputDeviance = results.at(1); + + ASSERT_TRUE(expMeans.isSameShape(outputMeans)); + ASSERT_TRUE(expMeans.equalsTo(outputMeans)); + ASSERT_TRUE(expMeans.isSameShape(outputDeviance)); + ASSERT_TRUE(expDeviance.equalsTo(outputDeviance)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests5, NormalizeMoments_3) { + + auto means = NDArrayFactory::create('c', {3, 2, 4}, { 11., 3., 14., 5., + 6., 9., 3.5, 7., + 21., 3., 14., 15., + 6., 9., 3.5, 7., + 11., 13., 14., 5., + 16., 9., 13.5, 7.}); + + auto deviance = NDArrayFactory::create('c', {3, 2, 4}, { 21., 13., 24., 15., + 16., 19., 13.5, 17., + 31., 13., 24., 25., + 16., 19., 13.5, 17., + 21., 23., 24., 15., + 26., 19., 23.5, 17.}); + + auto counts = NDArrayFactory::create(12.0); + double shift = 10.0; + auto expMeans = NDArrayFactory::create('c', {3, 2, 4}, { 10.9166667, 10.25, 11.1666667, 10.4166667, + 10.5, 10.75, 10.2916667, 10.5833334, + 11.75, 10.25, 11.1666667, 11.25, + 10.5, 10.75, 10.2916667, 10.5833334, + 10.9166667, 11.0833334, 11.1666667, 10.4166667, + 11.3333334, 10.75, 11.125, 10.5833334}); + + auto expDeviance = NDArrayFactory::create('c', {3, 2, 4}, { + 0.9097222, 1.0208334, 0.6388887, 1.0763888, + 1.0833334, 1.0208334, 1.0399306, 1.076389, + -0.4791665, 1.0208334, 0.6388887, 0.5208335, + 1.0833334, 1.0208334, 1.0399306, 1.076389, + 0.9097222, 0.7430556, 0.6388887, 1.0763888, + 0.38888884, 1.0208334, 0.6927084, 1.076389}); + + sd::ops::normalize_moments op; + auto results = op.evaluate({&counts, &means, &deviance}, {shift}, {}); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_EQ(results.size(), 2); + + auto outputMeans = results.at(0); + auto outputDeviance = results.at(1); + + ASSERT_TRUE(expMeans.isSameShape(outputMeans)); + ASSERT_TRUE(expMeans.equalsTo(outputMeans)); + ASSERT_TRUE(expMeans.isSameShape(outputDeviance)); + ASSERT_TRUE(expDeviance.equalsTo(outputDeviance)); + + +} + diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests6.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests6.cpp new file mode 100644 index 000000000..075e2372b --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests6.cpp @@ -0,0 +1,2807 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 09.02.18. +// + + +#include "testlayers.h" +#include +#include +#include +#include + + +using namespace sd; +using namespace sd::graph; + +class DeclarableOpsTests6 : public testing::Test { +public: + + DeclarableOpsTests6() { + printf("\n"); + fflush(stdout); + } +}; + + +TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_1) { + auto matrix = NDArrayFactory::create('c', {5, 2}); + auto b = NDArrayFactory::create('c', {1}, {0.}); + auto e = NDArrayFactory::create('c', {1}, {1}); + auto s = NDArrayFactory::create('c', {1}, {1}); + + auto exp = NDArrayFactory::create('c', {2}, {1.0f, 2.0f}); + + matrix.linspace(1); + + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_2) { + auto matrix = NDArrayFactory::create('c', {5, 2}); + auto b = NDArrayFactory::create('c', {1}, {0.0f}); + auto e = NDArrayFactory::create('c', {1}, {1.0f}); + auto s = NDArrayFactory::create('c', {1}, {1.0f}); + + auto exp = NDArrayFactory::create('c', {2}, {1.0f, 2.0f}); + + matrix.linspace(1); + + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_EQ(exp, *z); + + +} + +TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_3) { + auto matrix = NDArrayFactory::create(10); + auto b = NDArrayFactory::create(0); + auto e = NDArrayFactory::create(0); + auto s = NDArrayFactory::create(1.0); + + //auto exp = NDArrayFactory::create('c', {2}, {1.0f, 2.0f}); + + //matrix.linspace(1); + + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + //z->printShapeInfo("SS OS shape"); + ASSERT_TRUE(z->isEmpty()); + //ASSERT_EQ(exp, *z); + + +} + +TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_4) { + auto matrix = NDArrayFactory::create('c', {1}, {10}); + auto b = NDArrayFactory::create('c', {1}, {0.}); + auto e = NDArrayFactory::create('c', {1}, {0.}); + auto s = NDArrayFactory::create('c', {1}, {1.0}); + + auto exp = NDArrayFactory::create(10); + + //matrix.linspace(1); + + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(z->equalsTo(exp)); + //ASSERT_EQ(exp, *z); + + +} + +TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_04) { + int z = 0; + auto matrix = NDArrayFactory::create('c', {1}, {10}); + auto b = NDArrayFactory::create_('c', {1}, {1}); + auto e = NDArrayFactory::create_('c', {1}, {z}); + auto s = NDArrayFactory::create_('c', {1}, {1}); + sd::ops::ones_as opOnes; + //auto exp = NDArrayFactory::create('c', {2}, {1.0f, 2.0f}); + auto onesRes = opOnes.evaluate({&matrix}); + //matrix.linspace(1); + ASSERT_EQ(onesRes.status(), Status::OK()); + + auto ones = onesRes.at(0); + *ones *= 10; + auto onesD = new NDArray(ones->dup()); + + auto variableSpace = new VariableSpace(); + variableSpace->putVariable(-1, onesD); + variableSpace->putVariable(-2, b); + variableSpace->putVariable(-3, e); + variableSpace->putVariable(-4, s); + auto block = new Context(1, variableSpace, false); // not-in-place + block->fillInputs({-1}); + block->fillInputs({-2}); + block->fillInputs({-3}); + block->fillInputs({-4}); + block->getIArguments()->push_back(0); + block->getIArguments()->push_back(0); + block->getIArguments()->push_back(1); + block->getIArguments()->push_back(0); + block->getIArguments()->push_back(0); + auto inputShapes = new ShapeList({ones->shapeInfo(), b->shapeInfo(), e->shapeInfo(), s->shapeInfo()}); + sd::ops::strided_slice op; + auto result = op.calculateOutputShape(inputShapes, *block); //execute({ones, &b, &e, &s}, {}, {0, 1, 0, 0, 0}); + ASSERT_EQ(result->size(), 1); + ASSERT_TRUE(shape::isEmpty(result->at(0))); + //ASSERT_EQ(exp, *z); + delete block; + delete result; + delete variableSpace; + delete inputShapes; +} + +TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_5) { + auto matrix = NDArrayFactory::create('c', {3, 2, 2}); + auto b = NDArrayFactory::create('c', {1}, {2}); + auto e = NDArrayFactory::create('c', {1}, {3}); + auto s = NDArrayFactory::create('c', {1}, {1}); + + auto exp = NDArrayFactory::create('c', {2,2}, {0.0f, 0.0f, 0., 0.}); + + //matrix.linspace(1); + + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_6) { + auto matrix = NDArrayFactory::create('c', {3, 2, 2}); + auto b = NDArrayFactory::create('c', {1}, {2}); + auto e = NDArrayFactory::create('c', {1}, {3}); + auto s = NDArrayFactory::create('c', {1}, {1}); + + auto exp = NDArrayFactory::create('c', {1,2,2}, {0.0f, 0.0f, 0., 0.}); + + //matrix.linspace(1); + + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {0, 0, 0, 0, 2}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_7) { + int zero = 0; + auto matrix = NDArrayFactory::create('c', {5, 4}); + auto b = NDArrayFactory::create('c', {1}, {zero}); + auto e = NDArrayFactory::create('c', {1}, {zero}); + auto s = NDArrayFactory::create('c', {1}, {1}); + + //auto exp = NDArrayFactory::create('c', {1,2,2}, {0.0f, 0.0f, 0., 0.}); + + //matrix.linspace(1); + + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix, &b, &e, &s}, {}, {1, 0, 0, 0, 0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + //ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_1) { + int zero = 0; + auto matrix = NDArrayFactory::create('c', {5, 4}); +// auto b = NDArrayFactory::create('c', {1}, {zero}); +// auto e = NDArrayFactory::create('c', {1}, {zero}); +// auto s = NDArrayFactory::create('c', {1}, {1}); + + auto grad = NDArrayFactory::create('c', {5}); + + matrix.linspace(1); + grad.linspace(1); + + sd::ops::strided_slice_bp op; + auto result = op.evaluate({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + //ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_2) { + int zero = 0; + auto matrix = NDArrayFactory::create('c', {1, 2}); +// auto b = NDArrayFactory::create('c', {1}, {zero}); +// auto e = NDArrayFactory::create('c', {1}, {zero}); +// auto s = NDArrayFactory::create('c', {1}, {1}); + + auto grad = NDArrayFactory::create('c', {1}, {1.}); + + matrix.linspace(1); + //grad.linspace(1); + + sd::ops::strided_slice_bp op; + auto result = op.evaluate({&matrix, &grad}, {}, {1, 0, 1, 0, 2, 0, 0, 0, 1, 1, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + //ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_3) { + int zero = 0; + auto matrix = NDArrayFactory::create('c', {4, 8192}); +// auto b = NDArrayFactory::create('c', {1}, {zero}); +// auto e = NDArrayFactory::create('c', {1}, {zero}); +// auto s = NDArrayFactory::create('c', {1}, {1}); + + auto grad = NDArrayFactory::create('c', {4, 256}); + + matrix.linspace(1); + grad.linspace(1); + + sd::ops::strided_slice_bp op; + auto result = op.evaluate({&matrix, &grad}, {}, {1, 0, 1, 0, 0, 0, 0, 0, 256, 1, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + //ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) { + auto x = NDArrayFactory::create('c', {1, 1}, {2.0f}); + auto exp = NDArrayFactory::create('c', {1, 1}, {4.0f}); + + sd::ops::test_scalar op; + auto result = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests6, Test_Order_1) { + auto x = NDArrayFactory::create('f', {2, 3}); + auto exp = NDArrayFactory::create('c', {2, 3}); + x.linspace(1); + exp.linspace(1); + + sd::ops::order op; + auto result = op.evaluate({&x}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_NE(x.ordering(), z->ordering()); + + +} + +TEST_F(DeclarableOpsTests6, cumSum_1) { + auto x = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {1, 4}, {1.f, 3.f, 6.f, 10.f}); + + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests6, cumSum_2) { + auto x= NDArrayFactory::create('c', {2, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f}); + auto exp= NDArrayFactory::create('c', {2, 4}, {1.f, 3.f, 6.f, 10.f, 1.f, 3.f, 6.f, 10.f}); + + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 0, 1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + // z->printIndexedBuffer("CumSum1"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests6, cumSum_3) { + auto x= NDArrayFactory::create('c', {2, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f}); + auto exp= NDArrayFactory::create('c', {2, 4}, {1.f, 2.f, 3.f, 4.f, 2.f, 4.f, 6.f, 8.f}); + + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 0, 0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests6, cumSum_4) { + auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto exp = NDArrayFactory::create('c', {3, 3}, {12., 15., 18., 11., 13., 15., 7., 8., 9.}); + + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 1, 0}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + // z->printBuffer(); + + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests6, cumSum_5) { + auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto exp = NDArrayFactory::create('c', {3, 3}, {6.f, 5.f, 3.f, 15.f, 11.f, 6.f, 24.f, 17.f, 9.f,}); + + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 1, 1}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests6, cumSum_6) { + auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto exp = NDArrayFactory::create('c', {3, 3}, {11.f, 13.f, 15.f, 7.f, 8.f, 9.f, 0.f, 0.f, 0.f}); + + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {1, 1, 0}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests6, cumSum_7) { + auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto exp = NDArrayFactory::create('c', {3, 3}, {5.f, 3.f, 0.f, 11.f, 6.f, 0.f, 17.f, 9.f, 0.f}); + + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {1, 1, 1}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests6, cumSum_8) { + auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto axis = NDArrayFactory::create('c', {1}, {1}); + auto exp = NDArrayFactory::create('c', {3, 3}, {5.f, 3.f, 0.f, 11.f, 6.f, 0.f, 17.f, 9.f, 0.f}); + + sd::ops::cumsum op; + auto result = op.evaluate({&x, &axis}, {}, {1, 1}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, cumSum_9) { + + auto inputC = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto axis = NDArrayFactory::create(1); + + auto expFF = NDArrayFactory::create('c', {3, 5}, {1., 3., 6., 10., 15., 6., 13., 21., 30., 40., 11., 23., 36., 50., 65.}); + auto expTF = NDArrayFactory::create('c', {3, 5}, {0., 1., 3., 6., 10., 0., 6., 13., 21., 30., 0., 11., 23., 36., 50.}); + + auto expFT = NDArrayFactory::create('c', {3, 5}, {15, 14, 12, 9, 5,40, 34, 27, 19, 10,65, 54, 42, 29, 15}); //+++ + auto expTT = NDArrayFactory::create('c', {3, 5}, {14, 12, 9, 5, 0,34, 27, 19, 10, 0,54, 42, 29, 15, 0}); + + int exclusive, reverse; + + //************************************// + exclusive = 0; reverse = 0; + + sd::ops::cumsum op; + auto result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(expFF.equalsTo(z)); + + + //************************************// + exclusive = 1; reverse = 0; + + result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); + ASSERT_EQ(Status::OK(), result.status()); + z = result.at(0); + ASSERT_TRUE(expTF.equalsTo(z)); + + + //************************************// + exclusive = 0; reverse = 1; + + result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); + ASSERT_EQ(Status::OK(), result.status()); + z = result.at(0); + ASSERT_TRUE(expFT.equalsTo(z)); + + + //************************************// + exclusive = 1; reverse = 1; + + result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); + ASSERT_EQ(Status::OK(), result.status()); + z = result.at(0); + ASSERT_TRUE(expTT.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, cumSum_10) { + auto x = NDArrayFactory::create('c', {4, 16, 16, 1}); + auto y = NDArrayFactory::create(-3); + + sd::ops::cumsum op; + auto result = op.evaluate({&x, &y}, {}, {1, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, cumSum_11) { + + NDArray x('c', {3, 3, 3}, sd::DataType::DOUBLE); + auto exp = NDArrayFactory::create('c', {3,3,3}, {12., 15., 18.,11., 13., 15.,7., 8., 9., 39., 42., 45.,29., 31., 33.,16., 17., 18., 66., 69., 72.,47., 49., 51.,25., 26., 27.}); + + x.linspace(1); + + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 1, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, cumSum_12) { + + NDArray x('c', {3, 3, 3}, sd::DataType::DOUBLE); + auto exp = NDArrayFactory::create('c', {3,3,3}, {1., 2., 3.,5., 7., 9.,12., 15., 18., 10., 11., 12.,23., 25., 27.,39., 42., 45., 19., 20., 21.,41., 43., 45., 66., 69., 72.}); + + x.linspace(1); + + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, cumSum_13) { + + NDArray x('c', {3, 3, 3}, sd::DataType::DOUBLE); + auto exp = NDArrayFactory::create('c', {3,3,3}, {11., 13., 15.,7., 8., 9.,0., 0., 0., 29., 31., 33.,16., 17., 18.,0., 0., 0., 47., 49., 51.,25., 26., 27.,0., 0., 0.}); + + x.linspace(1); + + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {1, 1, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, cumSum_14) { + + NDArray x('c', {3, 3, 3}, sd::DataType::DOUBLE); + auto exp = NDArrayFactory::create('c', {3,3,3}, {29., 31., 33.,35., 37., 39.,41., 43., 45., 19., 20., 21.,22., 23., 24.,25., 26., 27., 0., 0., 0.,0., 0., 0.,0., 0., 0.}); + + x.linspace(1); + + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {1, 1, 0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, cumSum_15) { + + NDArray x('c', {3, 3, 3}, sd::DataType::DOUBLE); + auto exp = NDArrayFactory::create('c', {3,3,3}, {6., 5., 3.,15., 11., 6.,24., 17., 9., 33., 23., 12.,42., 29., 15.,51., 35., 18., 60., 41., 21.,69., 47., 24.,78., 53., 27.}); + + x.linspace(1); + + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, cumSum_16) { + + NDArray x('f', {3, 4}, sd::DataType::FLOAT32); + + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + // z->printShapeInfo(); + // x.printShapeInfo(); + + ASSERT_TRUE(z->ews() == 1); + ASSERT_TRUE(x.ews() == 1); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, cumSum_17) { + + NDArray x('c', {2, 1500}, sd::DataType::FLOAT32); + NDArray x0 = x(0, {0}); + NDArray x1 = x(1, {0}); + x0.linspace(1); + x1.linspace(1); + + NDArray exp('c', {2, 1500}, sd::DataType::FLOAT32); + NDArray exp0 = exp(0, {0}); + NDArray exp1 = exp(1, {0}); + + exp0.p(0, 1.); + exp1.p(0, 1.); + + for (int i = 1; i < 1500; ++i) { + const auto prev = exp0.e(i-1); + exp0.p(i, prev + i + 1); + exp1.p(i, prev + i + 1); + } + + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, cumSum_18) { + + NDArray x('c', {2, 1500}, sd::DataType::FLOAT32); + NDArray x0 = x(0, {0}); + NDArray x1 = x(1, {0}); + x0.linspace(1); + x1.linspace(1); + + NDArray exp('c', {2, 1500}, sd::DataType::FLOAT32); + NDArray exp0 = exp(0, {0}); + NDArray exp1 = exp(1, {0}); + + exp0.p(0, 0.); + exp1.p(0, 0.); + + for (int i = 1; i < 1500; ++i) { + const auto prev = exp0.e(i-1); + exp0.p(i, prev + i); + exp1.p(i, prev + i); + } + + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {1, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, cumSum_19) { + + NDArray x('c', {2, 1500}, sd::DataType::FLOAT32); + NDArray x0 = x(0, {0}); + NDArray x1 = x(1, {0}); + x0.linspace(1); + x1.linspace(1); + + NDArray exp('c', {2, 1500}, sd::DataType::FLOAT32); + NDArray exp0 = exp(0, {0}); + NDArray exp1 = exp(1, {0}); + + exp0.p(1499, 1500.f); + exp1.p(1499, 1500.f); + + for (int i = 1498; i >= 0; --i) { + const auto prev = exp0.e(i + 1); + exp0.p(i, prev + i + 1); + exp1.p(i, prev + i + 1); + } + + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {0, 1, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + // exp0.printBuffer(); + + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, cumSum_20) { + + NDArray x('c', {2, 1500}, sd::DataType::FLOAT32); + NDArray x0 = x(0, {0}); + NDArray x1 = x(1, {0}); + x0.linspace(1); + x1.linspace(1); + + NDArray exp('c', {2, 1500}, sd::DataType::FLOAT32); + NDArray exp0 = exp(0, {0}); + NDArray exp1 = exp(1, {0}); + + exp0.p(1499, 0.); + exp1.p(1499, 0.); + + for (int i = 1498; i >= 0; --i) { + const auto prev = exp0.e(i + 1); + exp0.p(i, prev + i + 2); + exp1.p(i, prev + i + 2); + } + + sd::ops::cumsum op; + auto result = op.evaluate({&x}, {}, {1, 1, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_1) { + + auto x = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + auto y = NDArrayFactory::create('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); + auto z = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 20.f, 3.f, 40.f, 5.f, 60.f, 7.f, 80.f}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 1, 2, 1, 2, 1, 2}); + sd::ops::mergemaxindex op; + + auto res = op.evaluate({&x, &y, &z}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_TRUE(res.at(0)->equalsTo(exp)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_2) { + + auto x = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 60.f, 7.f, 8.f}); + auto y = NDArrayFactory::create('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); + auto z = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 20.f, 3.f, 40.f, 5.f, 6.f, 7.f, 80.f}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 1, 2, 1, 0, 1, 2}); + sd::ops::mergemaxindex op; + + auto ress = op.evaluate({&x, &y, &z}, {}, {sd::DataType::INT64}); + + ASSERT_EQ(ND4J_STATUS_OK, ress.status()); + ASSERT_TRUE(ress.at(0)->equalsTo(exp)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_3) { + + auto x1 = NDArrayFactory::create('c', {3}, {1.f, 0.f, 0.f}); + auto x2 = NDArrayFactory::create('c', {3}, {0.f, 1.f, 0.f}); + auto x3 = NDArrayFactory::create('c', {3}, {0.f, 0.f, 1.f}); + NDArray z('c', {3}, sd::DataType::INT32); + NDArray expZ('c', {3}, {0, 1, 2}, sd::DataType::INT32); + + sd::ops::mergemaxindex op; + auto result = op.execute({&x1, &x2, &x3}, {&z}, {}, {}, {}); + + ASSERT_EQ(Status::OK(), result); + ASSERT_TRUE(z.equalsTo(expZ)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, TestDropout_1) { + + auto x = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + auto shape = NDArrayFactory::create({2, 2}); + sd::ops::dropout op; + + auto res = op.evaluate({&x, &shape}, {0.2f}, {113}); + + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + //res.at(0)->printIndexedBuffer("Result is "); + //x.printIndexedBuffer("Input is"); + + +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, TestMod_1) { + + auto x = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + auto y = NDArrayFactory::create('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 0, 3, 0, 5, 0, 7, 0}); + sd::ops::mod op; + + auto res = op.evaluate({&x, &y}); + + ASSERT_EQ(ND4J_STATUS_OK, res.status()); +// res.at(0)->printIndexedBuffer("MOD Result is "); +// x.printIndexedBuffer("Input is"); + ASSERT_TRUE(res.at(0)->equalsTo(exp)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, TestMod_BP_1) { + + auto x = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + auto y = NDArrayFactory::create('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); + auto eps = NDArrayFactory::create('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}); + sd::ops::mod_bp op; + + auto res = op.evaluate({&x, &y, &eps}); + + ASSERT_EQ(ND4J_STATUS_OK, res.status()); +// res.at(0)->printIndexedBuffer("MOD_BP Result is "); + + // x.printIndexedBuffer("Input is"); + ASSERT_TRUE(res.at(0)->equalsTo(exp)); + +} + +/////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, TestRank_1) { + + auto x = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + auto y = NDArrayFactory::create('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); + auto eps = NDArrayFactory::create('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f}); + auto exp = NDArrayFactory::create(3); + sd::ops::rank op; + + auto res = op.evaluate({&x}); + + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + + ASSERT_TRUE(res.at(0)->equalsTo(exp)); + +} +TEST_F(DeclarableOpsTests6, TestDropout_2) { +// auto x0 = NDArrayFactory::create('c', {10, 10}); +// auto x1 = NDArrayFactory::create('c', {10, 10}); + auto x = NDArrayFactory::create('c', {3, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f}); + + sd::ops::dropout op; + + auto res = op.evaluate({&x}, {0.4f}, {113}); + + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + + +} + +TEST_F(DeclarableOpsTests6, TestDropout_3) { +// auto x0 = NDArrayFactory::create('c', {10, 10}); +// auto x1 = NDArrayFactory::create('c', {10, 10}); + auto x = NDArrayFactory::create('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + auto shape = NDArrayFactory::create({1, 2}); + + sd::ops::dropout op; + + auto res = op.evaluate({&x, &shape}, {0.4f}, {113}); + + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, MaxPoolWithArgmax_1) { + + auto x = NDArrayFactory::create('c', {2, 2, 2, 4}, {5.5, 0., 0.3, 5.5,1.5, 0., 1.3, 6.5,8.6, 0., 0., 0.4,2.5, 1., 0.3, 4.5, + 1.5, 1., 1.3, 1.5, 3.5, 0., 1.3, 2.5, 2.6, 2., 3., 1.4, 4.5, 1., 0.3, 0.5}); + auto expI = NDArrayFactory::create('c', {2, 2, 2, 4}, {0, 1, 2, 3,4, 5, 6, 7,8, 9, 10, 11,12, 13, 14, 15, + 0, 1, 2, 3,4, 5, 6, 7,8, 9, 10, 11,12, 13, 14, 15}); + + sd::ops::max_pool_with_argmax op; + + auto res = op.evaluate({&x}, {}, {1,1,1,1,1,1,1,1,1}); + + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_TRUE(expI.isSameShape(res.at(0))); + ASSERT_TRUE(expI.isSameShape(res.at(1))); + ASSERT_TRUE(x.equalsTo(res.at(0))); + ASSERT_TRUE(expI.equalsTo(res.at(1))); + //x.printIndexedBuffer("Input is"); + + ASSERT_TRUE(expI.equalsTo(res.at(1))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, SufficientStatistics_1) { +// auto x0 = NDArrayFactory::create('c', {10, 10}); +// auto x1 = NDArrayFactory::create('c', {10, 10}); + auto x = NDArrayFactory::create('c', {2, 2, 2, 4}, {5.5, 0., 0.3, 5.5,1.5, 0., 1.3, 6.5,8.6, 0., 0., 0.4,2.5, 1., 0.3, 4.5,1.5, 1., + 1.3, 1.5,3.5, 0., 1.3, 2.5,2.6, 2., 3., 1.4,4.5, 1., 0.3, 0.5}); +// ------------------------------------ + double count = 8.0; + auto sumExp = NDArrayFactory::create({30.2, 5., 7.8, 22.8}); + auto sqrExp = NDArrayFactory::create({154.22, 7., 14.34, 103.62}); + + auto axis = NDArrayFactory::create({0, 1, 2}); + + sd::ops::sufficient_statistics op; + + auto res = op.evaluate({&x, &axis}); + + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_EQ(res.at(0)->e(0), count); + ASSERT_TRUE(sumExp.equalsTo(res.at(1))); + ASSERT_TRUE(sqrExp.equalsTo(res.at(2))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, SufficientStatistics_2) { +// auto x0 = NDArrayFactory::create('c', {10, 10}); +// auto x1 = NDArrayFactory::create('c', {10, 10}); + auto x = NDArrayFactory::create('c', {2, 2, 2, 4}, {5.5, 0., 0.3, 5.5,1.5, 0., 1.3, 6.5,8.6, 0., 0., 0.4,2.5, 1., 0.3, 4.5, + 1.5, 1., 1.3, 1.5,3.5, 0., 1.3, 2.5,2.6, 2., 3., 1.4,4.5, 1., 0.3, 0.5}); +// ------------------------------------ + double count = 4.0; + auto sumExp = NDArrayFactory::create('c', {2, 4}, { + 18.2, 3., 4.6, 8.8, + 12., 2., 3.2, 14.} + ); + + auto sqrExp = NDArrayFactory::create('c', {2, 4}, { + 113.22, 5., 10.78, 34.62, + 41., 2., 3.56, 69.} + ); + + auto axis = NDArrayFactory::create({0, 1}); + + sd::ops::sufficient_statistics op; + + auto res = op.evaluate({&x, &axis}); + + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_EQ(res.at(0)->e(0), count); + ASSERT_TRUE(sumExp.equalsTo(res.at(1))); + ASSERT_TRUE(sqrExp.equalsTo(res.at(2))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, BinCount_1) { + + auto x = NDArrayFactory::create('c', {2, 2, 2}, { + 1, 2, 0, 1, 2, 2, 1, 2} + ); +// ------------------------------------ + + NDArray exp('c', {3}, {1, 3, 4}, sd::DataType::INT32); + + sd::ops::bincount op; + + auto res = op.evaluate({&x}); + + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_TRUE(exp.equalsTo(res.at(0))); + +} + +///////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, BinCount_2) { + + auto x = NDArrayFactory::create('c', {2, 2, 2}, { + 1, 2, 0, 1, 2, 2, 1, 2} + ); + + auto weights = NDArrayFactory::create('c', {2, 2, 2}, { + 2, 1, 3, 1, 5, 1, 1, 6} + ); + +// ------------------------------------ + + auto exp = NDArrayFactory::create({3., 4., 13.}); + + sd::ops::bincount op; + + auto res = op.evaluate({&x, &weights}); + + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_TRUE(exp.equalsTo(res.at(0))); +} + +///////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, BinCount_3) { + + auto x = NDArrayFactory::create('c', {2, 2, 2}, { + 1, 2, 0, 1, 2, 2, 1, 2} + ); + + auto weights = NDArrayFactory::create('c', {2, 2, 2}, { + 2, 1, 3, 1, 5, 1, 1, 6} + ); + +// ------------------------------------ + + auto exp = NDArrayFactory::create({3., 4.}); + + sd::ops::bincount op; + + auto res = op.evaluate({&x, &weights}, {}, {0, 2}); + + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_TRUE(exp.equalsTo(res.at(0))); + +} + +///////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, BinCount_4) { + + auto x = NDArrayFactory::create('c', {2, 2, 2}, { + 1, 2, 0, 1, 2, 2, 1, 2} + ); + + auto weights = NDArrayFactory::create('c', {2, 2, 2}, { + 2, 1, 3, 1, 5, 1, 1, 6} + ); + +// ------------------------------------ + + auto exp = NDArrayFactory::create({3., 4., 13., 0.0}); + + sd::ops::bincount op; + + auto res = op.evaluate({&x, &weights}, {}, {4, 4}); + + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_TRUE(exp.equalsTo(res.at(0))); +} + +///////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, BinCount_5) { + + auto x = NDArrayFactory::create('c', {2, 2, 2}, { + 1, 2, 0, 1, 2, 2, 1, 2} + ); + + auto weights = NDArrayFactory::create('c', {2, 2, 2}, { + 2, 1, 3, 1, 5, 1, 1, 6} + ); + auto minV = NDArrayFactory::create(4); + auto maxV = NDArrayFactory::create(4); +// ------------------------------------ + + auto exp = NDArrayFactory::create({3., 4., 13., 0.0}); + + sd::ops::bincount op; + + auto res = op.evaluate({&x, &weights, &minV, &maxV}); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + // res->at(0)->printBuffer("BC out"); + ASSERT_TRUE(exp.equalsTo(res.at(0))); + +} + +///////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_1) { + + auto x = NDArrayFactory::create( {2, 2, 2} ); + + auto y = NDArrayFactory::create({ 2, 1, 2}); + + auto exp = NDArrayFactory::create({2, 2, 2}); + + sd::ops::broadcast_dynamic_shape op; + + auto res = op.evaluate({&x, &y}); + + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_TRUE(exp.equalsTo(res.at(0))); + +} + +///////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_2) { + + auto x = NDArrayFactory::create( {2, 2} ); + + auto y = NDArrayFactory::create({2, 1, 2}); + + auto exp = NDArrayFactory::create({2, 2, 2}); + + sd::ops::broadcast_dynamic_shape op; + + auto res = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_TRUE(exp.equalsTo(res.at(0))); + +} + +///////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_3) { + + auto x = NDArrayFactory::create( {2, 2, 2} ); + + auto y = NDArrayFactory::create({2, 1}); + + auto exp = NDArrayFactory::create({2, 2, 2}); + + sd::ops::broadcast_dynamic_shape op; + + auto res = op.evaluate({&x, &y}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_TRUE(exp.equalsTo(res.at(0))); +} + +///////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_4) { + + auto x = NDArrayFactory::create( {2, 1} ); + + auto y = NDArrayFactory::create('c', {1}, {4}); + + auto exp = NDArrayFactory::create({2, 4}); + + sd::ops::broadcast_dynamic_shape op; + + auto res = op.evaluate({&x, &y}); + + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + //res->at(0)->printBuffer("Shape SGO 4"); + ASSERT_TRUE(exp.equalsTo(res.at(0))); +} + +///////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_6) { + + auto x = NDArrayFactory::create({2, 1, 4}); + + auto y = NDArrayFactory::create({2, 2, 4}); + + auto exp = NDArrayFactory::create({2, 2, 4}); + + sd::ops::broadcast_dynamic_shape op; + auto res = op.evaluate({&x, &y}); + + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_TRUE(exp.equalsTo(res.at(0))); + +} + +///////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_7) { + + auto x = NDArrayFactory::create({1, 1, 3}); + + auto y = NDArrayFactory::create({2, 4, 1}); + + auto exp = NDArrayFactory::create({2, 4, 3}); + + sd::ops::broadcast_dynamic_shape op; + auto res = op.evaluate({&x, &y}); + + ASSERT_EQ(ND4J_STATUS_OK, res.status()); + ASSERT_TRUE(exp.equalsTo(res.at(0))); +} + +///////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_8) { + + auto x = NDArrayFactory::create('c', {1}, {1}); + + auto y = NDArrayFactory::create('c', {1}, {4}); + + auto z = NDArrayFactory::create('c', {1}); + + auto exp = NDArrayFactory::create('c', {1}, {4}); + + sd::ops::broadcast_dynamic_shape op; + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(exp.equalsTo(z)); +} + +///////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, BroadcastDynamicShape_SGO_9) { + + auto x = NDArrayFactory::create('c', {2}, {2,2}); + + auto y = NDArrayFactory::create('c', {1}, {1}); + + auto z = NDArrayFactory::create('c', {2}); + + auto exp = NDArrayFactory::create('c', {2}, {2,2}); + + sd::ops::broadcast_dynamic_shape op; + auto status = op.execute({&x, &y}, {&z}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + // ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_1) { + + auto x = NDArrayFactory::create('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0, + -3.0, 0.0, 0.0, 4.0, 0.0, 0.0, + -3.0, 0.0, 0.0, 4.0, 0.0, 0.0} + ); + + auto exp = NDArrayFactory::create('c', {2, 3, 3}, { + -0.2771281, 0., 0., + 0.36950415, 0., 0., + -0.2771281, 0., 0., + 0.36950415, 0., 0., + -0.2771281, 0., 0., + 0.36950415, 0., 0.} + ); +// 8.660254 +// auto expNorm(8.660254); + + sd::ops::clip_by_global_norm op; + auto result = op.evaluate({&x}, {0.8}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + auto norm = result.at(1); + //z->printIndexedBuffer("Output"); + //exp.printIndexedBuffer("Expected"); + //norm->printIndexedBuffer("Norm"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +// ASSERT_TRUE(expNorm.equalsTo(norm)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_2) { + + auto x = NDArrayFactory::create('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0, + -3.0, 0.0, 0.0, 4.0, 0.0, 0.0, + -3.0, 0.0, 0.0, 4.0, 0.0, 0.0} + ); + + auto a = NDArrayFactory::create('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0, + -3.0, 0.0, 0.0, 4.0, 0.0, 0.0, + -3.0, 0.0, 0.0, 4.0, 0.0, 0.0} + ); + + auto exp = NDArrayFactory::create('c', {2, 3, 3}, { + -0.44090813, 0., 0., + 0.5878775, 0., 0., + -0.44090813, 0., 0., + 0.5878775, 0., 0., + -0.44090813, 0., 0., + 0.5878775, 0., 0.} +//12.247449 + + ); + + sd::ops::clip_by_global_norm op; + auto result = op.evaluate({&x, &a}, {1.8}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + auto y = result.at(1); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.isSameShape(y)); + ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.equalsTo(y)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_3) { + + auto x = NDArrayFactory::create('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0}); + auto a = NDArrayFactory::create('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0, -3.0, 0.0, 0.0, 4.0, 0.0, 0.0}); + auto exp = NDArrayFactory::create('c', {2, 3, 3}, { + -0.19595918, 0., 0., + 0.2612789, 0., 0., + -0.19595918, 0., 0., + 0.2612789, 0., 0., + -0.19595918, 0., 0., + 0.2612789, 0., 0.} + ); + + sd::ops::clip_by_global_norm op; + auto result = op.evaluate({&x, &a}, {0.8}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + auto y = result.at(1); + //z->printIndexedBuffer("Output 1"); + //y->printIndexedBuffer("Output 2"); + //result.at(2)->printIndexedBuffer("Global norm is"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.isSameShape(y)); + ASSERT_TRUE(result.at(2)->isScalar()); + ASSERT_TRUE(exp.equalsTo(z)); + ASSERT_TRUE(exp.equalsTo(y)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, MatrixDeterminant_1) { + + auto x = NDArrayFactory::create('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, -3.0, 4.0, 0.0, 0.0, 0.0, -3.0, 0.0, 0.0, 0.0, 4.0}); + auto exp = NDArrayFactory::create({36.0, -48.0}); + + sd::ops::matrix_determinant op; + auto result = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + //z->printIndexedBuffer("Output "); + //exp.printIndexedBuffer("Expected "); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, MatrixDeterminant_2) { + + auto x = NDArrayFactory::create('c', {2, 2, 2}, {1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0}); + auto exp = NDArrayFactory::create({-2.0, -2.0}); + + sd::ops::matrix_determinant op; + auto result = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + //z->printIndexedBuffer("Output "); + //exp.printIndexedBuffer("Expected "); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, MatrixDeterminant_3) { + + auto x = NDArrayFactory::create('c', {1, 3, 3}, {3.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 3.0}); + NDArray exp('c', {1}, std::vector{-54.0}); + + sd::ops::matrix_determinant op; + auto result = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + //z->printIndexedBuffer("Output "); + //exp.printIndexedBuffer("Expected "); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, MatrixDeterminant_4) { + + auto x = NDArrayFactory::create('c', {1, 3, 3}, {12.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 13.0}); + auto exp = NDArrayFactory::create('c', {1}, {189.0}); + + sd::ops::matrix_determinant op; + auto result = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + // z->printIndexedBuffer("Output "); + // exp.printIndexedBuffer("Expected "); + // z->printShapeInfo("Output shape"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, MatrixDeterminant_5) { + + auto x = NDArrayFactory::create('c', {1, 4, 4}); + NDArray exp('c', {1}, std::vector{-16.0}); + x.linspace(1); + x.p(5, 4.0); + x.p(12, 12.0); + + sd::ops::matrix_determinant op; + auto result = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + //z->printIndexedBuffer("Output "); + //exp.printIndexedBuffer("Expected "); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, MatrixDeterminant_6) { + + auto x = NDArrayFactory::create('c', {4, 4}); + auto exp = NDArrayFactory::create(-16.0); + x.linspace(1); + x.p(5, 4.0); + x.p(12, 12.0); + + sd::ops::matrix_determinant op; + auto result = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + //z->printIndexedBuffer("Output "); + //z->printShapeInfo("Shape"); + //exp.printIndexedBuffer("Expected "); + ASSERT_TRUE(z->isScalar()); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, LogMatrixDeterminant_1) { + + auto x = NDArrayFactory::create('c', {2, 3, 3}, {-3.0, 0.0, 0.0, 0.0, 4.0, 0.0, 0.0, 0.0, -3.0, 4.0, 0.0, 0.0, 0.0, -3.0, 0.0, 0.0, 0.0, 4.0}); + auto exp = NDArrayFactory::create({3.58351893845611, 3.871201010907891}); + + sd::ops::log_matrix_determinant op; + auto result = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, LogDet_1) { + + auto x = NDArrayFactory::create('c', {2, 3, 3}, {4,12,-16,12,37,-43,-16,-43,98, 4,1.2,-1.6,1.2,3.7,-4.3,-1.6,-4.3,9.8}); + auto exp = NDArrayFactory::create({ 3.5835189, 4.159008}); + + sd::ops::logdet op; + auto result = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, LogDet_2) { + + auto x = NDArrayFactory::create('c', {1, 3, 3}, {4,12,-16,12,37,-43,-16,-43,98}); + auto exp = NDArrayFactory::create('c', {1}, { 3.5835189}); + + sd::ops::logdet op; + auto result = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, LogDet_3) { + + auto x = NDArrayFactory::create('c', {3, 3}, {4,12,-16,12,37,-43,-16,-43,98}); + auto exp = NDArrayFactory::create( 3.5835189); + + sd::ops::logdet op; + auto result = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, MatrixInverse_1) { + + auto x = NDArrayFactory::create('c', {2, 5, 5}, { + 2.f, 4.f, 60.f, 8.f, 10.f, + 0.f, 1.f, 2.f, 3.f, 4.f, + 0.f, 0.f, 2.f, 4.f, 6.f, + 0.f, 0.f, 0.f, 1.f, 2.f, + 0.f, 0.f, 0.f, 0.f, 4.f, + + 1.f, 0.f, 0.f, 0.f, 0.f, + 2.f, 1.f, 0.f, 0.f, 0.f, + 30.f, 2.f, 1.f, 0.f, 0.f, + 4.f, 3.f, 2.f, 1.f, 0.f, + 5.f, 4.f, 3.f, 2.f, 1.f + }); + + auto exp = NDArrayFactory::create('c', {2, 5, 5}, { + 0.5f, -2.0f, -13.0f, 54.0f, -6.75f, + 0.0f, 1.0f, -1.0f, 1.0f, 0.0f, + 0.f, 0.f, 0.5f, -2.0f, 0.25f, + 0.f, 0.f, 0.f, 1.0f, -0.5f, + 0.f, 0.f, 0.f, 0.f, 0.25f, + + 1.0f, 0.0f, 0.0f, 0.0f, 0.f, + -2.0f, 1.0f, 0.f, 0.f, 0.f, + -26.0f, -2.0f, 1.f, 0.f, 0.f, + 54.0f, 1.0f, -2.0f, 1.f, 0.f, + -27.0f, 0.0f, 1.0f, -2.0f, 1.f, + }); + + sd::ops::matrix_inverse op; + auto result = op.evaluate({&x}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, MatrixInverse_010) { + + auto x = NDArrayFactory::create('c', {1, 5, 5}, {1.f, 0.f, 0.f, 0.f, 0.f, 2.f, 1.f, 0.f, 0.f, 0.f, 30.f, 2.f, 1.f, 0.f, 0.f, 4.f, 3.f, 2.f, 1.f, 0.f, 5.f, 4.f, 3.f, 2.f, 1.f, }); + auto exp = NDArrayFactory::create('c', {1, 5, 5}, {1.0f, 0.0f, 0.0f, 0.0f, 0.f, -2.0f, 1.0f, 0.f, 0.f, 0.f, -26.0f, -2.0f, 1.f, 0.f, 0.f, 54.0f, 1.0f, -2.0f, 1.f, 0.f, -27.0f, 0.0f, 1.0f, -2.0f, 1.f}); + + sd::ops::matrix_inverse op; + auto result = op.evaluate({&x}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, MatrixInverse_01) { + + auto x = NDArrayFactory::create('c', {1, 5, 5}, {2.f, 4.f, 60.f, 8.f, 10.f, 0.f, 1.f, 2.f, 3.f, 4.f, 0.f, 0.f, 2.f, 4.f, 6.f, 0.f, 0.f, 0.f, 1.f, 2.f, 0.f, 0.f, 0.f, 0.f, 4.f }); + + auto exp = NDArrayFactory::create('c', {1, 5, 5}, {0.5f, -2.0f, -13.0f, 54.0f, -6.75f, 0.0f, 1.0f, -1.0f, 1.0f, 0.0f, 0.f, 0.f, 0.5f, -2.0f, 0.25f, 0.f, 0.f, 0.f, 1.0f, -0.5f, 0.f, 0.f, 0.f, 0.f, 0.25f }); + sd::ops::matrix_inverse op; + auto result = op.evaluate({&x}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, MatrixInverse_02) { + + auto x = NDArrayFactory::create('c', {1, 5, 5}, {1.f, 0.f, 0.f, 0.f, 0.f, 2.f, 1.f, 0.f, 0.f, 0.f, 30.f, 2.f, 1.f, 0.f, 0.f, 4.f, 3.f, 2.f, 1.f, 0.f, 5.f, 4.f, 3.f, 2.f, 1.f }); + auto exp = NDArrayFactory::create('c', {1, 5, 5}, {1.0f, 0.0f, 0.0f, 0.0f, 0.f, -2.0f, 1.0f, 0.f, 0.f, 0.f, -26.0f, -2.0f, 1.f, 0.f, 0.f, 54.0f, 1.0f, -2.0f, 1.f, 0.f, -27.0f, 0.0f, 1.0f, -2.0f, 1.f }); + + sd::ops::matrix_inverse op; + auto result = op.evaluate({&x}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +/* +TEST_F(DeclarableOpsTests6, MatrixInverse_2) { + + auto x = NDArrayFactory::create('c', {2, 5, 5}, { + 1., 2., 30., 4., 5., + 0., 1., 2., 3., 4., + 0., 0., 1., 2., 3., + 0., 0., 0., 1., 2., + 0., 0., 0., 0., 1., + + 4., 0., 0., 0., 0., + 4., 2., 0., 0., 0., + 30., 2., 1., 0., 0., + 8., 6., 4., 2., 0., + 15., 12., 9., 6., 3., + }); + + auto exp = NDArrayFactory::create('c', {2, 5, 5}, { + 1.0, -2.0, -26.0, 54.0, -27.0, + 0.0, 1.0, -2.0, 1.0, 0.0, + 0.0, 0.0, 1.0, -2.0, 1.0, + 0.0, 0.0, 0.0, 1.0, -2.0, + 0.0, 0.0, 0.0, 0.0, 1.0, + + 0.25, 0.0, 0.0, 0.0, 0.0, + -0.50, 0.5, 0.0, 0.0, 0.0, + -6.50, -1.0, 1.0, 0.0, 0.0, + 13.50, 0.5, -2.0, 0.5, 0.0, + -6.75, 0.0, 1.0, -1.0, 0.33333333 + }); + + sd::ops::matrix_inverse op; + auto result = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + z->printIndexedBuffer("Output "); + exp.printIndexedBuffer("Expected "); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} +*/ +TEST_F(DeclarableOpsTests6, MatrixInverse_03) { + + auto x = NDArrayFactory::create('c', {5, 5}, { + 4.f, 0.f, 0.f, 0.f, 0.f, + 4.f, 2.f, 0.f, 0.f, 0.f, + 30.f, 2.f, 1.f, 0.f, 0.f, + 8.f, 6.f, 4.f, 2.f, 0.f, + 15.f, 12.f, 9.f, 6.f, 3.f, + }); + + auto exp = NDArrayFactory::create('c', {5, 5}, { + 0.25f, 0.0f, 0.0f, 0.0f, 0.0f, + -0.50f, 0.5f, 0.0f, 0.0f, 0.0f, + -6.50f, -1.0f, 1.0f, 0.0f, 0.0f, + 13.50f, 0.5f, -2.0f, 0.5f, 0.0f, + -6.75f, 0.0f, 1.0f, -1.0f, 0.33333333f + }); + + sd::ops::matrix_inverse op; + auto result = op.evaluate({&x}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); +// z->printIndexedBuffer("Output "); +// exp.printIndexedBuffer("Expected "); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, MatrixInverse_3) { + + auto x = NDArrayFactory::create('c', {5, 5}, { + 4.f, 0.f, 0.f, 0.f, 0.f, + 4.f, 2.f, 0.f, 0.f, 0.f, + 30.f, 2.f, 1.f, 0.f, 0.f, + 8.f, 6.f, 4.f, 2.f, 0.f, + 15.f, 12.f, 9.f, 6.f, 3.f, + }); + + auto exp = NDArrayFactory::create('c', {5, 5}, { + 0.25f, 0.0f, 0.0f, 0.0f, 0.0f, + -0.50f, 0.5f, 0.0f, 0.0f, 0.0f, + -6.50f, -1.0f, 1.0f, 0.0f, 0.0f, + 13.50f, 0.5f, -2.0f, 0.5f, 0.0f, + -6.75f, 0.0f, 1.0f, -1.0f, 0.33333333f + }); + + sd::ops::matrix_inverse op; + auto result = op.evaluate({&x}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); +// exp.printIndexedBuffer("Expected "); +// z->printIndexedBuffer("Output "); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, MatrixInverse_4) { + + auto x = NDArrayFactory::create('c', {5, 5}, { + 1.f, 2.f, 30.f, 4.f, 5.f, + 0.f, 1.f, 2.f, 3.f, 4.f, + 0.f, 0.f, 1.f, 2.f, 3.f, + 0.f, 0.f, 0.f, 1.f, 2.f, + 0.f, 0.f, 0.f, 0.f, 1.f + }); + + auto exp = NDArrayFactory::create('c', {5, 5}, { + 1.0f, -2.0f, -26.0f, 54.0f, -27.0f, + 0.0f, 1.0f, -2.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 1.0f, -2.0f, 1.0f, + 0.0f, 0.0f, 0.0f, 1.0f, -2.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 1.0f + }); + + sd::ops::matrix_inverse op; + auto result = op.evaluate({&x}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); +// z->printIndexedBuffer("Output "); +// exp.printIndexedBuffer("Expected "); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, MatrixInverse_04) { + + auto x = NDArrayFactory::create('c', {5, 5}, { + 1.f, 2.f, 30.f, 4.f, 5.f, + 0.f, 1.f, 2.f, 3.f, 4.f, + 0.f, 0.f, 1.f, 2.f, 3.f, + 0.f, 0.f, 0.f, 1.f, 2.f, + 0.f, 0.f, 0.f, 0.f, 1.f + }); + + auto exp = NDArrayFactory::create('c', {5, 5}, { + 1.0f, -2.0f, -26.0f, 54.0f, -27.0f, + 0.0f, 1.0f, -2.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 1.0f, -2.0f, 1.0f, + 0.0f, 0.0f, 0.0f, 1.0f, -2.0f, + 0.0f, 0.0f, 0.0f, 0.0f, 1.0f + }); + + sd::ops::matrix_inverse op; + auto result = op.evaluate({&x}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); +// z->printIndexedBuffer("Output "); +// exp.printIndexedBuffer("Expected "); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, ReluLayer_1) { + auto x = NDArrayFactory::create('c', {3, 4}, {1.0, -2.0, 3.0, 4.0, 5.0, -6.0, 7.0, 8.0, 9.0, -10.0, 11.0, 12}); + auto w = NDArrayFactory::create('c', {4, 3}, {0.5, 0.1, 0.8, 0.5, 0.2, 0.5, 0.5, 0.25, 0.5, 0.1, 0.0, 0.25}); + auto b = NDArrayFactory::create({20.0, 30.0, 50.0}); + + + + auto exp = NDArrayFactory::create('c', {3, 3}, { + 21.4, 30.45, 52.3, + 23.8, 31.05, 56.5, + 26.2, 31.65, 60.7}); + + sd::ops::relu_layer op; + auto result = op.evaluate({&x, &w, &b}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + // z->printShapeInfo("Output shape"); + // z->printIndexedBuffer("Output "); + // exp.printIndexedBuffer("Expected "); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests6, Test_Reduce3_Edge) { + auto x = NDArrayFactory::create('c', {3, 4, 5}); + auto y = NDArrayFactory::create('c', {3, 4, 5}); + + + std::vector dims = {0, 1}; + auto z = x.applyReduce3(reduce3::CosineSimilarity, y, dims); + ASSERT_TRUE(&z != nullptr); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, static_rnn_test1) { + + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; + const int time = 5; + + auto x = NDArrayFactory::create('c', {time, bS, inSize}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2*numUnits}); + auto h0 = NDArrayFactory::create('c', {bS, numUnits}); + auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, time-3}); + + x.linspace(0.01, 0.01); + h0 = 0.2; + Wx = 0.3; + Wh = 0.4; + b = 0.25; + + auto expH = NDArrayFactory::create('c', {time, bS, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.69882484, 0.69882484, 0.69882484, 0.69882484, 0.9312333 , 0.9312333 , 0.9312333 , 0.9312333 , + 0.93751527, 0.93751527, 0.93751527, 0.93751527,0.97136768, 0.97136768, 0.97136768, 0.97136768,0., 0., 0., 0. , + 0.97732812, 0.97732812, 0.97732812, 0.97732812,0., 0., 0., 0. ,0., 0., 0., 0.,0., 0., 0., 0.}); + + auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.93751527, 0.93751527, 0.93751527, 0.93751527}); + + sd::ops::static_rnn op; + auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFinal = results.at(1); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFinal.isSameShape(hFinal)); + ASSERT_TRUE(expHFinal.equalsTo(hFinal)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, static_rnn_test2) { + + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; + const int time = 5; + + auto x = NDArrayFactory::create('c', {time, bS, inSize}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2*numUnits}); + auto h0 = NDArrayFactory::create('c', {bS, numUnits}); + + x.linspace(0.01, 0.01); + h0 = 0.2; + Wx = 0.3; + Wh = 0.4; + b = 0.25; + + auto expH = NDArrayFactory::create('c', {time, bS, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.69882484, 0.69882484, 0.69882484, 0.69882484,0.9312333 , 0.9312333 , 0.9312333 , 0.9312333, + 0.93751527, 0.93751527, 0.93751527, 0.93751527,0.97136768, 0.97136768, 0.97136768, 0.97136768,0.97338548, 0.97338548, 0.97338548, 0.97338548, + 0.97732812, 0.97732812, 0.97732812, 0.97732812,0.97864398, 0.97864398, 0.97864398, 0.97864398,0.98000654, 0.98000654, 0.98000654, 0.98000654, + 0.98112648, 0.98112648, 0.98112648, 0.98112648}); + + auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.98000654, 0.98000654, 0.98000654, 0.98000654,0.98112648, 0.98112648, 0.98112648, 0.98112648}); + + sd::ops::static_rnn op; + auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFinal = results.at(1); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFinal.isSameShape(hFinal)); + ASSERT_TRUE(expHFinal.equalsTo(hFinal)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, static_rnn_test3) { + + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; + const int time = 5; + + auto x = NDArrayFactory::create('c', {time, bS, inSize}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2*numUnits}); + auto h0 = NDArrayFactory::create('c', {bS, numUnits}); + auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, 0}); + + x.linspace(0.01, 0.01); + h0 = 0.2; + Wx = 0.3; + Wh = 0.4; + b = 0.25; + + auto expH = NDArrayFactory::create('c', {time, bS, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0., 0., 0., 0., 0.9312333, 0.9312333, 0.9312333, 0.9312333, + 0., 0., 0., 0. , 0.97136768, 0.97136768, 0.97136768, 0.97136768,0., 0., 0., 0. , + 0.97732812, 0.97732812, 0.97732812, 0.97732812,0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0.}); + + auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.2 , 0.2 , 0.2 , 0.2}); + + sd::ops::static_rnn op; + auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFinal = results.at(1); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFinal.isSameShape(hFinal)); + ASSERT_TRUE(expHFinal.equalsTo(hFinal)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, static_rnn_test4) { + + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; + const int time = 5; + + auto x = NDArrayFactory::create('c', {time, bS, inSize}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2*numUnits}); + auto h0 = NDArrayFactory::create('c', {bS, numUnits}); + auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, time-3}); + + x.linspace(0.01, 0.01); + Wx = 0.3; + Wh = 0.4; + b = 0.25; + + auto expH = NDArrayFactory::create('c', {time, bS, numUnits}, {0.47615493, 0.47615493, 0.47615493, 0.47615493,0.49676344, 0.49676344, 0.49676344, 0.49676344, 0.87018664, 0.87018664, 0.87018664, 0.87018664, + 0.88400882, 0.88400882, 0.88400882, 0.88400882, 0.96529784, 0.96529784, 0.96529784, 0.96529784,0., 0., 0., 0. , + 0.97688859, 0.97688859, 0.97688859, 0.97688859,0., 0., 0., 0., 0., 0., 0., 0.,0., 0., 0., 0.}); + + auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.97688859, 0.97688859, 0.97688859, 0.97688859, 0.88400882, 0.88400882, 0.88400882, 0.88400882}); + + sd::ops::static_rnn op; + auto results = op.evaluate({&x, &Wx, &Wh, &b, &maxTimeStep}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFinal = results.at(1); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFinal.isSameShape(hFinal)); + ASSERT_TRUE(expHFinal.equalsTo(hFinal)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, static_rnn_test5) { + + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; + const int time = 5; + + auto x = NDArrayFactory::create('c', {time, bS, inSize}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2*numUnits}); + auto h0 = NDArrayFactory::create('c', {bS, numUnits}); + + x.linspace(0.01, 0.01); + Wx = 0.3; + Wh = 0.4; + b = 0.25; + + auto expH = NDArrayFactory::create('c', {time, bS, numUnits}, {0.47615493, 0.47615493, 0.47615493, 0.47615493,0.49676344, 0.49676344, 0.49676344, 0.49676344, 0.87018664, 0.87018664, 0.87018664, 0.87018664, + 0.88400882, 0.88400882, 0.88400882, 0.88400882, 0.96529784, 0.96529784, 0.96529784, 0.96529784,0.96849345, 0.96849345, 0.96849345, 0.96849345, + 0.97688859, 0.97688859, 0.97688859, 0.97688859,0.97831069, 0.97831069, 0.97831069, 0.97831069, 0.97997868, 0.97997868, 0.97997868, 0.97997868, + 0.98110653, 0.98110653, 0.98110653, 0.98110653}); + + auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.97997868, 0.97997868, 0.97997868, 0.97997868, 0.98110653, 0.98110653, 0.98110653, 0.98110653}); + + sd::ops::static_rnn op; + auto results = op.evaluate({&x, &Wx, &Wh, &b}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFinal = results.at(1); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFinal.isSameShape(hFinal)); + ASSERT_TRUE(expHFinal.equalsTo(hFinal)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, static_bidir_rnn_test1) { + + const int bS = 4; + const int inSize = 4; + const int numUnitsFW = 3; + const int numUnitsBW = 3; + const int time = 5; + + auto x = NDArrayFactory::create('c', {time, bS, inSize}); + auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); + auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); + auto bFW = NDArrayFactory::create('c', {2*numUnitsFW}); + + auto h0FW = NDArrayFactory::create('c', {bS, numUnitsFW}); + auto h0BW = NDArrayFactory::create('c', {bS, numUnitsBW}); + auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, time-3, time-4, 0}); + + x.linspace(0.01, 0.01); + h0FW = 0.2; + h0BW = 0.25; + WxFW = 0.3; + WhFW = 0.4; + bFW = 0.1; + + auto expH = NDArrayFactory::create('c', {time, bS, numUnitsFW+numUnitsBW}, {0.43819931, 0.43819931, 0.43819931, 0.86708881, 0.86708881,0.86708881,0.47615493, 0.47615493, 0.47615493, 0.78347842, 0.78347842,0.78347842, + 0.51241561, 0.51241561, 0.51241561, 0.55529176, 0.55529176,0.55529176,0., 0., 0., 0., 0.,0.,0.73880324, 0.73880324, 0.73880324, 0.90935605, 0.90935605, + 0.90935605, 0.77843476, 0.77843476, 0.77843476, 0.64692945, 0.64692945,0.64692945,0., 0., 0., 0., 0.,0.,0., 0., 0., 0., 0.,0., + 0.9052501, 0.9052501, 0.9052501, 0.9181592, 0.9181592, 0.9181592,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., + 0.9555734, 0.9555734, 0.9555734, 0.8026439, 0.8026439, 0.8026439,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}); + + auto expHFWfinal = NDArrayFactory::create('c', {bS, numUnitsFW}, {0.9555734 , 0.9555734 , 0.9555734 , 0.77843476, 0.77843476, 0.77843476, 0.51241561, 0.51241561, 0.51241561, 0.2, 0.2, 0.2}); + auto expHBWfinal = NDArrayFactory::create('c', {bS, numUnitsBW}, {0.86708881, 0.86708881, 0.86708881, 0.78347842, 0.78347842, 0.78347842, 0.55529176, 0.55529176, 0.55529176, 0.25, 0.25, 0.25}); + + sd::ops::static_bidirectional_rnn op; + auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFWfinal = results.at(1); + auto hBWfinal = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); + ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); + ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); + ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, static_bidir_rnn_test2) { + + const int bS = 4; + const int inSize = 4; + const int numUnitsFW = 3; + const int numUnitsBW = 3; + const int time = 5; + + auto x = NDArrayFactory::create('c', {time, bS, inSize}); + auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); + auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); + auto bFW = NDArrayFactory::create('c', {2*numUnitsFW}); + + auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, time-3, time-4, 0}); + + x.linspace(0.01, 0.01); + WxFW = 0.3; + WhFW = 0.4; + bFW = 0.1; + + auto expH = NDArrayFactory::create('c', {time, bS, numUnitsFW+numUnitsBW}, {0.22602835, 0.22602835, 0.22602835, 0.86518273, 0.86518273,0.86518273,0.27105303, 0.27105303, 0.27105303, 0.66617761, 0.66617761,0.66617761, + 0.31492203, 0.31492203, 0.31492203, 0.31492203, 0.31492203,0.31492203,0. , 0. , 0. , 0. , 0. ,0. , + 0.60005558, 0.60005558, 0.60005558, 0.9029975 , 0.9029975 ,0.9029975 ,0.66138054, 0.66138054, 0.66138054, 0.43819931, 0.43819931,0.43819931, + 0. , 0. , 0. , 0. , 0. ,0. ,0. , 0. , 0. , 0. , 0. ,0. , + 0.87023975, 0.87023975, 0.87023975, 0.88852032, 0.88852032,0.88852032,0. , 0. , 0. , 0. , 0. ,0. , + 0. , 0. , 0. , 0. , 0. ,0. ,0. , 0. , 0. , 0. , 0. ,0. , + 0.95177305, 0.95177305, 0.95177305, 0.66737775, 0.66737775,0.66737775,0. , 0. , 0. , 0. , 0. ,0. , + 0. , 0. , 0. , 0. , 0. ,0. ,0. , 0. , 0. , 0. , 0. ,0. , + 0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.,0., 0., 0., 0., 0., 0.}); + + auto expHFWfinal = NDArrayFactory::create('c', {bS, numUnitsFW}, {0.95177305, 0.95177305, 0.95177305, 0.66138054, 0.66138054, 0.66138054, 0.31492203, 0.31492203, 0.31492203, 0. , 0. , 0.}); + auto expHBWfinal = NDArrayFactory::create('c', {bS, numUnitsBW}, {0.86518273, 0.86518273, 0.86518273, 0.66617761, 0.66617761, 0.66617761, 0.31492203, 0.31492203, 0.31492203, 0. , 0. , 0.}); + + sd::ops::static_bidirectional_rnn op; + auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &maxTimeStep}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFWfinal = results.at(1); + auto hBWfinal = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); + ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); + ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); + ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); + + +} + + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, static_bidir_rnn_test3) { + + const int bS = 4; + const int inSize = 4; + const int numUnitsFW = 3; + const int numUnitsBW = 3; + const int time = 5; + + auto x = NDArrayFactory::create('c', {time, bS, inSize}); + auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); + auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); + auto bFW = NDArrayFactory::create('c', {2*numUnitsFW}); + + x.linspace(0.01, 0.01); + WxFW = 0.3; + WhFW = 0.4; + bFW = 0.1; + + auto expH = NDArrayFactory::create('c', {time, bS, numUnitsFW+numUnitsBW}, {0.22602835, 0.22602835, 0.22602835, 0.86841012, 0.86841012,0.86841012,0.27105303, 0.27105303, 0.27105303, 0.88207531, 0.88207531,0.88207531, + 0.31492203, 0.31492203, 0.31492203, 0.8941667 , 0.8941667 ,0.8941667 ,0.35748551, 0.35748551, 0.35748551, 0.90489713, 0.90489713, + 0.90489713, 0.60005558, 0.60005558, 0.60005558, 0.91381375, 0.91381375,0.91381375,0.66138054, 0.66138054, 0.66138054, 0.92253504, 0.92253504, + 0.92253504,0.71429879, 0.71429879, 0.71429879, 0.93027876, 0.93027876,0.93027876,0.75947891, 0.75947891, 0.75947891, 0.9371767 , 0.9371767 , + 0.9371767 , 0.87023975, 0.87023975, 0.87023975, 0.94014274, 0.94014274,0.94014274,0.89680574, 0.89680574, 0.89680574, 0.94648926, 0.94648926, + 0.94648926,0.91657261, 0.91657261, 0.91657261, 0.95204779, 0.95204779,0.95204779,0.93146896, 0.93146896, 0.93146896, 0.95694206, 0.95694206, + 0.95694206, 0.95177305, 0.95177305, 0.95177305, 0.93773086, 0.93773086,0.93773086,0.95874689, 0.95874689, 0.95874689, 0.94579176, 0.94579176, + 0.94579176,0.96416067, 0.96416067, 0.96416067, 0.95267886, 0.95267886,0.95267886,0.96851506, 0.96851506, 0.96851506, 0.95857985, 0.95857985, + 0.95857985, 0.97269956, 0.97269956, 0.97269956, 0.76075293, 0.76075293,0.76075293,0.97557464, 0.97557464, 0.97557464, 0.78024637, 0.78024637, + 0.78024637,0.97806922, 0.97806922, 0.97806922, 0.79833344, 0.79833344,0.79833344,0.98026195, 0.98026195, 0.98026195, 0.81508646, 0.81508646,0.81508646}); + + auto expHFWfinal = NDArrayFactory::create('c', {bS, numUnitsFW}, {0.97269956, 0.97269956, 0.97269956, 0.97557464, 0.97557464, 0.97557464, 0.97806922, 0.97806922, 0.97806922, 0.98026195, 0.98026195, 0.98026195}); + auto expHBWfinal = NDArrayFactory::create('c', {bS, numUnitsBW}, {0.86841012, 0.86841012, 0.86841012, 0.88207531, 0.88207531, 0.88207531, 0.8941667 , 0.8941667 , 0.8941667 , 0.90489713, 0.90489713, 0.90489713}); + + sd::ops::static_bidirectional_rnn op; + auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFWfinal = results.at(1); + auto hBWfinal = results.at(2); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); + ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); + ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); + ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, dynamic_rnn_test1) { + + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; + const int time = 5; + + auto x = NDArrayFactory::create('c', {time, bS, inSize}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2*numUnits}); + auto h0 = NDArrayFactory::create('c', {bS, numUnits}); + auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, time-3}); + + x.linspace(0.01, 0.01); + h0 = 0.2; + Wx = 0.3; + Wh = 0.4; + b = 0.25; + + auto expH = NDArrayFactory::create('c', {time, bS, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.69882484, 0.69882484, 0.69882484, 0.69882484,0.9312333 , 0.9312333 , 0.9312333 , 0.9312333 , + 0.93751527, 0.93751527, 0.93751527, 0.93751527,0.97136768, 0.97136768, 0.97136768, 0.97136768,0. , 0. , 0. , 0. , + 0.97732812, 0.97732812, 0.97732812, 0.97732812,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. }); + + auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.97732812, 0.97732812, 0.97732812, 0.97732812, 0.93751527, 0.93751527, 0.93751527, 0.93751527}); + + sd::ops::dynamic_rnn op; + auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFinal = results.at(1); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFinal.isSameShape(hFinal)); + ASSERT_TRUE(expHFinal.equalsTo(hFinal)); + + +} + + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, dynamic_rnn_test2) { + + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; + const int time = 5; + + auto x = NDArrayFactory::create('c', {bS, time, inSize}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2*numUnits}); + auto h0 = NDArrayFactory::create('c', {bS, numUnits}); + auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, time}); + + x.linspace(0.01, 0.01); + h0 = 0.2; + Wx = 0.3; + Wh = 0.4; + b = 0.25; + + auto expH = NDArrayFactory::create('c', {bS, time, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.92755601, 0.92755601, 0.92755601, 0.92755601,0.96778334, 0.96778334, 0.96778334, + 0.96778334,0.97309129, 0.97309129, 0.97309129, 0.97309129,0. , 0. , 0. , 0. , + 0.75001965, 0.75001965, 0.75001965, 0.75001965,0.95449491, 0.95449491, 0.95449491, 0.95449491,0.97732828, 0.97732828, 0.97732828, + 0.97732828,0.98000655, 0.98000655, 0.98000655, 0.98000655,0.98120782, 0.98120782, 0.98120782, 0.98120782}); + + auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.97309129, 0.97309129, 0.97309129, 0.97309129, 0.98120782, 0.98120782, 0.98120782, 0.98120782}); + + sd::ops::dynamic_rnn op; + auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0, &maxTimeStep}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFinal = results.at(1); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFinal.isSameShape(hFinal)); + ASSERT_TRUE(expHFinal.equalsTo(hFinal)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, dynamic_rnn_test3) { + + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; + const int time = 5; + + auto x = NDArrayFactory::create('c', {bS, time, inSize}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2*numUnits}); + auto h0 = NDArrayFactory::create('c', {bS, numUnits}); + + x.linspace(0.01, 0.01); + h0 = 0.2; + Wx = 0.3; + Wh = 0.4; + b = 0.25; + + auto expH = NDArrayFactory::create('c', {bS, time, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.92755601, 0.92755601, 0.92755601, 0.92755601,0.96778334, 0.96778334, 0.96778334, 0.96778334,0.97309129, + 0.97309129, 0.97309129, 0.97309129,0.97491207, 0.97491207, 0.97491207, 0.97491207,0.75001965, 0.75001965, 0.75001965, 0.75001965,0.95449491, 0.95449491, + 0.95449491, 0.95449491,0.97732828, 0.97732828, 0.97732828, 0.97732828,0.98000655, 0.98000655, 0.98000655, 0.98000655,0.98120782, 0.98120782, 0.98120782, 0.98120782}); + + auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.97491207, 0.97491207, 0.97491207, 0.97491207, 0.98120782, 0.98120782, 0.98120782, 0.98120782}); + + sd::ops::dynamic_rnn op; + auto results = op.evaluate({&x, &Wx, &Wh, &b, &h0}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFinal = results.at(1); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFinal.isSameShape(hFinal)); + ASSERT_TRUE(expHFinal.equalsTo(hFinal)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, dynamic_rnn_test4) { + + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; + const int time = 5; + + auto x = NDArrayFactory::create('c', {bS, time, inSize}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2*numUnits}); + auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, time-4}); + + x.linspace(0.01, 0.01); + Wx = 0.3; + Wh = 0.4; + b = 0.25; + + auto expH = NDArrayFactory::create('c', {bS, time, numUnits}, {0.47615493, 0.47615493, 0.47615493, 0.47615493,0.86347567, 0.86347567, 0.86347567, 0.86347567,0.96059545, 0.96059545, + 0.96059545, 0.96059545,0.9724738 , 0.9724738 , 0.9724738 , 0.9724738 ,0. , 0. , 0. , 0. , + 0.57368608, 0.57368608, 0.57368608, 0.57368608,0. , 0. , 0 , 0. ,0., 0. , 0, 0.,0., 0., 0. , 0. ,0. , 0. , 0., 0. }); + + auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.9724738 , 0.9724738 , 0.9724738 , 0.9724738 ,0.57368608, 0.57368608, 0.57368608, 0.57368608}); + + sd::ops::dynamic_rnn op; + auto results = op.evaluate({&x, &Wx, &Wh, &b, &maxTimeStep}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFinal = results.at(1); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFinal.isSameShape(hFinal)); + ASSERT_TRUE(expHFinal.equalsTo(hFinal)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, dynamic_rnn_test5) { + + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; + const int time = 5; + + auto x = NDArrayFactory::create('c', {bS, time, inSize}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2*numUnits}); + + x.linspace(0.01, 0.01); + Wx = 0.3; + Wh = 0.4; + b = 0.25; + + auto expH = NDArrayFactory::create('c', {bS, time, numUnits}, {0.47615493, 0.47615493, 0.47615493, 0.47615493,0.86347567, 0.86347567, 0.86347567, 0.86347567,0.96059545, 0.96059545, 0.96059545, 0.96059545, + 0.9724738 , 0.9724738 , 0.9724738 , 0.9724738 ,0.97486307, 0.97486307, 0.97486307, 0.97486307,0.57368608, 0.57368608, 0.57368608, 0.57368608, + 0.92135149, 0.92135149, 0.92135149, 0.92135149,0.97482354, 0.97482354, 0.97482354, 0.97482354,0.97984727, 0.97984727, 0.97984727, 0.97984727, + 0.98119833, 0.98119833, 0.98119833, 0.98119833}); + + auto expHFinal = NDArrayFactory::create('c', {bS, numUnits}, {0.97486307, 0.97486307, 0.97486307, 0.97486307,0.98119833, 0.98119833, 0.98119833, 0.98119833}); + + sd::ops::dynamic_rnn op; + auto results = op.evaluate({&x, &Wx, &Wh, &b}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto h = results.at(0); + auto hFinal = results.at(1); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expHFinal.isSameShape(hFinal)); + ASSERT_TRUE(expHFinal.equalsTo(hFinal)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test1) { + + const int bS = 4; + const int inSize = 4; + const int numUnitsFW = 3; + const int numUnitsBW = 3; + const int time = 5; + + auto x = NDArrayFactory::create('c', {time, bS, inSize}); + auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); + auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); + auto bFW = NDArrayFactory::create('c', {2*numUnitsFW}); + + auto h0FW = NDArrayFactory::create('c', {bS, numUnitsFW}); + auto h0BW = NDArrayFactory::create('c', {bS, numUnitsBW}); + auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, time-3, time-4, 0}); + + x.linspace(0.01, 0.01); + h0FW = 0.2; + h0BW = 0.25; + WxFW = 0.3; + WhFW = 0.4; + bFW = 0.1; + + auto expHFW = NDArrayFactory::create('c', {time, bS, numUnitsFW}, {0.43819931, 0.43819931, 0.43819931,0.47615493, 0.47615493, 0.47615493,0.51241561, 0.51241561, 0.51241561,0. , 0. , 0. , + 0.73880324, 0.73880324, 0.73880324,0.77843476, 0.77843476, 0.77843476,0. , 0. , 0. ,0. , 0. , 0. , + 0.9052501 , 0.9052501 , 0.9052501 ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , + 0.9555734 , 0.9555734 , 0.9555734 ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , + 0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. }); + + auto expHBW = NDArrayFactory::create('c', {time, bS, numUnitsBW}, {0.86708881, 0.86708881, 0.86708881,0.78347842, 0.78347842, 0.78347842,0.55529176, 0.55529176, 0.55529176,0. , 0. , 0. , + 0.90935605, 0.90935605, 0.90935605,0.64692945, 0.64692945, 0.64692945,0. , 0. , 0. ,0. , 0. , 0. , + 0.9181592 , 0.9181592 , 0.9181592 ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , + 0.8026439 , 0.8026439 , 0.8026439 ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , + 0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. }); + + auto expHFWfinal = NDArrayFactory::create('c', {bS, numUnitsFW}, {0.9555734 , 0.9555734 , 0.9555734 , 0.77843476, 0.77843476, 0.77843476, 0.51241561, 0.51241561, 0.51241561, 0.2 , 0.2 , 0.2}); + auto expHBWfinal = NDArrayFactory::create('c', {bS, numUnitsBW}, {0.86708881, 0.86708881, 0.86708881, 0.78347842, 0.78347842, 0.78347842, 0.55529176, 0.55529176, 0.55529176, 0.25 , 0.25 , 0.25}); + + sd::ops::dynamic_bidirectional_rnn op; + auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto hFW = results.at(0); + auto hBW = results.at(1); + auto hFWfinal = results.at(2); + auto hBWfinal = results.at(3); + + ASSERT_TRUE(expHFW.isSameShape(hFW)); + ASSERT_TRUE(expHFW.equalsTo(hFW)); + ASSERT_TRUE(expHBW.isSameShape(hBW)); + ASSERT_TRUE(expHBW.equalsTo(hBW)); + ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); + ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); + ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); + ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test2) { + + const int bS = 4; + const int inSize = 4; + const int numUnitsFW = 3; + const int numUnitsBW = 3; + const int time = 5; + + auto x = NDArrayFactory::create('c', {bS, time, inSize}); + auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); + auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); + auto bFW = NDArrayFactory::create('c', {2*numUnitsFW}); + + auto h0FW = NDArrayFactory::create('c', {bS, numUnitsFW}); + auto h0BW = NDArrayFactory::create('c', {bS, numUnitsBW}); + auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, time-3, time-4, 0}); + + x.linspace(0.01, 0.01); + h0FW = 0.2; + h0BW = 0.25; + WxFW = 0.3; + WhFW = 0.4; + bFW = 0.1; + + auto expHFW = NDArrayFactory::create('c', {bS, time, numUnitsFW}, {0.43819931, 0.43819931, 0.43819931,0.66617761, 0.66617761, 0.66617761,0.80944357, 0.80944357, 0.80944357,0.87294706, 0.87294706, 0.87294706,0. , 0. , 0. , + 0.61067683, 0.61067683, 0.61067683,0.84851124, 0.84851124, 0.84851124,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , + 0.73978305, 0.73978305, 0.73978305,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , + 0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. }); + + auto expHBW = NDArrayFactory::create('c', {bS, time, numUnitsBW}, {0.84345207, 0.84345207, 0.84345207,0.83584708, 0.83584708, 0.83584708,0.77435951, 0.77435951, 0.77435951,0.58760492, 0.58760492, 0.58760492,0. , 0. , 0. , + 0.85615841, 0.85615841, 0.85615841,0.67397984, 0.67397984, 0.67397984,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , + 0.76576202, 0.76576202, 0.76576202,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , + 0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. }); + + auto expHFWfinal = NDArrayFactory::create('c', {bS, numUnitsFW}, {0.87294706, 0.87294706, 0.87294706,0.84851124, 0.84851124, 0.84851124,0.73978305, 0.73978305, 0.73978305,0.2 , 0.2 , 0.2}); + auto expHBWfinal = NDArrayFactory::create('c', {bS, numUnitsBW}, {0.84345207, 0.84345207, 0.84345207, 0.85615841, 0.85615841, 0.85615841, 0.76576202, 0.76576202, 0.76576202, 0.25 , 0.25 , 0.25}); + + sd::ops::dynamic_bidirectional_rnn op; + auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW, &maxTimeStep}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto hFW = results.at(0); + auto hBW = results.at(1); + auto hFWfinal = results.at(2); + auto hBWfinal = results.at(3); + + ASSERT_TRUE(expHFW.isSameShape(hFW)); + ASSERT_TRUE(expHFW.equalsTo(hFW)); + ASSERT_TRUE(expHBW.isSameShape(hBW)); + ASSERT_TRUE(expHBW.equalsTo(hBW)); + ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); + ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); + ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); + ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test3) { + + const int bS = 4; + const int inSize = 4; + const int numUnitsFW = 3; + const int numUnitsBW = 3; + const int time = 5; + + auto x = NDArrayFactory::create('c', {bS, time, inSize}); + auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); + auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); + auto bFW = NDArrayFactory::create('c', {2*numUnitsFW}); + + auto maxTimeStep = NDArrayFactory::create('c', {bS}, {time-1, time-3, time-4, 0}); + + x.linspace(0.01, 0.01); + WxFW = 0.3; + WhFW = 0.4; + bFW = 0.1; + + auto expHFW = NDArrayFactory::create('c', {bS, time, numUnitsFW}, {0.22602835, 0.22602835, 0.22602835,0.49994591, 0.49994591, 0.49994591,0.72869307, 0.72869307, 0.72869307,0.84784327, 0.84784327, 0.84784327,0. , 0. , 0. , + 0.43819931, 0.43819931, 0.43819931,0.7793996 , 0.7793996 , 0.7793996 ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , + 0.61067683, 0.61067683, 0.61067683,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , + 0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. }); + + auto expHBW = NDArrayFactory::create('c', {bS, time, numUnitsBW}, {0.82273707, 0.82273707, 0.82273707,0.77935851, 0.77935851, 0.77935851,0.6381121 , 0.6381121 , 0.6381121 ,0.35748551, 0.35748551, 0.35748551,0. , 0. , 0. , + 0.77843476, 0.77843476, 0.77843476,0.47615493, 0.47615493, 0.47615493,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , + 0.61067683, 0.61067683, 0.61067683,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. , + 0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. ,0. , 0. , 0. }); + + auto expHFWfinal = NDArrayFactory::create('c', {bS, numUnitsFW}, {0.84784327, 0.84784327, 0.84784327, 0.7793996 , 0.7793996 , 0.7793996 , 0.61067683, 0.61067683, 0.61067683, 0. , 0. , 0.}); + auto expHBWfinal = NDArrayFactory::create('c', {bS, numUnitsBW}, {0.82273707, 0.82273707, 0.82273707, 0.77843476, 0.77843476, 0.77843476, 0.61067683, 0.61067683, 0.61067683, 0. , 0. , 0.}); + + sd::ops::dynamic_bidirectional_rnn op; + auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &maxTimeStep}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto hFW = results.at(0); + auto hBW = results.at(1); + auto hFWfinal = results.at(2); + auto hBWfinal = results.at(3); + + ASSERT_TRUE(expHFW.isSameShape(hFW)); + ASSERT_TRUE(expHFW.equalsTo(hFW)); + ASSERT_TRUE(expHBW.isSameShape(hBW)); + ASSERT_TRUE(expHBW.equalsTo(hBW)); + ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); + ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); + ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); + ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test4) { + + const int bS = 4; + const int inSize = 4; + const int numUnitsFW = 3; + const int numUnitsBW = 3; + const int time = 5; + + auto x = NDArrayFactory::create('c', {bS, time, inSize}); + auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); + auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); + auto bFW = NDArrayFactory::create('c', {2*numUnitsFW}); + + auto h0FW = NDArrayFactory::create('c', {bS, numUnitsFW}); + auto h0BW = NDArrayFactory::create('c', {bS, numUnitsBW}); + + x.linspace(0.01, 0.01); + h0FW = 0.2; + h0BW = 0.25; + WxFW = 0.3; + WhFW = 0.4; + bFW = 0.1; + + auto expHFW = NDArrayFactory::create('c', {bS, time, numUnitsFW}, {0.43819931, 0.43819931, 0.43819931,0.66617761, 0.66617761, 0.66617761,0.80944357, 0.80944357, 0.80944357,0.87294706, 0.87294706, 0.87294706,0.89948899, 0.89948899, 0.89948899, + 0.61067683, 0.61067683, 0.61067683,0.84851124, 0.84851124, 0.84851124,0.91925737, 0.91925737, 0.91925737,0.93751395, 0.93751395, 0.93751395,0.94544483, 0.94544483, 0.94544483, + 0.73978305, 0.73978305, 0.73978305,0.92827068, 0.92827068, 0.92827068,0.95791111, 0.95791111, 0.95791111,0.96427356, 0.96427356, 0.96427356,0.96797541, 0.96797541, 0.96797541, + 0.83057887, 0.83057887, 0.83057887,0.96365083, 0.96365083, 0.96365083,0.97585698, 0.97585698, 0.97585698,0.97866981, 0.97866981, 0.97866981,0.9807326 , 0.9807326 , 0.9807326 }); + + auto expHBW = NDArrayFactory::create('c', {bS, time, numUnitsBW}, {0.85301722, 0.85301722, 0.85301722,0.86427295, 0.86427295, 0.86427295,0.8599919 , 0.8599919 , 0.8599919 ,0.80609463, 0.80609463, 0.80609463,0.61814662, 0.61814662, 0.61814662, + 0.91888753, 0.91888753, 0.91888753,0.92652672, 0.92652672, 0.92652672,0.92939674, 0.92939674, 0.92939674,0.90661931, 0.90661931, 0.90661931,0.74516764, 0.74516764, 0.74516764, + 0.95254269, 0.95254269, 0.95254269,0.95710717, 0.95710717, 0.95710717,0.96021584, 0.96021584, 0.96021584,0.95222547, 0.95222547, 0.95222547,0.83426363, 0.83426363, 0.83426363, + 0.97154357, 0.97154357, 0.97154357,0.97424915, 0.97424915, 0.97424915,0.97644817, 0.97644817, 0.97644817,0.97410547, 0.97410547, 0.97410547,0.89409962, 0.89409962, 0.89409962}); + + auto expHFWfinal = NDArrayFactory::create('c', {bS, numUnitsFW}, {0.89948899, 0.89948899, 0.89948899, 0.94544483, 0.94544483, 0.94544483, 0.96797541, 0.96797541, 0.96797541, 0.9807326 , 0.9807326 , 0.9807326 }); + auto expHBWfinal = NDArrayFactory::create('c', {bS, numUnitsBW}, {0.85301722, 0.85301722, 0.85301722, 0.91888753, 0.91888753, 0.91888753, 0.95254269, 0.95254269, 0.95254269, 0.97154357, 0.97154357, 0.97154357}); + + sd::ops::dynamic_bidirectional_rnn op; + auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW, &h0FW, &h0BW}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto hFW = results.at(0); + auto hBW = results.at(1); + auto hFWfinal = results.at(2); + auto hBWfinal = results.at(3); + + ASSERT_TRUE(expHFW.isSameShape(hFW)); + ASSERT_TRUE(expHFW.equalsTo(hFW)); + ASSERT_TRUE(expHBW.isSameShape(hBW)); + ASSERT_TRUE(expHBW.equalsTo(hBW)); + ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); + ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); + ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); + ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); + + +} + +TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test5) { + + const int bS = 4; + const int inSize = 4; + const int numUnitsFW = 3; + const int numUnitsBW = 3; + const int time = 5; + + auto x = NDArrayFactory::create('c', {bS, time, inSize}); + auto WxFW = NDArrayFactory::create('c', {inSize, numUnitsFW}); + auto WhFW = NDArrayFactory::create('c', {numUnitsFW, numUnitsFW}); + auto bFW = NDArrayFactory::create('c', {2*numUnitsFW}); + + x.linspace(0.01, 0.01); + WxFW = 0.3; + WhFW = 0.4; + bFW = 0.1; + + auto expHFW = NDArrayFactory::create('c', {bS, time, numUnitsFW}, {0.22602835, 0.22602835, 0.22602835,0.49994591, 0.49994591, 0.49994591,0.72869307, 0.72869307, 0.72869307,0.84784327, 0.84784327, 0.84784327,0.89357928, 0.89357928, 0.89357928, + 0.43819931, 0.43819931, 0.43819931,0.7793996 , 0.7793996 , 0.7793996 ,0.9053792 , 0.9053792 , 0.9053792 ,0.93546593, 0.93546593, 0.93546593,0.94518339, 0.94518339, 0.94518339, + 0.61067683, 0.61067683, 0.61067683,0.90347408, 0.90347408, 0.90347408,0.95538786, 0.95538786, 0.95538786,0.96406045, 0.96406045, 0.96406045,0.96795929, 0.96795929, 0.96795929, + 0.73978305, 0.73978305, 0.73978305,0.95499984, 0.95499984, 0.95499984,0.97535671, 0.97535671, 0.97535671,0.97864446, 0.97864446, 0.97864446,0.98073144, 0.98073144, 0.98073144}); + + auto expHBW = NDArrayFactory::create('c', {bS, time, numUnitsBW}, {0.84882345, 0.84882345, 0.84882345,0.85160683, 0.85160683, 0.85160683,0.81997657, 0.81997657, 0.81997657,0.69228829, 0.69228829, 0.69228829,0.39861399, 0.39861399, 0.39861399, + 0.91865453, 0.91865453, 0.91865453,0.92528094, 0.92528094, 0.92528094,0.92212167, 0.92212167, 0.92212167,0.86418213, 0.86418213, 0.86418213,0.57969286, 0.57969286, 0.57969286, + 0.95252666, 0.95252666, 0.95252666,0.95696305, 0.95696305, 0.95696305,0.95878749, 0.95878749, 0.95878749,0.93722463, 0.93722463, 0.93722463,0.71727031, 0.71727031, 0.71727031, + 0.97154234, 0.97154234, 0.97154234,0.97423089, 0.97423089, 0.97423089,0.976149 , 0.976149 , 0.976149 ,0.96878298, 0.96878298, 0.96878298,0.81508646, 0.81508646, 0.81508646}); + + auto expHFWfinal = NDArrayFactory::create('c', {bS, numUnitsFW}, {0.89357928, 0.89357928, 0.89357928, 0.94518339, 0.94518339, 0.94518339, 0.96795929, 0.96795929, 0.96795929, 0.98073144, 0.98073144, 0.98073144}); + auto expHBWfinal = NDArrayFactory::create('c', {bS, numUnitsBW}, {0.84882345, 0.84882345, 0.84882345, 0.91865453, 0.91865453, 0.91865453, 0.95252666, 0.95252666, 0.95252666, 0.97154234, 0.97154234, 0.97154234}); + + sd::ops::dynamic_bidirectional_rnn op; + auto results = op.evaluate({&x, &WxFW,&WhFW,&bFW, &WxFW,&WhFW,&bFW}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto hFW = results.at(0); + auto hBW = results.at(1); + auto hFWfinal = results.at(2); + auto hBWfinal = results.at(3); + + ASSERT_TRUE(expHFW.isSameShape(hFW)); + ASSERT_TRUE(expHFW.equalsTo(hFW)); + ASSERT_TRUE(expHBW.isSameShape(hBW)); + ASSERT_TRUE(expHBW.equalsTo(hBW)); + ASSERT_TRUE(expHFWfinal.isSameShape(hFWfinal)); + ASSERT_TRUE(expHFWfinal.equalsTo(hFWfinal)); + ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal)); + ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal)); + + +} + + +TEST_F(DeclarableOpsTests6, Test_Diag_119_1) { + auto x = NDArrayFactory::create('c', {3}, {0.15f, 0.25f, 0.35f}); + auto e = NDArrayFactory::create('c', {3, 3}, {0.15f, 0.0f, 0.0f, 0.0f, 0.25f, 0.0f, 0.0f, 0.0f, 0.35f}); + + sd::ops::diag op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_EQ(e, *result.at(0)); + + +} + +TEST_F(DeclarableOpsTests6, Test_Diag_119_2) { + auto x = NDArrayFactory::create('c', {1}, {0.15f}); + auto e = NDArrayFactory::create('c', {1, 1}, {0.15f}); + + sd::ops::diag op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_EQ(e, *result.at(0)); + + +} + +TEST_F(DeclarableOpsTests6, Test_Diag_119_3) { + auto x = NDArrayFactory::create(0.15f); + auto e = NDArrayFactory::create('c', {1, 1}, {0.15f}); + + sd::ops::diag op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_EQ(e, *result.at(0)); + + +} + + diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests7.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests7.cpp new file mode 100644 index 000000000..a78165686 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests7.cpp @@ -0,0 +1,7000 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 09.02.18. +// + + +#include "testlayers.h" +#include +#include +#include +#include +#include + + +using namespace sd; +using namespace sd::graph; + + +class DeclarableOpsTests7 : public testing::Test { +public: + + DeclarableOpsTests7() { + printf("\n"); + fflush(stdout); + } +}; + +template +class TypedDeclarableOpsTests7 : public testing::Test { +public: + + TypedDeclarableOpsTests7() { + printf("\n"); + fflush(stdout); + } +}; + +typedef ::testing::Types TestingTypes; +TYPED_TEST_CASE(TypedDeclarableOpsTests7, TestingTypes); + +TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR_LARGE) { + double inputData[150] = { + 0, 0.51, 0.68, 0.69, 0.86, 0.91, 0.96, 0.97, 0.97, 1.03, 1.13, 1.16, 1.16, 1.17, 1.19, 1.25, 1.25, 1.26, 1.27, 1.28, 1.29, 1.29, 1.29, 1.30, 1.31, 1.32, 1.33, 1.33, 1.35, 1.35, 1.36, 1.37, 1.38, 1.40, 1.41, 1.42, 1.43, 1.44, 1.44, 1.45, 1.45, 1.47, 1.47, 1.51, 1.51, 1.51, 1.52, 1.53, 1.56, 1.57, 1.58, 1.59, 1.61, 1.62, 1.63, 1.63, 1.64, 1.64, 1.66, 1.66, 1.67, 1.67, 1.70, 1.70, 1.70, 1.72, 1.72, 1.72, 1.72, 1.73, 1.74, 1.74, 1.76, 1.76, 1.77, 1.77, 1.80, 1.80, 1.81, 1.82, 1.83, 1.83, 1.84, 1.84, 1.84, 1.85, 1.85, 1.85, 1.86, 1.86, 1.87, 1.88, 1.89, 1.89, 1.89, 1.89, 1.89, 1.91, 1.91, 1.91, 1.92, 1.94, 1.95, 1.97, 1.98, 1.98, 1.98, 1.98, 1.98, 1.99, 2, 2, 2.01, 2.01, 2.02, 2.03, 2.03, 2.03, 2.04, 2.04, 2.05, 2.06, 2.07, 2.08, 2.08, 2.08, 2.08, 2.09, 2.09, 2.10, 2.10, 2.11, 2.11, 2.11, 2.12, 2.12, 2.13, 2.13, 2.14, 2.14, 2.14, 2.14, 2.15, 2.15, 2.16, 2.16, 2.16, 2.16, 2.16, 2.17 + }; + + auto x = NDArrayFactory::create(inputData,'c',{1,149}); + sd::ops::choose op; + //greater than test + auto result = op.evaluate({&x}, {0.0},{3}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(1); + + ASSERT_EQ(148,z->e(0)); + //ASSERT_TRUE(exp.isSameShape(z)); + + + +} + +TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR_ZERO) { + std::vector data; + for(Nd4jLong i = 0; i < 4; i++) { + data.push_back(i); + } + + + + auto x = NDArrayFactory::create('c',{1,4},data); + sd::ops::choose op; + //greater than test + auto result = op.evaluate({&x}, {0.0},{3}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(1); + auto array = *z; + ASSERT_EQ(3,array.e(0)); + //ASSERT_TRUE(exp.isSameShape(z)); + + + +} + + +TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR) { + std::vector data; + for(Nd4jLong i = 0; i < 4; i++) { + data.push_back(i); + } + + + + auto x = NDArrayFactory::create('c',{1,4},data); + auto scalar = NDArrayFactory::create('c',{1,1},{0.0}); + sd::ops::choose op; + //greater than test + auto result = op.evaluate({&x,&scalar}, {1.0},{3}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_EQ(3, z->lengthOf()); + //ASSERT_TRUE(exp.isSameShape(z)); + + + +} + + +TEST_F(DeclarableOpsTests7, Test_CHOOSE_SCALAR_LEFT) { + std::vector data; + for(Nd4jLong i = 0; i < 4; i++) { + data.push_back(i); + } + + + + auto x = NDArrayFactory::create('c',{1,4},data); + auto scalar = NDArrayFactory::create('c',{1,1},{0.0}); + sd::ops::choose op; + //greater than test + auto result = op.evaluate({&scalar,&x}, {1.0},{3}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_EQ(3,z->lengthOf()); + //ASSERT_TRUE(exp.isSameShape(z)); + + + +} + + +TEST_F(DeclarableOpsTests7, Test_CHOOSE_ONLY_SCALAR) { + std::vector data; + for(Nd4jLong i = 0; i < 4; i++) { + data.push_back(i); + } + + + + auto x = NDArrayFactory::create('c',{1,4},data); + sd::ops::choose op; + //greater than test + auto result = op.evaluate({&x}, {1.0},{3}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_EQ(2,z->lengthOf()); + //ASSERT_TRUE(exp.isSameShape(z)); + + + +} + + +TEST_F(DeclarableOpsTests7, Test_CHOOSE_ONLY_SCALAR_GTE) { + std::vector data; + for(Nd4jLong i = 0; i < 4; i++) { + data.push_back(i); + } + + + + auto x = NDArrayFactory::create('c',{1,4},data); + sd::ops::choose op; + //greater than test + auto result = op.evaluate({&x}, {1.0},{5}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_EQ(3,z->lengthOf()); + //ASSERT_TRUE(exp.isSameShape(z)); + + + +} + + +TEST_F(DeclarableOpsTests7, TEST_WHERE) { + std::vector data; + std::vector mask; + std::vector put; + std::vector resultData; + std::vector assertion; + for(Nd4jLong i = 0; i < 4; i++) { + data.push_back(i); + if(i > 1) { + assertion.push_back(5.0); + mask.push_back(true); + } + else { + assertion.push_back(i); + mask.push_back(false); + } + + put.push_back(5.0); + resultData.push_back(0.0); + } + + + + + auto x = NDArrayFactory::create('c',{1,4},data); + auto maskArr = NDArrayFactory::create('c',{1,4},mask); + auto putArr = NDArrayFactory::create('c',{1,4},put); + auto resultArr = NDArrayFactory::create('c',{1,4},resultData); + sd::ops::where_np op; + //greater than test + // Nd4jStatus execute(std::initializer_list*> inputs, std::initializer_list*> outputs , std::initializer_list tArgs, std::initializer_list iArgs, bool isInplace = false); + + auto result = op.execute({&maskArr,&x,&putArr},{&resultArr}, {},{3}, {}, {}, false); + ASSERT_EQ(Status::OK(), result); + for(int i = 0; i < 4; i++) + ASSERT_EQ(assertion[i],resultArr.e(i)); + // auto z = result.at(0); + //ASSERT_EQ(4,z->lengthOf()); + //ASSERT_TRUE(exp.isSameShape(z)); + + +} + +TEST_F(DeclarableOpsTests7, TEST_WHERE_MASK) { + double x[300] = {1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0}; + double z[300] = {1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0}; + bool mask[300] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + double put[200] = {0.99666107,0.9867112,0.97686064,0.9671082,0.95745337,0.9478948,0.9384318,0.92906314,0.9197881,0.91060543,0.9015147,0.8925147,0.8836044,0.8747831,0.86605,0.85740393,0.8488442,0.84037,0.83198035,0.8236745,0.8154515,0.8073106,0.79925096,0.79127187,0.7833724,0.77555174,0.76780915,0.7601439,0.75255525,0.7450422,0.7376043,0.73024046,0.72295034,0.715733,0.7085876,0.7015135,0.69451016,0.68757665,0.6807124,0.6739167,0.66718876,0.66052806,0.6539338,0.6474054,0.6409421,0.6345435,0.6282087,0.6219371,0.6157281,0.60958105,0.6034956,0.59747064,0.5915059,0.5856007,0.57975453,0.5739667,0.5682366,0.5625637,0.5569475,0.5513874,0.54588276,0.540433,0.53503764,0.5296962,0.52440816,0.51917285,0.5139898,0.5088585,0.50377846,0.4987491,0.4937699,0.48884052,0.48396033,0.47912875,0.47434545,0.4696099,0.46492168,0.46028027,0.45568514,0.4511359,0.44663212,0.4421733,0.43775895,0.43338865,0.42906195,0.42477852,0.4205379,0.41633952,0.41218308,0.40806815,0.40399432,0.3999611,0.3959682,0.39201516,0.38810158,0.384227,0.38039115,0.37659356,0.37283397,0.3691119,0.36542687,0.36177874,0.35816705,0.3545914,0.35105142,0.34754673,0.34407702,0.34064204,0.33724132,0.3338745,0.33054137,0.3272415,0.32397458,0.32074028,0.3175382,0.31436813,0.31122974,0.3081226,0.30504647,0.30200112,0.2989862,0.29600134,0.29304633,0.2901207,0.28722438,0.28435695,0.2815181,0.27870762,0.27592525,0.27317056,0.27044344,0.26774356,0.26507056,0.2624243,0.25980446,0.25721073,0.25464293,0.25210077,0.249584,0.24709237,0.24462552,0.24218333,0.23976555,0.23737194,0.23500215,0.23265606,0.23033342,0.22803394,0.22575743,0.2235036,0.22127232,0.21906327,0.21687631,0.21471114,0.21256764,0.21044552,0.20834461,0.20626466,0.20420544,0.20216681,0.20014854,0.19815037,0.19617215,0.19421372,0.19227484,0.19035533,0.18845497,0.18657354,0.18471093,0.18286693,0.18104129,0.17923392,0.17744459,0.17567308,0.1739193,0.17218304,0.17046405,0.16876228,0.16707748,0.16540948,0.16375816,0.16212334,0.16050482,0.15890247,0.15731607,0.15574552,0.15419069,0.15265137,0.15112738,0.14961864,0.14812498,0.14664622,0.1451822,0.14373279,0.14229788,0.14087726,0.13947085,0.13807845,0.13669999,0.13533528}; + double assertion[300] = {1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,1.000000000000000000e+00,9.966611049434810354e-01,9.867111603284486332e-01,9.768605487739230320e-01,9.671082786103732953e-01,9.574533680683808834e-01,9.478948451798039354e-01,9.384317476799283186e-01,9.290631229105962285e-01,9.197880277243004610e-01,9.106055283892373620e-01,9.015147004953073528e-01,8.925146288610534828e-01,8.836044074415293492e-01,8.747831392370875037e-01,8.660499362030764647e-01,8.574039191604412302e-01,8.488442177072155204e-01,8.403699701308978698e-01,8.319803233217017979e-01,8.236744326866727306e-01,8.154514620646623468e-01,8.073105836421510251e-01,7.992509778699116163e-01,7.912718333805045523e-01,7.833723469065965173e-01,7.755517232000953554e-01,7.678091749520912224e-01,7.601439227135980969e-01,7.525551948170853267e-01,7.450422272987937689e-01,7.376042638218265335e-01,7.302405556000080011e-01,7.229503613225031211e-01,7.157329470791886639e-01,7.085875862867698771e-01,7.015135596156351072e-01,6.945101549174396149e-01,6.875766671534137009e-01,6.807123983233853703e-01,6.739166573955123196e-01,6.671887602367149173e-01,6.605280295438040739e-01,6.539337947752965619e-01,6.474053920839111242e-01,6.409421642497381555e-01,6.345434606140767375e-01,6.282086370139332576e-01,6.219370557171712832e-01,6.157280853583116942e-01,6.095811008749726367e-01,6.034954834449430816e-01,5.974706204238864338e-01,5.915059052836644238e-01,5.856007375512777280e-01,5.797545227484157682e-01,5.739666723316099173e-01,5.682366036329845604e-01,5.625637398015992385e-01,5.569475097453767676e-01,5.513873480736106725e-01,5.458826950400470501e-01,5.404329964865340896e-01,5.350377037872348085e-01,5.296962737933965659e-01,5.244081687786711354e-01,5.191728563849821176e-01,5.139898095689314772e-01,5.088585065487419845e-01,5.037784307517284565e-01,4.987490707622945774e-01,4.937699202704479151e-01,4.888404780208293054e-01,4.839602477622509946e-01,4.791287381977387683e-01,4.743454629350723484e-01,4.696099404378203390e-01,4.649216939768630041e-01,4.602802515824001017e-01,4.556851459964368911e-01,4.511359146257447605e-01,4.466320994952920342e-01,4.421732472021388527e-01,4.377589088697927955e-01,4.333886401030203062e-01,4.290620009431086457e-01,4.247785558235752101e-01,4.205378735263185508e-01,4.163395271382073215e-01,4.121830940081024908e-01,4.080681557043087104e-01,4.039942979724505667e-01,3.999611106937689398e-01,3.959681878438343627e-01,3.920151274516718853e-01,3.881015315592946102e-01,3.842270061816405180e-01,3.803911612669100828e-01,3.765936106572991271e-01,3.728339720501240850e-01,3.691118669593352886e-01,3.654269206774144463e-01,3.617787622376523182e-01,3.581670243768036999e-01,3.545913434981138868e-01,3.510513596347161203e-01,3.475467164133922426e-01,3.440770610186974499e-01,3.406420441574410929e-01,3.372413200235238606e-01,3.338745462631242389e-01,3.305413839402346898e-01,3.272414975025391692e-01,3.239745547476344245e-01,3.207402267895853032e-01,3.175381880258169032e-01,3.143681161043347383e-01,3.112296918912743071e-01,3.081225994387726264e-01,3.050465259531625062e-01,3.020011617634821843e-01,2.989862002903017069e-01,2.960013380148582840e-01,2.930462744485015647e-01,2.901207121024425017e-01,2.872243564578055852e-01,2.843569159359789489e-01,2.815181018692606840e-01,2.787076284717992514e-01,2.759252128108221624e-01,2.731705747781537075e-01,2.704434370620155681e-01,2.677435251191103149e-01,2.650705671469821278e-01,2.624242940566549609e-01,2.598044394455423789e-01,2.572107395706292876e-01,2.546429333219200064e-01,2.521007621961529055e-01,2.495839702707757235e-01,2.470923041781825646e-01,2.446255130802063582e-01,2.421833486428674187e-01,2.397655650113727777e-01,2.373719187853666479e-01,2.350021689944260528e-01,2.326560770738031469e-01,2.303334068404078172e-01,2.280339244690317291e-01,2.257573984688081292e-01,2.235035996599082919e-01,2.212723011504689752e-01,2.190632783137518302e-01,2.168763087655291855e-01,2.147111723416972873e-01,2.125676510761114746e-01,2.104455291786438698e-01,2.083445930134591173e-01,2.062646310775079761e-01,2.042054339792348794e-01,2.021667944174980747e-01,2.001485071607009836e-01,1.981503690261307848e-01,1.961721788595043592e-01,1.942137375147174327e-01,1.922748478337968081e-01,1.903553146270518526e-01,1.884549446534251604e-01,1.865735466010380594e-01,1.847109310679319050e-01,1.828669105430000552e-01,1.810412993871116094e-01,1.792339138144224131e-01,1.774445718738737465e-01,1.756730934308744496e-01,1.739193001491673995e-01,1.721830154728755669e-01,1.704640646087285105e-01,1.687622745084652875e-01,1.670774738514141378e-01,1.654094930272448083e-01,1.637581641188943782e-01,1.621233208856623365e-01,1.605047987464754966e-01,1.589024347633189727e-01,1.573160676248336609e-01,1.557455376300762306e-01,1.541906866724424563e-01,1.526513582237501165e-01,1.511273973184814046e-01,1.496186505381822129e-01,1.481249659960175158e-01,1.466461933214808777e-01,1.451821836452561187e-01,1.437327895842310799e-01,1.422978652266598532e-01,1.408772661174743090e-01,1.394708492437411185e-01,1.380784730202649913e-01,1.366999972753347725e-01,1.353352832366127023e-01}; + Nd4jLong threeHundredShapePointer[8] = {2,1,300,1,1,0,1,99}; + Nd4jLong twoHundredShapePointer[8] = {2,1,200,1,1,0,1,99}; + sd::ops::where_np op; + ArrayOptions::setDataType(threeHundredShapePointer, sd::DataType::DOUBLE); + ArrayOptions::setDataType(twoHundredShapePointer, sd::DataType::DOUBLE); + + NDArray xArr(x,threeHundredShapePointer); + NDArray putArr(put,twoHundredShapePointer); + NDArray resultArr(z,threeHundredShapePointer); + + resultArr.assign(0.0); + ArrayOptions::setDataType(threeHundredShapePointer, sd::DataType::BOOL); + NDArray maskArr(mask,threeHundredShapePointer); + + ArrayOptions::setDataType(threeHundredShapePointer, sd::DataType::DOUBLE); + NDArray assertArr(assertion, threeHundredShapePointer); + Nd4jStatus result = op.execute({&maskArr, &xArr, &putArr},{&resultArr},{},{},{}); + ASSERT_EQ(Status::OK(),result); + ASSERT_TRUE(assertArr.isSameShape(resultArr)); + ASSERT_TRUE (assertArr.equalsTo(resultArr)); +} + +TEST_F(DeclarableOpsTests7, TEST_WHERE_SCALAR) { + std::vector data; + std::vector mask; + std::vector put; + std::vector resultData; + std::vector assertion; + for(Nd4jLong i = 0; i < 4; i++) { + data.push_back(i); + if(i > 1) { + assertion.push_back(5.0); + mask.push_back(true); + } + else { + assertion.push_back(i); + mask.push_back(false); + } + + resultData.push_back(0.0); + } + + + put.push_back(5.0); + + + auto x = NDArrayFactory::create('c',{1,4},data); + auto maskArr = NDArrayFactory::create('c',{1,4},mask); + auto putArr = NDArrayFactory::create('c',{1,1},put); + auto resultArr = NDArrayFactory::create('c',{1,4},resultData); + sd::ops::where_np op; + //greater than test + // Nd4jStatus execute(std::initializer_list*> inputs, std::initializer_list*> outputs , std::initializer_list tArgs, std::initializer_list iArgs, bool isInplace = false); + + auto result = op.execute({&maskArr,&x,&putArr},{&resultArr}, {},{3}, {}, {}, false); + // ASSERT_EQ(Status::OK(), result.status()); + for(int i = 0; i < 4; i++) + ASSERT_EQ(assertion[i],resultArr.e(i)); + // auto z = result.at(0); + //ASSERT_EQ(4,z->lengthOf()); + //ASSERT_TRUE(exp.isSameShape(z)); + + +} + + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestMatrixDiagPart_1) { + auto x = NDArrayFactory::create('c', {2, 4, 4}, {1., 0., 0., 0., 0., 2., 0., 0., 0., 0., 3., 0., 0., 0., 0., 4.,5., 0., 0., 0., 0., 6., 0., 0., 0., 0., 7., 0., 0., 0., 0., 8.}); + + auto z = NDArrayFactory::create('c', {2, 4}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0}); + + sd::ops::matrix_diag_part op; + + auto result = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(z.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestMatrixDiagPart_2) { + auto x = NDArrayFactory::create('c', {2, 3, 4}, {1., 0., 0., 0., 0., 2., 0., 0., 0., 0., 3., 0.,5., 0., 0., 0., 0., 6., 0., 0., 0., 0., 7., 0.}); + + auto z = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 5, 6, 7}); + + sd::ops::matrix_diag_part op; + + auto result = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(z.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestMatrixDiag_1) { + auto z = NDArrayFactory::create('c', {2, 4, 4}, {1., 0., 0., 0., 0., 2., 0., 0., 0., 0., 3., 0., 0., 0., 0., 4.,5., 0., 0., 0., 0., 6., 0., 0., 0., 0., 7., 0., 0., 0., 0., 8.}); + + auto x = NDArrayFactory::create('c', {2, 4}, {1, 2, 3, 4, 5, 6, 7, 8}); + + sd::ops::matrix_diag op; + + auto result = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(z.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestMatrixDiag_2) { + auto z = NDArrayFactory::create('c', {2, 3, 3}, {1., 0., 0., 0., 2., 0., 0., 0., 3.,5., 0., 0., 0., 6., 0.,0., 0., 7.}); + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 5, 6, 7}); + + sd::ops::matrix_diag op; + + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(z.equalsTo(result.at(0))); + + +} + + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestRandomCrop_1) { + auto x = NDArrayFactory::create('c', {2, 2, 4}, {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }); + auto shape = NDArrayFactory::create({1, 2, 3}); + sd::ops::random_crop op; + + auto result = op.evaluate({&x, &shape}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + //result.at(0)->printIndexedBuffer("Output"); +// ASSERT_TRUE(z.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestRandomCrop_2) { + auto x = NDArrayFactory::create('c', {2, 2, 4}, {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }); + auto shape = NDArrayFactory::create({2, 2, 2}); + sd::ops::random_crop op; + + auto result = op.evaluate({&x, &shape}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + //result.at(0)->printIndexedBuffer("Output"); +// ASSERT_TRUE(z.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Dynamic_Stitch_119) { + auto indices0 = NDArrayFactory::create('c', {2}, {1, 10}); + auto indices1 = NDArrayFactory::create('c', {2, 3}, {0, 7, 9, 5, 8, 3}); + auto indices2 = NDArrayFactory::create('c', {3, 1}, {6, 4, 2}); + auto data0 = NDArrayFactory::create('c', {2,5,4}, {1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f}); + + auto data1 = NDArrayFactory::create('c', {2,3,5,4},{1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f,25.f, 26.f, 27.f, 28.f, + 29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f,41.f, 42.f, 43.f, 44.f,45.f, 46.f, 47.f, 48.f,49.f, 50.f, 51.f, 52.f,53.f, 54.f, 55.f, 56.f, + 57.f, 58.f, 59.f, 60.f,61.f, 62.f, 63.f, 64.f,65.f, 66.f, 67.f, 68.f,69.f, 70.f, 71.f, 72.f,73.f, 74.f, 75.f, 76.f,77.f, 78.f, 79.f, 80.f,81.f, 82.f, 83.f, 84.f, + 85.f, 86.f, 87.f, 88.f,89.f, 90.f, 91.f, 92.f,93.f, 94.f, 95.f, 96.f,97.f, 98.f, 99.f, 100.f,101.f, 102.f, 103.f, 104.f,105.f, 106.f, 107.f, 108.f,109.f, 110.f, 111.f, 112.f, + 113.f, 114.f, 115.f, 116.f,117.f, 118.f, 119.f, 120.f}); + + auto data2 = NDArrayFactory::create('c', {3,1,5,4}, {1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f,41.f, 42.f, 43.f, 44.f,45.f, 46.f, 47.f, 48.f, + 49.f, 50.f, 51.f, 52.f,53.f, 54.f, 55.f, 56.f,57.f, 58.f, 59.f, 60.f}); + + auto exp = NDArrayFactory::create('c', {11, 5, 4}, {1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,41.f, 42.f, 43.f, 44.f,45.f, 46.f, 47.f, 48.f, + 49.f, 50.f, 51.f, 52.f,53.f, 54.f, 55.f, 56.f,57.f, 58.f, 59.f, 60.f,101.f, 102.f, 103.f, 104.f,105.f, 106.f, 107.f, 108.f,109.f, 110.f, 111.f, 112.f, + 113.f, 114.f, 115.f, 116.f,117.f, 118.f, 119.f, 120.f,21.f, 22.f, 23.f, 24.f,25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f, + 37.f, 38.f, 39.f, 40.f,61.f, 62.f, 63.f, 64.f,65.f, 66.f, 67.f, 68.f,69.f, 70.f, 71.f, 72.f,73.f, 74.f, 75.f, 76.f,77.f, 78.f, 79.f, 80.f, + 1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f,81.f, 82.f, 83.f, 84.f,85.f, 86.f, 87.f, 88.f, + 89.f, 90.f, 91.f, 92.f,93.f, 94.f, 95.f, 96.f,97.f, 98.f, 99.f, 100.f,41.f, 42.f, 43.f, 44.f,45.f, 46.f, 47.f, 48.f,49.f, 50.f, 51.f, 52.f, + 53.f, 54.f, 55.f, 56.f,57.f, 58.f, 59.f, 60.f,21.f, 22.f, 23.f, 24.f,25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f}); + + sd::ops::dynamic_stitch op; + auto result = op.evaluate({&indices0, &indices1, &indices2, &data0, &data1, &data2}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); +// result.at(0)->printIndexedBuffer("Output"); +// exp.printIndexedBuffer("Expect"); +// result.at(0)->printShapeInfo("Output shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Dynamic_Stitch_Prof_1) { + auto indices0 = NDArrayFactory::create('c', {2}, {1, 10}); + auto indices1 = NDArrayFactory::create('c', {2, 3}, {0, 7, 9, 5, 8, 3}); + auto indices2 = NDArrayFactory::create('c', {3, 1}, {6, 4, 2}); + auto data0 = NDArrayFactory::create('c', {2,5,4}, {1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f}); + + auto data1 = NDArrayFactory::create('c', {2,3,5,4},{1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f,25.f, 26.f, 27.f, 28.f, + 29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f,41.f, 42.f, 43.f, 44.f,45.f, 46.f, 47.f, 48.f,49.f, 50.f, 51.f, 52.f,53.f, 54.f, 55.f, 56.f, + 57.f, 58.f, 59.f, 60.f,61.f, 62.f, 63.f, 64.f,65.f, 66.f, 67.f, 68.f,69.f, 70.f, 71.f, 72.f,73.f, 74.f, 75.f, 76.f,77.f, 78.f, 79.f, 80.f,81.f, 82.f, 83.f, 84.f, + 85.f, 86.f, 87.f, 88.f,89.f, 90.f, 91.f, 92.f,93.f, 94.f, 95.f, 96.f,97.f, 98.f, 99.f, 100.f,101.f, 102.f, 103.f, 104.f,105.f, 106.f, 107.f, 108.f,109.f, 110.f, 111.f, 112.f, + 113.f, 114.f, 115.f, 116.f,117.f, 118.f, 119.f, 120.f}); + + auto data2 = NDArrayFactory::create('c', {3,1,5,4}, {1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f,41.f, 42.f, 43.f, 44.f,45.f, 46.f, 47.f, 48.f, + 49.f, 50.f, 51.f, 52.f,53.f, 54.f, 55.f, 56.f,57.f, 58.f, 59.f, 60.f}); + + auto exp = NDArrayFactory::create('c', {11, 5, 4}, {1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,41.f, 42.f, 43.f, 44.f,45.f, 46.f, 47.f, 48.f, + 49.f, 50.f, 51.f, 52.f,53.f, 54.f, 55.f, 56.f,57.f, 58.f, 59.f, 60.f,101.f, 102.f, 103.f, 104.f,105.f, 106.f, 107.f, 108.f,109.f, 110.f, 111.f, 112.f, + 113.f, 114.f, 115.f, 116.f,117.f, 118.f, 119.f, 120.f,21.f, 22.f, 23.f, 24.f,25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f, + 37.f, 38.f, 39.f, 40.f,61.f, 62.f, 63.f, 64.f,65.f, 66.f, 67.f, 68.f,69.f, 70.f, 71.f, 72.f,73.f, 74.f, 75.f, 76.f,77.f, 78.f, 79.f, 80.f, + 1.f, 2.f, 3.f, 4.f,5.f, 6.f, 7.f, 8.f,9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, + 25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f,81.f, 82.f, 83.f, 84.f,85.f, 86.f, 87.f, 88.f, + 89.f, 90.f, 91.f, 92.f,93.f, 94.f, 95.f, 96.f,97.f, 98.f, 99.f, 100.f,41.f, 42.f, 43.f, 44.f,45.f, 46.f, 47.f, 48.f,49.f, 50.f, 51.f, 52.f, + 53.f, 54.f, 55.f, 56.f,57.f, 58.f, 59.f, 60.f,21.f, 22.f, 23.f, 24.f,25.f, 26.f, 27.f, 28.f,29.f, 30.f, 31.f, 32.f,33.f, 34.f, 35.f, 36.f,37.f, 38.f, 39.f, 40.f}); + + sd::ops::dynamic_stitch op; + auto result = op.evaluate({&indices0, &indices1, &indices2, &data0, &data1, &data2}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); +// result.at(0)->printIndexedBuffer("Output"); +// exp.printIndexedBuffer("Expect"); +// result.at(0)->printShapeInfo("Output shape"); + auto res = result.at(0); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + int numOfCases = 100; + auto timeStart = std::chrono::system_clock::now(); + + for (int i = 0; i < numOfCases; i++) { + op.execute({&indices0, &indices1, &indices2, &data0, &data1, &data2}, {res}, {}, {}, {}); + } + + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); + //nd4j_printf("dynamic_stitch: Process with %i iterations was load: %lld us.\n", numOfCases, outerTime / numOfCases); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Dynamic_Stitch_119_1) { + auto indices0 = NDArrayFactory::create('c', {2}, {1, 10}); + auto indices1 = NDArrayFactory::create('c', {2, 3}, {0, 7, 9, 5, 8, 3}); + auto indices2 = NDArrayFactory::create('c', {3, 1}, {6, 4, 2}); + + auto data0 = NDArrayFactory::create('c', {2,5,4}); + auto data1 = NDArrayFactory::create('c', {2,3,5,4}); + auto data2 = NDArrayFactory::create('c', {3,1,5,4}); + + auto exp = NDArrayFactory::create('c', {11, 5, 4}, { + 21, 22, 23, 24, + 25, 26, 27, 28, + 29, 30, 31, 32, + 33, 34, 35, 36, + 37, 38, 39, 40, + + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + + 181, 182, 183, 184, + 185, 186, 187, 188, + 189, 190, 191, 192, + 193, 194, 195, 196, + 197, 198, 199, 200, + + 121, 122, 123, 124, + 125, 126, 127, 128, + 129, 130, 131, 132, + 133, 134, 135, 136, + 137, 138, 139, 140, + + 161, 162, 163, 164, + 165, 166, 167, 168, + 169, 170, 171, 172, + 173, 174, 175, 176, + 177, 178, 179, 180, + + 81, 82, 83, 84, + 85, 86, 87, 88, + 89, 90, 91, 92, + 93, 94, 95, 96, + 97, 98, 99, 100, + + 141, 142, 143, 144, + 145, 146, 147, 148, + 149, 150, 151, 152, + 153, 154, 155, 156, + 157, 158, 159, 160, + + 41, 42, 43, 44, + 45, 46, 47, 48, + 49, 50, 51, 52, + 53, 54, 55, 56, + 57, 58, 59, 60, + + 101, 102, 103, 104, + 105, 106, 107, 108, + 109, 110, 111, 112, + 113, 114, 115, 116, + 117, 118, 119, 120, + + 61, 62, 63, 64, + 65, 66, 67, 68, + 69, 70, 71, 72, + 73, 74, 75, 76, + 77, 78, 79, 80, + + 21, 22, 23, 24, + 25, 26, 27, 28, + 29, 30, 31, 32, + 33, 34, 35, 36, + 37, 38, 39, 40, + }); + data0.linspace(1); + data1.linspace(21); + data2.linspace(141); + sd::ops::dynamic_stitch op; + auto result = op.evaluate({&indices0, &indices1, &indices2, &data0, &data1, &data2}, {}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + + ASSERT_TRUE(z->isSameShape(exp)); + ASSERT_TRUE(z->equalsTo(exp)); + + +} + +TEST_F(DeclarableOpsTests7, Test_Dynamic_Stitch_119_2) { + auto indices0 = NDArrayFactory::create('c', {2}, {1, 10}); + auto indices1 = NDArrayFactory::create('c', {2, 3}, {0, 7, 9, 5, 8, 3}); + auto indices2 = NDArrayFactory::create('c', {3, 1}, {6, 4, 2}); + + auto data0 = NDArrayFactory::create('c', {2,5,4}); + auto data1 = NDArrayFactory::create('c', {2,3,5,4}); + auto data2 = NDArrayFactory::create('c', {3,1,5,4}); + + auto exp = NDArrayFactory::create('c', {11, 5, 4}, { + 41, 42, 43, 44, + 45, 46, 47, 48, + 49, 50, 51, 52, + 53, 54, 55, 56, + 57, 58, 59, 60, + + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + + 201, 202, 203, 204, + 205, 206, 207, 208, + 209, 210, 211, 212, + 213, 214, 215, 216, + 217, 218, 219, 220, + + 141, 142, 143, 144, + 145, 146, 147, 148, + 149, 150, 151, 152, + 153, 154, 155, 156, + 157, 158, 159, 160, + + 181, 182, 183, 184, + 185, 186, 187, 188, + 189, 190, 191, 192, + 193, 194, 195, 196, + 197, 198, 199, 200, + + 101, 102, 103, 104, + 105, 106, 107, 108, + 109, 110, 111, 112, + 113, 114, 115, 116, + 117, 118, 119, 120, + + 161, 162, 163, 164, + 165, 166, 167, 168, + 169, 170, 171, 172, + 173, 174, 175, 176, + 177, 178, 179, 180, + + 61, 62, 63, 64, + 65, 66, 67, 68, + 69, 70, 71, 72, + 73, 74, 75, 76, + 77, 78, 79, 80, + + 121, 122, 123, 124, + 125, 126, 127, 128, + 129, 130, 131, 132, + 133, 134, 135, 136, + 137, 138, 139, 140, + + 81, 82, 83, 84, + 85, 86, 87, 88, + 89, 90, 91, 92, + 93, 94, 95, 96, + 97, 98, 99, 100, + + 21, 22, 23, 24, + 25, 26, 27, 28, + 29, 30, 31, 32, + 33, 34, 35, 36, + 37, 38, 39, 40, + }); + data0.linspace(1); + data1.linspace(41); + data2.linspace(161); + sd::ops::dynamic_stitch op; + auto result = op.evaluate({&indices0, &indices1, &indices2, &data0, &data1, &data2}, {}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + + ASSERT_TRUE(z->isSameShape(exp)); + ASSERT_TRUE(z->equalsTo(exp)); + + +} + +TEST_F(DeclarableOpsTests7, Test_Dynamic_Partition_119) { + auto x = NDArrayFactory::create('c', {5, 4, 11}); + auto y = NDArrayFactory::create('c', {5, 4}, {0,1,2,3, 1,0,2,3, 2,3,1,0, 2,1,0,3, 0,1,2,3}); + auto e = NDArrayFactory::create('c', {5, 11}); + x.assign(1.f); + e.assign(1.f); + sd::ops::dynamic_partition op; + auto result = op.evaluate({&x, &y}, {}, {4}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(4, result.size()); + auto z = result.at(0); + + ASSERT_TRUE(e.isSameShape(z)); + + +} + +TEST_F(DeclarableOpsTests7, Test_Dynamic_Partition_119_1) { + auto x = NDArrayFactory::create('c', {3, 4, 2}, {10, 20,11, 21,12, 22,13, 23,14, 24,15, 25,16, 26,17, 27,18, 28,19, 29,20, 30,21, 31}); + + auto y = NDArrayFactory::create('c', {3, 4}, {0,0,0,0, 2,2,2,2, 2,1,1,1}); + auto e = NDArrayFactory::create('c', {4, 2}, {10, 20, 11, 21, 12, 22, 13, 23}); + +// x.assign(1.f); +// e.assign(1.f); + sd::ops::dynamic_partition op; + auto result = op.evaluate({&x, &y}, {}, {3}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(3, result.size()); + auto z = result.at(0); +// z->printShapeInfo("Output shape info"); +// result.at(1)->printShapeInfo("Shape2"); +// result.at(2)->printShapeInfo("Shape3"); +// result.at(3)->printShapeInfo("Shape4"); +// z->printIndexedBuffer("Output1"); +// result.at(1)->printIndexedBuffer("Output2"); +// result.at(2)->printIndexedBuffer("Output3"); +// result.at(3)->printIndexedBuffer("Output4"); + ASSERT_TRUE(e.isSameShape(z)); + + +} +TEST_F(DeclarableOpsTests7, Test_Dynamic_Partition_119_2) { + auto x = NDArrayFactory::create('c', {5, 4, 11}); + auto y = NDArrayFactory::create('c', {5, 4}, {0,1,2,3, 1,0,2,3, 2,3,1,0, 2,1,0,3, 0,1,2,3}); + auto e1 = NDArrayFactory::create('c', {5, 11}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, + 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, + 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, + 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187}); + auto e2 = NDArrayFactory::create('c', {5, 11}, { 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, + 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, + 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, + 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198}); + auto e3 = NDArrayFactory::create('c', {5, 11}, {23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, + 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, + 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, + 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, + 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209}); + auto e4 = NDArrayFactory::create('c', {5, 11}, { 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, + 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, + 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, + 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, + 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220}) ; + std::vector e({&e1, &e2, &e3, &e4}); + x.linspace(1.f); + //.assign(1.f); + sd::ops::dynamic_partition op; + auto result = op.evaluate({&x, &y}, {}, {4}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(4, result.size()); + for (size_t i = 0; i < result.size(); i++) { + auto z = result.at(i); +// z->printShapeInfo("Output shape info"); +// z->printIndexedBuffer("Output1"); +// result.at(1)->printIndexedBuffer("Output2"); +// result.at(2)->printIndexedBuffer("Output3"); +// result.at(3)->printIndexedBuffer("Output4"); + ASSERT_TRUE(e[i]->isSameShape(z)); + ASSERT_TRUE(e[i]->equalsTo(z)); + } + + +} + + +TEST_F(DeclarableOpsTests7, Test_SequenceMask_1) { + auto input = NDArrayFactory::create('c', {4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + auto exp = NDArrayFactory::create('c', {4, 4, 16}, {1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0,1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 }); + + sd::ops::sequence_mask op; + auto result = op.evaluate({&input}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); +// z->printIndexedBuffer("Output"); +// z->printShapeInfo("Shape"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + + +} + +TEST_F(DeclarableOpsTests7, Test_SequenceMask_2) { + auto input = NDArrayFactory::create('c', {2, 2, 2}, {10, 20, 30, 4, 0, 6, 7, 8}); + auto exp = NDArrayFactory::create('c', {2, 2, 2, 30}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + + sd::ops::sequence_mask op; + auto result = op.evaluate({&input}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); +// z->printBuffer("Output"); +// z->printShapeInfo("Shape"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests7, Test_SequenceMask_3) { + auto input = NDArrayFactory::create('c', {2, 2, 2}, {10, 20, 30, 4, 0, 6, 7, 8}); + auto exp = NDArrayFactory::create('c', {2, 2, 2, 30}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + + sd::ops::sequence_mask op; + auto result = op.evaluate({&input}, {sd::DataType::INT32}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); +// z->printBuffer("Output"); +// z->printShapeInfo("Shape"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests7, Test_SequenceMask_4) { + auto input = NDArrayFactory::create({1, 3, 2}); + auto maxLen = NDArrayFactory::create(5); + auto exp = NDArrayFactory::create('c', {3,5}, { + 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f + }); + + sd::ops::sequence_mask op; + auto result = op.evaluate({&input, &maxLen}, {sd::DataType::FLOAT32}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); +// z->printBuffer("Output"); +// z->printShapeInfo("Shape"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(DeclarableOpsTests7, Test_SequenceMask_5) { + auto input = NDArrayFactory::create({1, 3, 2}); + auto exp = NDArrayFactory::create('c', {3,5}, { + 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 0.f, 0.f, 0.f + }); + + sd::ops::sequence_mask op; + auto result = op.evaluate({&input}, {5, (int)sd::DataType::FLOAT32}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); +// z->printBuffer("Output"); +// z->printShapeInfo("Shape"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentMax_1) { + auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({2.5, 9, 3, 9, 4.2}); + + sd::ops::segment_max op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); +// result.at(0)->printBuffer("MaX1"); +// exp.printBuffer("ExP1"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentMax_01) { + auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1., 10, 40, 30}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4, 5,5, 5}); + auto exp = NDArrayFactory::create({2.5, 9, 3, 9, 4.2, 40}); + + sd::ops::segment_max op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); +// result.at(0)->printBuffer("MaX01"); +// exp.printBuffer("ExP01"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentMaxBP_1) { + auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({0., 1., 0., 2., 0., 0., 3., 4., 0., 0.,0., 0., 0., 5., 0.,0.}); + auto eps = NDArrayFactory::create('c', {5}); + sd::ops::segment_max_bp op; + eps.linspace(1); + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); +// result.at(0)->printIndexedBuffer("OutputMaxBP"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentMax_2) { + auto x = NDArrayFactory::create('c', {5, 4}, { 0, 1.8, 2.5, 4., + 1, 9., 2.1, 2.4, + 0, 3., 9., 2.1, + 2, 1, 2.1, 0.7, + 3, 4.2, 2.2, 1. }); + auto idx = NDArrayFactory::create({0, 0, 0, 1, 2}); + auto exp = NDArrayFactory::create('c', {3, 4}, {1, 9, 9, 4, + 2, 1, 2.1, 0.7, + 3, 4.2, 2.2, 1.}); + + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + sd::ops::segment_max op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + auto out = result.at(0); +// out->printIndexedBuffer("Output2Max"); +// exp.printIndexedBuffer("Expect2Max"); +// exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentMaxBP_2) { + auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto eps = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); +// NDArray exp('c', {3, 4}, {2.1, 2.5, 4, 9,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); + auto exp = NDArrayFactory::create('c', {4, 4}, {0., 2., 3., 4., 1., 0., 0., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + + sd::ops::segment_max_bp op; + + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 2); + //exp.printIndexedBuffer("BP Max Expect"); + //result.at(0)->printIndexedBuffer("BP Max Output"); +// exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentMax_3) { + auto x = NDArrayFactory::create('c', {4, 4, 4}, { + 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., + 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , + 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); + +// ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 1, 2}); + auto exp = NDArrayFactory::create('c', {3, 4, 4}, {91. , 82. , 37. , 64.,55.1, 46.4, 73. , 28.,119.1, 12.1,112.7, 13.1,14. ,114.2, 16.2,117.,51. , 42. , 87. , 44., + 55.1, 56.4, 93. , 28.,119.1, 82.1,112.7,113.1,114. ,114.2,116.2,117.,91. , 82. , 37. , 64.,55.1, 46.4, 73. , 28., 119.1, 12.1,112.7, 13.1,14. ,114.2, 16.2,117. }); + + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + + sd::ops::segment_max op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); +// result.at(0)->printIndexedBuffer("Output3Max"); +// result.at(0)->printShapeInfo("Out Shape 3 Max"); +// exp.printIndexedBuffer("Expect3Max"); +// exp.printShapeInfo("Exp Shape 3 Max"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentMax_4) { + auto x = NDArrayFactory::create('c', {4, 4, 4}, {91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24., + 15.1, 56.4, 93. , 28.,109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. , + 119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); + + auto idx = NDArrayFactory::create({0, 1, 3, 7}); + auto exp = NDArrayFactory::create('c', {8, 4, 4}, { + 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24. ,15.1, 56.4, 93. , 28. , + 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , + 31. , 22. , 87. , 44. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. , + 119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); + + sd::ops::segment_max op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + //result.at(0)->printIndexedBuffer("Output"); + //result.at(0)->printShapeInfo("Out Shape"); + //exp.printIndexedBuffer("Expect"); + //exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMax_1) { + auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({4, 4, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 0, 0}); + auto exp = NDArrayFactory::create({2.2, 9., 3., 9., 4.2}); + + sd::ops::unsorted_segment_max op; + + auto result = op.evaluate({&x, &idx}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMaxBP_1) { + auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({0., 1., 0., 2., 0., 0., 3., 4., 0., 0.,0., 0., 0., 5., 0.,0.}); + auto eps = NDArrayFactory::create('c', {5}); + sd::ops::segment_max_bp op; + eps.linspace(1); + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMaxBP_2) { + auto x = NDArrayFactory::create({3., 1.8, 2.5, 4., 9., 2.1, 2.4, 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({3., 0., 1., 0., 2., 0., 0., 4., 0., 0.,0., 0., 0., 5., 0.,0.}); + auto eps = NDArrayFactory::create('c', {5}); + sd::ops::segment_max_bp op; + eps.linspace(1); + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + //result.at(0)->printIndexedBuffer("Output"); + //exp.printIndexedBuffer("Expect"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMax_2) { + auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({4, 4, 1, 1, 1, 1, 3, 3, 3, 3, 4, 4, 4, 4, 0, 0}); + auto exp = NDArrayFactory::create({2.2, 9., -DataTypeUtils::max(), 9., 4.2}); + + sd::ops::unsorted_segment_max op; + + auto result = op.evaluate({&x, &idx}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); +// result.at(0)->printIndexedBuffer("OutputUnsortedMax"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMax_3) { + auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create('c', {3, 4}, {2.1, 2.5, 4, 9,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); + + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + + sd::ops::unsorted_segment_max op; + + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + //exp.printIndexedBuffer("Expect"); + //result.at(0)->printIndexedBuffer("Output"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMax_4) { + auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9., 2.1, 2.4, 3., 8., 2.1, 2.1, 11.7, 0.1, 3., 4.2, 2.2, 1. }); + auto idx = NDArrayFactory::create({0, 0, 0, 2}); + double principalMax = DataTypeUtils::max(); + auto exp = NDArrayFactory::create('c', {3, 4}, {2.1, 2.5, 11.7, 9, + -principalMax, -principalMax, -principalMax, -principalMax, + 3., 4.2, 2.2, 1.}); + + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + + sd::ops::unsorted_segment_max op; + + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + //exp.printIndexedBuffer("Expect"); + //result.at(0)->printIndexedBuffer("Output"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentMin_1) { + auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4, 3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({1.8, 2.1, 3., 2.1, 0.1}); + + sd::ops::segment_min op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + auto out = result.at(0); + + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentMin_01) { + auto x = NDArrayFactory::create({1.8, -2.5,4., -9., 2.1, 2.4,-3.,-9., 2.1, 2.1,0.7, 0.1, 3., -4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({-2.5, -9, -3., -9, -4.2}); + + sd::ops::segment_min op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + auto out = result.at(0); + + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentMin_02) { + auto x = NDArrayFactory::create({1.8f, -2.5f, 4.f, -9.f, 2.1f, 2.4f, -3.f, -9.f, 2.1f, 2.1f,0.7f, 0.1f, 3.f, -4.2f, 2.2f, 1.f}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({-2.5f, -9.f, -3.f, -9.f, -4.2f}); + + sd::ops::segment_min op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + auto out = result.at(0); + + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentMinBP_1) { + auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({ 1., 0., 0., 0., 2., 0., 3., 0., 4., 4., 0., 5., 0., 0., 0., 0.}); + auto eps = NDArrayFactory::create('c', {5}); + eps.linspace(1); + sd::ops::segment_min_bp op; + + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMinBP_1) { + auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({ 1., 0., 0., 0., 2., 0., 3., 0., 4., 4., 0., 5., 0., 0., 0., 0.}); + auto eps = NDArrayFactory::create('c', {5}); + eps.linspace(1); + sd::ops::unsorted_segment_min_bp op; + + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + //result.at(0)->printIndexedBuffer("Output1"); + //exp.printIndexedBuffer("Expecte"); + + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMinBP_2) { + auto x = NDArrayFactory::create({3., 1.8, 2.5, 4., 9., 2.1, 2.4, 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({3., 1., 0., 0., 0., 2., 0., 0., 4., 4., 0., 5., 0., 0., 0., 0.}); + auto eps = NDArrayFactory::create('c', {5}); + eps.linspace(1); + sd::ops::unsorted_segment_min_bp op; + + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + //result.at(0)->printIndexedBuffer("Output1"); + //exp.printIndexedBuffer("Expecte"); + + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentMin_2) { + auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create('c', {3, 4}, {1.8, 2.4, 3. , 9.,2.1, 2.1, 0.7, 0.1,3. , 4.2, 2.2, 1.}); + + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + + sd::ops::segment_min op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); +// exp.printIndexedBuffer("Expect"); +// exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentMinBP_2) { + auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto eps = NDArrayFactory::create('c', {3, 4}, {1., 2., 3. , 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto exp = NDArrayFactory::create('c', {4, 4}, {1., 0., 0., 4., 0., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + + sd::ops::segment_min_bp op; + + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 2); +// exp.printIndexedBuffer("Expect"); +// result.at(0)->printIndexedBuffer("Output"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentMin_3) { + auto x = NDArrayFactory::create('c', {4, 4, 4}, { + 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28.,109.1, 82.1, 12.7, 113.1, + 114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. , + 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); + +// ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 1, 2}); + auto exp = NDArrayFactory::create('c', {3, 4, 4}, {91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,31. , 22. , 67. , 24. , + 15.1, 46.4, 73. , 28. ,109.1, 12.1, 12.7, 13.1,14. , 14.2, 16.2, 11. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); + + sd::ops::segment_min op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); +// result.at(0)->printIndexedBuffer("Output"); +// result.at(0)->printShapeInfo("Out Shape"); +// exp.printIndexedBuffer("Expect"); +// exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentMin_4) { + auto x = NDArrayFactory::create('c', {4, 4, 4}, { + 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., + 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , + 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); + +// ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 3, 7}); + auto exp = NDArrayFactory::create('c', {8, 4, 4}, { + 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24. ,15.1, 56.4, 93. , 28. , + 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , + 31. , 22. , 87. , 44. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); + + sd::ops::segment_min op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + //result.at(0)->printIndexedBuffer("Output"); + //result.at(0)->printShapeInfo("Out Shape"); + //exp.printIndexedBuffer("Expect"); + //exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMin_1) { + auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({1.8, 2.1, 3., 2.1, 0.1}); + + sd::ops::unsorted_segment_min op; + + auto result = op.evaluate({&x, &idx}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMin_01) { + auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({1.8, 2.1, 3., 2.1, 0.1}); + + sd::ops::unsorted_segment_min op; + + auto result = op.evaluate({&x, &idx}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMin_2) { + auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create('c', {3, 4}, {1.8, 2.4, 3. , 9.,2.1, 2.1, 0.7, 0.1,3. , 4.2, 2.2, 1.}); + + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + + sd::ops::unsorted_segment_min op; + + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); +// exp.printIndexedBuffer("Expect"); +// exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMin_3) { + auto x = NDArrayFactory::create('c', {4, 4, 4}, { 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28.,109.1, 82.1, 12.7, 113.1, + 114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. , + 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); + +// ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 1, 2}); + auto exp = NDArrayFactory::create('c', {3, 4, 4}, {91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,31. , 22. , 67. , 24. , + 15.1, 46.4, 73. , 28. ,109.1, 12.1, 12.7, 13.1,14. , 14.2, 16.2, 11. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); + + sd::ops::unsorted_segment_min op; + + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); +// result.at(0)->printIndexedBuffer("Output"); +// result.at(0)->printShapeInfo("Out Shape"); +// exp.printIndexedBuffer("Expect"); +// exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMin_4) { + auto x = NDArrayFactory::create('c', {4, 4, 4}, {91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., + 51., 42., 67., 24., 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, 116.2, 11., + 31., 22., 87., 44., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., + 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117. }); + +// ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 3, 7}); + double principalMax = DataTypeUtils::max(); + + auto exp = NDArrayFactory::create('c', {8, 4, 4}, { + 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., 51., + 42., 67., 24., 15.1, 56.4, 93., 28., 109.1, 82.1, 12.7, 113.1, 114., 14.2, 116.2, 11., + principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, + principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, + principalMax, principalMax, + 31., 22., 87., 44., 55.1, 46.4, 73., 28., 119.1, 12.1, 112.7, 13.1, 14., 114.2, 16.2, 117., + principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, + principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, + principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, + principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, + principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, + principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, + principalMax, principalMax, principalMax, principalMax, principalMax, principalMax, + 91., 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); + + sd::ops::unsorted_segment_min op; + + auto result = op.evaluate({&x, &idx}, {}, {8}); + ASSERT_EQ(result.status(), Status::OK()); + //result.at(0)->printIndexedBuffer("Output"); + //result.at(0)->printShapeInfo("Out Shape"); + // exp.printIndexedBuffer("Expect"); + //exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentMean_1) { + auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({2.15, 4.375, 3., 4.4, 1.8666667}); + + sd::ops::segment_mean op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +TEST_F(DeclarableOpsTests7, TestSegmentMean_2) { + auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create('c', {3, 4}, { 1.95, 2.45, 3.5, 9., 2.1, 2.1, 0.7, 0.1, 3. , 4.2, 2.2, 1.}); + + sd::ops::segment_mean op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); +// exp.printIndexedBuffer("Expect"); +// result.at(0)->printIndexedBuffer("Output"); +// exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +TEST_F(DeclarableOpsTests7, TestSegmentMean_02) { + auto x = NDArrayFactory::create('c', {6, 3}, {1, 2, 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 2,2}); + auto exp = NDArrayFactory::create('c', {3, 3}, { 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5}); + + sd::ops::segment_mean op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +TEST_F(DeclarableOpsTests7, TestSegmentMean_021) { + auto x = NDArrayFactory::create('c', {6, 3});//, {1, 2, 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 2,2}); + auto exp = NDArrayFactory::create('c', {3, 3}, { 2.5f, 3.5f, 4.5f, 8.5f, 9.5f, 10.5f, 14.5f, 15.5f, 16.5f}); + + sd::ops::segment_mean op; + x.linspace(1.); + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +TEST_F(DeclarableOpsTests7, TestSegmentMean_022) { + auto x = NDArrayFactory::create('c', {6, 3});//, {1, 2, 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18.}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 2,2}); + auto z = NDArrayFactory::create('c', {3, 3}); //, { 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5}); + auto exp = NDArrayFactory::create('c', {3, 3}, { 2.5f, 3.5f, 4.5f, 8.5f, 9.5f, 10.5f, 14.5f, 15.5f, 16.5f}); + + sd::ops::segment_mean op; + x.linspace(1.); + auto result = op.execute({&x, &idx}, {&z}); + ASSERT_EQ(result, Status::OK()); + + ASSERT_TRUE(exp.equalsTo(z)); + +// +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentMeanBP_2) { + auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto eps = NDArrayFactory::create('c', {3, 4}); + auto exp = NDArrayFactory::create('c', {4, 4}, { 0.5, 1., 1.5, 2., 0.5, 1., 1.5, 2., 5., 6., 7., 8., 9., 10., 11., 12.}); + eps.linspace(1); + + sd::ops::segment_mean_bp op; + + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 2); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentMean_3) { + auto x = NDArrayFactory::create('c', {4, 4, 4}, { + 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., + 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , + 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); + +// ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 1, 2}); + auto exp = NDArrayFactory::create('c', {3, 4, 4}, { + 91. , 82. , 37. , 64. ,55.1 , 46.4 , 73. , 28. ,119.1 , 12.1 , 112.7 , 13.1,14. , 114.2 , 16.2 , 117. , + 41. , 32. , 77. , 34. ,35.1 , 51.4 , 83. , 28. ,114.1 , 47.1 , 62.7, 63.1,64. , 64.2 , 66.2 , 64. , + 91. , 82. , 37. , 64. ,55.1 , 46.4 , 73. , 28. ,119.1 , 12.1 , 112.7 , 13.1,14. , 114.2 , 16.2 , 117. }); + + sd::ops::segment_mean op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); +// result.at(0)->printIndexedBuffer("Output"); +// result.at(0)->printShapeInfo("Out Shape"); +// exp.printIndexedBuffer("Expect"); +// exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentMean_4) { + auto x = NDArrayFactory::create('c', {4, 4, 4}, { + 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., + 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , + 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); + +// ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 3, 7}); + auto exp = NDArrayFactory::create('c', {8, 4, 4}, { + 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24. ,15.1, 56.4, 93. , 28. , + 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , + 31. , 22. , 87. , 44. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. , + 119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); + + sd::ops::segment_mean op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + //result.at(0)->printIndexedBuffer("Output"); + //result.at(0)->printShapeInfo("Out Shape"); + //exp.printIndexedBuffer("Expect"); + //exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMean_1) { + auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({2.15, 4.375, 3., 4.4, 1.8666667}); + + sd::ops::unsorted_segment_mean op; + + auto result = op.evaluate({&x, &idx}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentMeanBP_1) { + auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); + auto exp = NDArrayFactory::create({1./2., 1./2., 2./4., 2./4., 2./4., 2./4, 3., 4./3., 4./3., 4./3., + 5./6., 5./6., 5./6., 5./6., 5./6., 5./6.}); + sd::ops::segment_mean_bp op; + + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMeanBP_1) { + auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); + auto exp = NDArrayFactory::create({1./2., 1./2., 2./4., 2./4., 2./4., 2./4, 3., 4./3., 4./3., 4./3., + 5./6., 5./6., 5./6., 5./6., 5./6., 5./6.}); + sd::ops::unsorted_segment_mean_bp op; + + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMeanBP_2) { + auto x = NDArrayFactory::create({3.,1.8, 2.5,4., 9., 2.1, 2.4,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); + auto exp = NDArrayFactory::create({3., 1./2., 1./2., 2./4., 2./4., 2./4., 2./4, 4./3., 4./3., 4./3., + 5./6., 5./6., 5./6., 5./6., 5./6., 5./6.}); + sd::ops::unsorted_segment_mean_bp op; + + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMean_2) { + auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create('c', {3, 4}, { 1.95, 2.45, 3.5, 9., 2.1, 2.1, 0.7, 0.1, 3. , 4.2, 2.2, 1.}); + + sd::ops::unsorted_segment_mean op; + + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); +// exp.printIndexedBuffer("Expect"); +// result.at(0)->printIndexedBuffer("Output"); +// exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMean_3) { + auto x = NDArrayFactory::create('c', {4, 4, 4}, { + 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., + 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , + 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); + +// ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 1, 2}); + auto exp = NDArrayFactory::create('c', {3, 4, 4}, { + 91. , 82. , 37. , 64. ,55.1 , 46.4 , 73. , 28. ,119.1 , 12.1 , 112.7 , 13.1,14. , 114.2 , 16.2 , 117. , + 41. , 32. , 77. , 34. ,35.1 , 51.4 , 83. , 28. ,114.1 , 47.1 , 62.7, 63.1,64. , 64.2 , 66.2 , 64. , + 91. , 82. , 37. , 64. ,55.1 , 46.4 , 73. , 28. ,119.1 , 12.1 , 112.7 , 13.1,14. , 114.2 , 16.2 , 117. }); + + sd::ops::unsorted_segment_mean op; + + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); +// result.at(0)->printIndexedBuffer("Output"); +// result.at(0)->printShapeInfo("Out Shape"); +// exp.printIndexedBuffer("Expect"); +// exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentMean_4) { + auto x = NDArrayFactory::create('c', {4, 4, 4}, { + 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., + 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , + 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); + +// ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 3, 7}); + auto exp = NDArrayFactory::create('c', {8, 4, 4}, { + 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24. ,15.1, 56.4, 93. , 28. , + 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , + 31. , 22. , 87. , 44. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. , + 119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); + + sd::ops::unsorted_segment_mean op; + + auto result = op.evaluate({&x, &idx}, {}, {8}); + ASSERT_EQ(result.status(), Status::OK()); + //result.at(0)->printIndexedBuffer("Output"); + //result.at(0)->printShapeInfo("Out Shape"); + //exp.printIndexedBuffer("Expect"); + //exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_1) { + auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({3.0405593, 8.75, 3., 7.621024, 4.5723805}); + + sd::ops::unsorted_segment_sqrt_n op; + + auto result = op.evaluate({&x, &idx}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_BP_1) { + auto x = NDArrayFactory::create({3.,1.8, 2.5,4., 9., 2.1, 2.4,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); +// NDArray exp({3.0405593, 8.75, 3., 7.621024, 4.5723805}); + auto exp = NDArrayFactory::create({3., 0.707107, 0.707107, 1., 1., 1., 1., 2.309401, 2.309401, 2.309401, 2.041241, 2.041241, 2.041241, 2.041241, 2.041241, 2.041241}); + sd::ops::unsorted_segment_sqrt_n_bp op; + + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); +// result.at(0)->printIndexedBuffer("Hello Out:"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_2) { + auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create('c', {3, 4}, { 2.7577164, 3.4648232, 4.9497476, 12.727922, + 2.1, 2.1, 0.7, 0.1, + 3. , 4.2, 2.2, 1. + }); + + sd::ops::unsorted_segment_sqrt_n op; + + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); +// exp.printIndexedBuffer("Expect"); +// result.at(0)->printIndexedBuffer("Output"); +// exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_3) { + auto x = NDArrayFactory::create('c', {4, 4, 4}, { + 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., + 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , + 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); + +// ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 1, 2}); + auto exp = NDArrayFactory::create('c', {3, 4, 4}, { + 91. , 82. , 37. , 64. ,55.1 , 46.4 , 73. , 28. ,119.1 , 12.1 , 112.7 , 13.1,14. , 114.2 , 16.2 , 117. , + 57.982758, 45.254833, 108.89445, 48.083263, 49.638893, 72.69058, 117.37973, 39.59798, 161.36177, 66.60946, 88.67119, 89.23688, 90.50967, 90.79251, 93.62093, 90.50967, + 91. , 82. , 37. , 64. ,55.1 , 46.4 , 73. , 28. ,119.1 , 12.1 , 112.7 , 13.1,14. , 114.2 , 16.2 , 117. }); + + sd::ops::unsorted_segment_sqrt_n op; + + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); +// result.at(0)->printIndexedBuffer("Output"); +// result.at(0)->printShapeInfo("Out Shape"); +// exp.printIndexedBuffer("Expect"); +// exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_4) { + auto x = NDArrayFactory::create('c', {4, 4, 4}, { + 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., + 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , + 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); + +// ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 3, 7}); + auto exp = NDArrayFactory::create('c', {8, 4, 4}, { + 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24. ,15.1, 56.4, 93. , 28. , + 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , + 31. , 22. , 87. , 44. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. , + 119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); + + sd::ops::unsorted_segment_sqrt_n op; + + auto result = op.evaluate({&x, &idx}, {}, {8}); + ASSERT_EQ(result.status(), Status::OK()); + //result.at(0)->printIndexedBuffer("Output"); + //result.at(0)->printShapeInfo("Out Shape"); + //exp.printIndexedBuffer("Expect"); + //exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_5) { + auto x = NDArrayFactory::create({1.,2.,5.,7.,3.,1.,3.,4.}); + auto idx = NDArrayFactory::create({3, 1, 0, 0, 2, 0, 3, 2}); + //NDArray exp({1.7320508075688772, 1., 1.4142135623730951, 1.4142135623730951}); + auto exp = NDArrayFactory::create({7.5055537, 2., 4.9497476, 2.828427}); + sd::ops::unsorted_segment_sqrt_n op; + + auto result = op.evaluate({&x, &idx}, {}, {4}); + ASSERT_EQ(result.status(), Status::OK()); + // result.at(0)->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSqrtN_6) { + auto x = NDArrayFactory::create({5,1,7,2,3,4,1,3}); + auto idx = NDArrayFactory::create({0,0,0,1,2,2,3,3}); + //NDArray exp({1.7320508075688772, 1., 1.4142135623730951, 1.4142135623730951}); +// auto exp = NDArrayFactory::create({7.5055537, 2., 4.9497476, 2.828427}); + sd::ops::unsorted_segment_sqrt_n op; + +try { + auto result = op.evaluate({&x, &idx}, {}, {1}); + ASSERT_NE(result.status(), Status::OK()); +} +catch (std::exception& err) { + +} + // result.at(0)->printIndexedBuffer("Output"); + // exp.printIndexedBuffer("Expect"); + //ASSERT_TRUE(exp.equalsTo(result.at(0))); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentSum_1) { + auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1. }); + auto idx = NDArrayFactory::create({ 0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({4.3, 17.5, 3., 13.2, 11.2}); + + sd::ops::segment_sum op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentSumBP_1) { + auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1. }); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); + auto exp = NDArrayFactory::create({ 1., 1., 2., 2., 2., 2., 3., 4., 4., 4., 5., 5., 5., 5., 5., 5.}); + sd::ops::segment_sum_bp op; + + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSumBP_1) { + auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1. }); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto eps = NDArrayFactory::create({1, 2, 3, 4, 5}); + auto exp = NDArrayFactory::create({ 1., 1., 2., 2., 2., 2., 3., 4., 4., 4., 5., 5., 5., 5., 5., 5.}); + sd::ops::unsorted_segment_sum_bp op; + + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSumBP_2) { + auto x = NDArrayFactory::create({3.,1.8, 2.5,4., 9., 2.1, 2.4,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1. }); + auto idx = NDArrayFactory::create({2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); + auto exp = NDArrayFactory::create({ 3., 1., 1., 2., 2., 2., 2., 4., 4., 4., 5., 5., 5., 5., 5., 5.}); + sd::ops::unsorted_segment_sum_bp op; + + auto result = op.evaluate({&x, &idx, &eps}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentSum_2) { + auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create('c', {3, 4}, {3.9 , 4.9, 7. , 18.,2.1 , 2.1, 0.7, 0.1,3. , 4.2, 2.2, 1.}); + + sd::ops::segment_sum op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); +// exp.printIndexedBuffer("Expect"); +// exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentSumBP_2) { + auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create('c', {4, 4}, {1. , 2., 3., 4., 1. , 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto eps = NDArrayFactory::create('c', {3, 4}); + eps.linspace(1); + + sd::ops::segment_sum_bp op; + + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 2); +// exp.printIndexedBuffer("Expect"); +// exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentSum_3) { + auto x = NDArrayFactory::create('c', {4, 4, 4}, { + 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., + 109.1, 82.1, 12.7, 113.1, 114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , + 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); + +// ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 1, 2}); + auto exp = NDArrayFactory::create('c', {3, 4, 4}, { + 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,82. , 64. , 154. , 68. , + 70.2, 102.8, 166. , 56. ,228.2, 94.2, 125.4, 126.2 ,128. , 128.4, 132.4, 128. ,91. , 82. , 37. , 64. , + 55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); + + sd::ops::segment_sum op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); +// result.at(0)->printIndexedBuffer("Output"); +// result.at(0)->printShapeInfo("Out Shape"); +// exp.printIndexedBuffer("Expect"); +// exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentSum_4) { + auto x = NDArrayFactory::create('c', {4, 4, 4}, { + 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., + 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , + 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); + +// ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 3, 7}); + auto exp = NDArrayFactory::create('c', {8, 4, 4}, { + 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24. ,15.1, 56.4, 93. , 28. , + 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , + 31. , 22. , 87. , 44. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. , + 119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); + + sd::ops::segment_sum op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + //result.at(0)->printIndexedBuffer("Output"); + //result.at(0)->printShapeInfo("Out Shape"); + //exp.printIndexedBuffer("Expect"); + //exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSum_1) { + auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1. }); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({4.3, 17.5, 3., 13.2, 11.2}); + + sd::ops::unsorted_segment_sum op; + + auto result = op.evaluate({&x, &idx}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSum_2) { + auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9.,2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create('c', {3, 4}, {3.9 , 4.9, 7. , 18.,2.1 , 2.1, 0.7, 0.1,3. , 4.2, 2.2, 1.}); + + sd::ops::unsorted_segment_sum op; + + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSum_3) { + auto x = NDArrayFactory::create('c', {4, 4, 4}, { + 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., + 109.1, 82.1, 12.7, 113.1, 114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , + 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); + +// ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 1, 2}); + auto exp = NDArrayFactory::create('c', {3, 4, 4}, { + 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,82. , 64. , 154. , 68. , + 70.2, 102.8, 166. , 56. ,228.2, 94.2, 125.4, 126.2 ,128. , 128.4, 132.4, 128. ,91. , 82. , 37. , 64. , + 55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); + + sd::ops::unsorted_segment_sum op; + + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); +// result.at(0)->printIndexedBuffer("Output"); +// result.at(0)->printShapeInfo("Out Shape"); +// exp.printIndexedBuffer("Expect"); +// exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentSum_4) { + auto x = NDArrayFactory::create('c', {4, 4, 4}, {91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24.,15.1, 56.4, 93. , 28., + 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,31. , 22. , 87., 44. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. , + 91. , 82. , 37., 64. ,55.1, 46.4, 73., 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); + +// ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 3, 7}); + auto exp = NDArrayFactory::create('c', {8, 4, 4}, { + 91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,51. , 42. , 67. , 24. ,15.1, 56.4, 93. , 28. , + 109.1, 82.1, 12.7, 113.1,114. , 14.2, 116.2, 11. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , + 31. , 22. , 87. , 44. ,55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. , + 0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,0. , 0. , 0. , 0. ,91. , 82. , 37. , 64. ,55.1, 46.4, 73. , 28. , + 119.1, 12.1, 112.7, 13.1,14. , 114.2, 16.2, 117. }); + + sd::ops::unsorted_segment_sum op; + + auto result = op.evaluate({&x, &idx}, {}, {8}); + ASSERT_EQ(result.status(), Status::OK()); +// result.at(0)->printIndexedBuffer("Output"); +// result.at(0)->printShapeInfo("Out Shape"); +// exp.printIndexedBuffer("Expect"); + //exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentProd_1) { + auto x = NDArrayFactory::create({1.8, 2.5, 4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({4.5, 181.44, 3., 39.69, 1.9404}); + + sd::ops::segment_prod op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentProdBP_1) { + auto x = NDArrayFactory::create({ 1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); + auto exp = NDArrayFactory::create({2.5, 1.8, 90.72, 40.32, 172.8, 151.2, 3., 17.64, 75.6, 75.6, 13.86, 97.02, 3.234, 2.31, 4.41, 9.702}); + sd::ops::segment_prod_bp op; + + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); +// result.at(0)->printIndexedBuffer("ProdBP Output"); +// exp.printIndexedBuffer("ProdBP Expect"); + + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProdBP_1) { + auto x = NDArrayFactory::create({ 1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); + auto exp = NDArrayFactory::create({2.5, 1.8, 90.72, 40.32, 172.8, 151.2, 3., 17.64, 75.6, 75.6, 13.86, 97.02, 3.234, 2.31, 4.41, 9.702}); + sd::ops::segment_prod_bp op; + + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + //result.at(0)->printIndexedBuffer("ProdBP Output"); + //exp.printIndexedBuffer("ProdBP Expect"); + + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProdBP_2) { + auto x = NDArrayFactory::create({ 3., 1.8, 2.5, 4., 9., 2.1, 2.4, 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto eps = NDArrayFactory::create({1., 2., 3., 4., 5.}); + auto exp = NDArrayFactory::create({3., 2.5, 1.8, 90.72, 40.32, 172.8, 151.2, 17.64, 75.6, 75.6, 13.86, 97.02, 3.234, 2.31, 4.41, 9.702}); + auto n = NDArrayFactory::create(5LL); + sd::ops::unsorted_segment_prod_bp op; + + auto result = op.evaluate({&x, &idx, &eps, &n}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + //result.at(0)->printIndexedBuffer("Unsorted ProdBP Output"); + //exp.printIndexedBuffer("Unsorted ProdBP Expect"); + + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentProd_2) { + auto x = NDArrayFactory::create('c', {4, 4}, { + 1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1. }); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create('c', {3, 4}, { 3.78, 6. , 12. , 81., 2.1 , 2.1, 0.7 , 0.1, 3. , 4.2, 2.2 , 1.}); + + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + + sd::ops::segment_prod op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); +// exp.printIndexedBuffer("Expect"); +// exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentProdBP_2) { + auto x = NDArrayFactory::create('c', {4, 4}, {1.8, 2.5, 4., 9., + 2.1, 2.4, 3., 9., + 2.1, 2.1, 0.7, 0.1, + 3., 4.2, 2.2, 1. }); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto eps = NDArrayFactory::create('c', {3, 4}); + auto exp = NDArrayFactory::create('c', {4, 4}, {2.1, 4.8, 9., 36., 1.8, 5., 12., 36., 5., 6., 7., 8., 9., 10., 11., 12.}); + + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + eps.linspace(1); + sd::ops::segment_prod_bp op; + + auto result = op.evaluate({&x, &idx, &eps}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 2); +// exp.printIndexedBuffer("Expect"); +// exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentProd_3) { + auto x = NDArrayFactory::create('c', {4, 4, 4}, { + 91. , 82. , 37. , 64. , 55.1, 46.4, 73. , 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. , 51. , 42. , 67. , 24., + 15.1, 56.4, 93. , 28., 109.1, 82.1, 12.7, 113.1, 114. , 14.2, 116.2, 11. , 31. , 22. , 87., 44. , 55.1, 46.4, 73., 28. , + 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. , 91. , 82. , 37., 64. , 55.1, 46.4, 73., 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); + +// ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 1, 2}); + auto exp = NDArrayFactory::create('c', {3, 4, 4}, { + 91. , 82. , 37. , 64. , 55.1, 46.4, 73. , 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. , + 1581, 924, 5829, 1056,832.01001, 2616.9602, 6789, 784, 12993.810, 993.41003, 1431.2899, 1481.61, 1596, 1621.64, 1882.4401, 1287, + 91. , 82. , 37. , 64. , 55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); + + sd::ops::segment_prod op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); +// result.at(0)->printIndexedBuffer("Output"); +// result.at(0)->printShapeInfo("Out Shape"); +// exp.printIndexedBuffer("Expect"); +// exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentProd_04) { + auto x = NDArrayFactory::create({1,2,3,4,5,6,7,8 }); + +// ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0,0,1,2,2,2,3,3}); + auto exp = NDArrayFactory::create({ 2, 3, 120, 56}); + + sd::ops::segment_prod op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentProd_05) { + auto x = NDArrayFactory::create({1,2,3,4,5,6,7,8 }); + +// ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0,0,1,2,2,2,3,3}); + auto exp = NDArrayFactory::create({ 2, 3, 120, 56}); + + sd::ops::segment_prod op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + auto res = result.at(0); +// res->printIndexedBuffer("Segment prod 05"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentProd_05_1) { + auto x = NDArrayFactory::create({1,2,3,4,5,6,7,8 }); + +// ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0,0,1,2,2,2,3,3}); + auto exp = NDArrayFactory::create({ 2, 3, 120, 56}); + + sd::ops::segment_prod op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + auto res = result.at(0); +// res->printIndexedBuffer("Segment prod 05_1"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentProd_06) { + auto x = NDArrayFactory::create({'\x1','\x2','\x3','\x4','\x5','\x6','\x7','\x8' }); + +// ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0,0,1,2,2,2,3,3}); + auto exp = NDArrayFactory::create({ 2, 3, 120, 56}); + sd::ops::segment_prod op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentProd_07) { + auto x = NDArrayFactory::create({'\x1','\x2','\x3','\x4','\x5','\x6','\x7','\x8' }); + +// ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0,0,1,2,2,2,3,3}); + auto exp = NDArrayFactory::create({ 2, 3, 120, 56}); + sd::ops::segment_prod op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestSegmentProd_08) { + auto x = NDArrayFactory::create({'\x1','\x2','\x3','\x4','\x5','\x6','\x7','\x8', '\x9', '\xA' }); + +// ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0,0,2,2,2,2,3,3,3,3}); + auto exp = NDArrayFactory::create({ 2, 1,360, 5040}); + sd::ops::segment_prod op; + + auto result = op.evaluate({&x, &idx}, {}, {}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_1) { + auto x = NDArrayFactory::create({1.8, 2.5,4., 9., 2.1, 2.4,3.,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({0, 0, 1, 1, 1, 1, 2, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({4.5, 181.44, 3., 39.69, 1.9404}); + + sd::ops::unsorted_segment_prod op; + + auto result = op.evaluate({&x, &idx}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_11) { + auto x = NDArrayFactory::create({3.,1.8, 2.5,4., 9., 2.1, 2.4,9., 2.1, 2.1,0.7, 0.1, 3., 4.2, 2.2, 1.}); + auto idx = NDArrayFactory::create({2, 0, 0, 1, 1, 1, 1, 3, 3, 3, 4, 4, 4, 4, 4, 4}); + auto exp = NDArrayFactory::create({4.5, 181.44, 3., 39.69, 1.9404}); + + sd::ops::unsorted_segment_prod op; + + auto result = op.evaluate({&x, &idx}, {}, {5}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_2) { + auto x = NDArrayFactory::create('c', {4, 4}, { + 1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1. }); + auto idx = NDArrayFactory::create({0, 0, 1, 2}); + auto exp = NDArrayFactory::create('c', {3, 4}, { 3.78, 6. , 12. , 81., 2.1 , 2.1, 0.7 , 0.1, 3. , 4.2, 2.2 , 1.}); + + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + + sd::ops::unsorted_segment_prod op; + + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); +// exp.printIndexedBuffer("Expect"); +// exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_12) { + auto x = NDArrayFactory::create('c', {4, 4}, { + 3., 4.2, 2.2, 1., + 1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9., 2.1, 2.1, 0.7, 0.1 }); + auto idx = NDArrayFactory::create({2, 0, 0, 1}); + auto exp = NDArrayFactory::create('c', {3, 4}, { 3.78, 6. , 12. , 81., 2.1 , 2.1, 0.7 , 0.1, 3. , 4.2, 2.2 , 1.}); + + //{ 2.1, 2.5, 4., 9., 2.1, 2.1, 0.7, 0.1, 3., 4.2, 2.2, 1.} + + sd::ops::unsorted_segment_prod op; + + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_EQ(result.size(), 1); +// exp.printIndexedBuffer("Expect"); +// exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_08) { + auto x = NDArrayFactory::create({'\x1','\x2','\x3','\x4','\x5','\x6','\x7','\x8', '\x9', '\xA' }); + +// ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0,0,2,2,2,2,3,3,3,3}); + auto exp = NDArrayFactory::create({ 2, 1,360, 5040}); + sd::ops::unsorted_segment_prod op; + + auto result = op.evaluate({&x, &idx}, {}, {4}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_3) { + auto x = NDArrayFactory::create('c', {4, 4, 4}, { + 91. , 82. , 37. , 64. , 55.1, 46.4, 73. , 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. , 51. , 42. , 67. , 24., + 15.1, 56.4, 93. , 28., 109.1, 82.1, 12.7, 113.1, 114. , 14.2, 116.2, 11. , 31. , 22. , 87., 44. , 55.1, 46.4, 73., 28. , + 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. , 91. , 82. , 37., 64. , 55.1, 46.4, 73., 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); + +// ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0, 1, 1, 2}); + auto exp = NDArrayFactory::create('c', {3, 4, 4}, { + 91. , 82. , 37. , 64. , 55.1, 46.4, 73. , 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. , + 1581, 924, 5829, 1056,832.01001, 2616.9602, 6789, 784, 12993.810, 993.41003, 1431.2899, 1481.61, 1596.0000, 1621.6399, 1882.4401, 1287, + 91. , 82. , 37. , 64. , 55.1, 46.4, 73. , 28. ,119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); + + sd::ops::unsorted_segment_prod op; + + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); +// result.at(0)->printIndexedBuffer("Output"); +// result.at(0)->printShapeInfo("Out Shape"); +// exp.printIndexedBuffer("Expect"); +// exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_4) { + auto x = NDArrayFactory::create('c', {4, 4, 4}, { + 91. , 82. , 37. , 64. , 55.1, 46.4, 73. , 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. , 51. , 42. , 67. , 24., + 15.1, 56.4, 93. , 28., 109.1, 82.1, 12.7, 113.1, 114. , 14.2, 116.2, 11. , 31. , 22. , 87., 44. , 55.1, 46.4, 73., 28. , + 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. , 91. , 82. , 37., 64. , 55.1, 46.4, 73., 28. , 119.1, 12.1, 112.7, 13.1, 14. , 114.2, 16.2, 117. }); + +// ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({1, 1, 1, 2}); + auto exp = NDArrayFactory::create('c', {3, 4, 4}, { + 1., 1., 1., 1., 1., 1.,1.,1., 1.,1.,1.,1., 1.,1.,1.,1., + + 143871, 75768, 215673, 67584., 45843.75, 121426.96, 495597, 21952, + 1547562.8, 12020.262, 161306.38, 19409.092, 22344, 185191.27, 30495.531, 150579, + + 91., 82., 37., 64, 55.1, 46.400002, 73, 28, 119.1, 12.1, 112.7, 13.1, 14, 114.2, 16.2, 117}); + + sd::ops::unsorted_segment_prod op; + + auto result = op.evaluate({&x, &idx}, {}, {3}); + ASSERT_EQ(result.status(), Status::OK()); + //result.at(0)->printIndexedBuffer("Output"); +// result.at(0)->printShapeInfo("Out Shape"); + //exp.printIndexedBuffer("Expect"); +// exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProd_5) { + auto x = NDArrayFactory::create('c', {8, 15}); + +// ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({3, 1, 2, 1, 2, 3, 2, 1}); + auto exp = NDArrayFactory::create('c', {4, 15}, { + 1., 1., 1., 1., 1., + 1., 1., 1., 1., 1., + 1., 1., 1., 1., 1., + 78016., 85493., 93312., 101479., 110000., + 118881., 128128., 137747., 147744., 158125., + 168896., 180063., 191632., 203609., 216000., + 172081., 182528., 193347., 204544., 216125., + 228096., 240463., 253232., 266409., 280000., + 294011., 308448., 323317., 338624., 354375., + 76., 154., 234., 316., 400., + 486., 574., 664., 756., 850., + 946., 1044., 1144., 1246., 1350.}); + x.linspace(1.); + + sd::ops::unsorted_segment_prod op; + + auto result = op.evaluate({&x, &idx}, {}, {4}); + ASSERT_EQ(result.status(), Status::OK()); + //result.at(0)->printIndexedBuffer("Output"); +// result.at(0)->printShapeInfo("Out Shape"); + //exp.printIndexedBuffer("Expect"); +// exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestUnsortedSegmentProdBP_4) { + auto x = NDArrayFactory::create('c', {8}, { + 5,1,7,2,3,4,1,3}); + auto gradO = NDArrayFactory::create('c', {4}, {1,2,3,4}); +// ---------------------------------------------------------------- + + auto idx = NDArrayFactory::create({0,0,0,1,2,2,3,3}); + auto exp = NDArrayFactory::create('c', {8}, { + 7.000000, 35.000000, 5.000000, 2.000000, 12.000000, 9.000000, 12.000000, 4.000000 + }); +// 1., 1., 1., 1., 1., 1.,1.,1., 1.,1.,1.,1., 1.,1.,1.,1., +// +// 143871, 75768, 215673, 67584., 45843.75, 121426.96, 495597, 21952, +// 1547562.8, 12020.262, 161306.38, 19409.092, 22344, 185191.27, 30495.531, 150579, +// +// 91., 82., 37., 64, 55.1, 46.400002, 73, 28, 119.1, 12.1, 112.7, 13.1, 14, 114.2, 16.2, 117}); + + sd::ops::unsorted_segment_prod_bp op; + + auto result = op.evaluate({&x, &idx, &gradO}, {}, {4}); + ASSERT_EQ(result.status(), Status::OK()); + //result.at(0)->printIndexedBuffer("Output"); + //result.at(0)->printShapeInfo("Out Shape"); + //exp.printIndexedBuffer("Expect"); +// exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestExtractImagePatches_1) { + auto x = NDArrayFactory::create('c', {2,4, 4, 4}, { + 91., 82., 37., 64., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., + 51., 42., 67., 24., 15., 56., 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., + 31., 22., 87., 44., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., + 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119., 12., 112., 13., 14., 114., 16.2, 117., + 91., 82., 37., 64., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., + 51., 42., 67., 24., 15., 56., 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., + 31., 22., 87., 44., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., + 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119., 12., 112., 13., 140., 110., 160., 107.}); + +// ---------------------------------------------------------------- + + + auto exp = NDArrayFactory::create('c', {2, 4, 4, 4}, { + 91., 82., 37., 64., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., + 51., 42., 67., 24., 15., 56., 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., + 31., 22., 87., 44., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., + 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119., 12., 112., 13., 14., 114., 16.2, 117., + 91., 82., 37., 64., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., + 51., 42., 67., 24., 15., 56., 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., + 31., 22., 87., 44., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., + 91., 82., 37., 64., 55.1, 46.4, 73., 28., 119., 12., 112., 13., 140., 110., 160., 107.}); + + sd::ops::extract_image_patches op; + + auto result = op.evaluate({&x}, {}, {1,1,1,1,1,1,0}); + ASSERT_EQ(result.status(), Status::OK()); +// result.at(0)->printIndexedBuffer("Output"); + //result.at(0)->printShapeInfo("Out Shape"); +// exp.printIndexedBuffer("Expect"); + //exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestExtractImagePatches_2) { + auto x = NDArrayFactory::create('c', {3, 3, 4, 3}, { + 11., 12., 13., 12., 13., 14., 15., 16., 17., 18., 19., 10., + 1., 2., 3., 2., 3., 4., 21., 22., 23., 22., 23., 24., + 5., 6., 7., 8., 9., 0., 35., 36., 37., 38., 39., 40., + 9., 8., 7., 6., 5., 4., 49., 48., 47., 46., 45., 44., + 3., 2., 1., 0., 1., 2., 53., 52., 51., 50., 51., 52., + 15., 16., 17., 18., 19., 10., 135., 136., 137., 138., 139., 140., + 211., 12., 13., 12., 213., 14., 15., 216., 17., 128., 19., 10., + 21., 2., 3., 2., 3., 24., 21., 22., 223., 22., 223., 24., + 25., 6., 7., 8., 9., 20., 35., 36., 327., 38., 239., 40.}); + +//Images shape is (3, 3, 4, 3) +//[1, 1, 1, 1] +//[1, 3, 2, 1] +auto exp = NDArrayFactory::create('c', {3, 1, 1, 12}, { + 11., 12., 13., 12., 13., 14., 1., 2., 3., 2., 3., 4., + 9., 8., 7., 6., 5., 4., 3., 2., 1., 0., 1., 2., + 211., 12., 13., 12., 213., 14., 21., 2., 3., 2., 3., 24. + }); +// ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate({&x}, {}, {2,2, 3,3, 1,1,0}); + ASSERT_EQ(result.status(), Status::OK()); + + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestExtractImagePatches_3) { + auto x = NDArrayFactory::create('c', {3, 3, 4, 3}, { + 11., 12., 13., 12., 13., 14., 15., 16., 17., 18., 19., 10., + 1., 2., 3., 2., 3., 4., 21., 22., 23., 22., 23., 24., + 5., 6., 7., 8., 9., 0., 35., 36., 37., 38., 39., 40., + 9., 8., 7., 6., 5., 4., 49., 48., 47., 46., 45., 44., + 3., 2., 1., 0., 1., 2., 53., 52., 51., 50., 51., 52., + 15., 16., 17., 18., 19., 10., 135., 136., 137., 138., 139., 140., + 211., 12., 13., 12., 213., 14., 15., 216., 17., 128., 19., 10., + 21., 2., 3., 2., 3., 24., 21., 22., 223., 22., 223., 24., + 25., 6., 7., 8., 9., 20., 35., 36., 327., 38., 239., 40.}); + +//Images shape is (3, 3, 4, 3) +//[1, 1, 1, 1] +//[1, 3, 2, 1] +auto exp = NDArrayFactory::create('c', {3, 1, 2, 6}, { + 11., 12., 13., 5., 6., 7., 15., 16., 17., 35., 36., 37., 9., 8., + 7., 15., 16., 17., 49., 48., 47., 135., 136., 137., 211., 12., 13., 25., + 6., 7., 15., 216., 17., 35., 36., 327. + }); +// ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate({&x}, {}, {2,1,3,2,2,2,0}); + ASSERT_EQ(result.status(), Status::OK()); + + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestExtractImagePatches_4) { + auto x = NDArrayFactory::create('c', {3, 3, 4, 3}, { + 11., 12., 13., 12., 13., 14., 15., 16., 17., 18., 19., 10., + 1., 2., 3., 2., 3., 4., 21., 22., 23., 22., 23., 24., + 5., 6., 7., 8., 9., 0., 35., 36., 37., 38., 39., 40., + 9., 8., 7., 6., 5., 4., 49., 48., 47., 46., 45., 44., + 3., 2., 1., 0., 1., 2., 53., 52., 51., 50., 51., 52., + 15., 16., 17., 18., 19., 10., 135., 136., 137., 138., 139., 140., + 211., 12., 13., 12., 213., 14., 15., 216., 17., 128., 19., 10., + 21., 2., 3., 2., 3., 24., 21., 22., 223., 22., 223., 24., + 25., 6., 7., 8., 9., 20., 35., 36., 327., 38., 239., 40.}); + +//Images shape is (3, 3, 4, 3) +//[1, 1, 1, 1] +//[1, 3, 2, 1] +auto exp = NDArrayFactory::create('c', {3, 3, 4, 3}, { + 11., 12., 13., 12., 13., 14., 15., 16., 17., 18., 19., 10., + 1., 2., 3., 2., 3., 4., 21., 22., 23., 22., 23., 24., + 5., 6., 7., 8., 9., 0., 35., 36., 37., 38., 39., 40., + 9., 8., 7., 6., 5., 4., 49., 48., 47., 46., 45., 44., + 3., 2., 1., 0., 1., 2., 53., 52., 51., 50., 51., 52., + 15., 16., 17., 18., 19., 10., 135., 136., 137., 138., 139., 140., + 211., 12., 13., 12., 213., 14., 15., 216., 17., 128., 19., 10., + 21., 2., 3., 2., 3., 24., 21., 22., 223., 22., 223., 24., + 25., 6., 7., 8., 9., 20., 35., 36., 327., 38., 239., 40.}); +// ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate({&x}, {}, {1,1,1,1,1,1,0}); + ASSERT_EQ(result.status(), Status::OK()); +// result.at(0)->printIndexedBuffer("Output"); + //result.at(0)->printShapeInfo("Out Shape"); +// exp.printIndexedBuffer("Expect"); + //exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestExtractImagePatches_5) { + auto x = NDArrayFactory::create('c', {3, 3, 4, 3}, { + 11., 12., 13., 12., 13., 14., 15., 16., 17., 18., 19., 10., + 1., 2., 3., 2., 3., 4., 21., 22., 23., 22., 23., 24., + 5., 6., 7., 8., 9., 0., 35., 36., 37., 38., 39., 40., + 9., 8., 7., 6., 5., 4., 49., 48., 47., 46., 45., 44., + 3., 2., 1., 0., 1., 2., 53., 52., 51., 50., 51., 52., + 15., 16., 17., 18., 19., 10., 135., 136., 137., 138., 139., 140., +211., 12., 13., 12., 213., 14., 15., 216., 17., 128., 19., 10., + 21., 2., 3., 2., 3., 24., 21., 22., 223., 22., 223., 24., + 25., 6., 7., 8., 9., 20., 35., 36., 327., 38., 239., 40.}); + +//Images shape is (3, 3, 4, 3) +//[1, 1, 1, 1] +//[1, 3, 2, 1] +auto exp = NDArrayFactory::create('c', {3, 1, 1, 18}, { + 11., 12., 13., 15., 16., 17., 1., 2., 3., 21., 22., 23., 5., 6., 7., 35., 36., 37., + 9., 8., 7., 49., 48., 47., 3., 2., 1., 53., 52., 51., 15., 16., 17., 135., 136., 137., + 211., 12., 13., 15., 216., 17., 21., 2., 3., 21., 22., 223., 25., 6., 7., 35., 36., 327. + +//Patch shape is (3, 1, 2, 18) + + }); +// ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate({&x}, {}, {3,2,3,2,1,2,0}); + ASSERT_EQ(result.status(), Status::OK()); +// result.at(0)->printIndexedBuffer("Output"); + //result.at(0)->printShapeInfo("Out Shape"); +// exp.printIndexedBuffer("Expect"); + //exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestExtractImagePatches_6) { + auto x = NDArrayFactory::create('c', {2, 2, 4, 2}, { + 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, + 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42 +}); + +//Images shape is (3, 3, 4, 3) +//[1, 1, 1, 1] +//[1, 3, 2, 1] +auto exp = NDArrayFactory::create('c', {2, 1, 4, 4}, { + 11.11, 11.12, 12.11, 12.12, 11.21, 11.22, 12.21, 12.22, 11.31, 11.32, 12.31, 12.32, 11.41, 11.42, 12.41, 12.42, + 21.11, 21.12, 22.11, 22.12, 21.21, 21.22, 22.21, 22.22, 21.31, 21.32, 22.31, 22.32, 21.41, 21.42, 22.41, 22.42 + }); +// ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate({&x}, {}, {2,1, 1,1, 1,1,0}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_7) { + auto x = NDArrayFactory::create('c', {1, 3, 3, 1}); + x.linspace(1); + +//Images shape is (1, 3, 3, 4) +//[1, 1, 1, 1] +//[1, 3, 2, 1] + auto exp = NDArrayFactory::create('c', {1, 3, 3, 4}, { + 1., 2., 4., 5., 2., 3., 5., 6., 3., 0., 6., 0., + 4., 5., 7., 8., 5., 6., 8., 9., 6., 0., 9., 0., 7., 8., 0., 0., 8., 9., 0., 0., 9., 0., 0., 0. }); +// ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); +// output->printBuffer("Output"); +// exp.printBuffer("Expect"); +// for (Nd4jLong e = 0; e < exp.lengthOf(); e++) +// if (exp.e(e) != output->e(e)) +// printf("%lld ", e); +// printf("\n"); + //result.at(1)->printBuffer("OUtput2"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_8) { + auto x = NDArrayFactory::create('c', {1, 3, 3, 2}); + x.linspace(1); + +//Images shape is (1, 3, 3, 4) +//[1, 1, 1, 1] +//[1, 3, 2, 1] + auto exp = NDArrayFactory::create('c', {1, 3, 3, 8}, { + 1, 2, 3, 4, 7, 8, 9, 10, 3, 4, 5, 6, 9, 10, 11, 12, 5, 6, 0, 0, 11, 12, 0, 0, + 7, 8, 9, 10, 13, 14, 15, 16, 9, 10, 11, 12, 15, 16, 17, 18, 11, 12, 0, 0, 17, 18, 0, 0, + 13, 14, 15, 16, 0, 0, 0, 0, 15, 16, 17, 18, 0, 0, 0, 0, 17, 18, 0, 0, 0, 0, 0, 0 }); +// ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); +// output->printBuffer("Output"); +// exp.printBuffer("Expect"); +// for (Nd4jLong e = 0; e < exp.lengthOf(); e++) +// if (exp.e(e) != output->e(e)) +// printf("%lld ", e); +// printf("\n"); + //result.at(1)->printBuffer("OUtput2"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_9) { + auto x = NDArrayFactory::create('c', {1, 6, 6, 2}); + x.linspace(1); + +//Images shape is (1, 3, 3, 4) +//[1, 1, 1, 1] +//[1, 3, 2, 1] + auto exp = NDArrayFactory::create('c', {1, 6, 6, 18}, { + 0., 0., 0., 0., 0., 0., 0., 0., 1., 2., 3., 4., 0., 0., 13., 14., 15., 16., + 0., 0., 0., 0., 0., 0., 1., 2., 3., 4., 5., 6., 13., 14., 15., 16., 17., 18., + 0., 0., 0., 0., 0., 0., 3., 4., 5., 6., 7., 8., 15., 16., 17., 18., 19., 20., + 0., 0., 0., 0., 0., 0., 5., 6., 7., 8., 9., 10., 17., 18., 19., 20., 21., 22., + 0., 0., 0., 0., 0., 0., 7., 8., 9., 10., 11., 12., 19., 20., 21., 22., 23., 24., + 0., 0., 0., 0., 0., 0., 9., 10., 11., 12., 0., 0., 21., 22., 23., 24., 0., 0., + 0., 0., 1., 2., 3., 4., 0., 0., 13., 14., 15., 16., 0., 0., 25., 26., 27., 28., + 1., 2., 3., 4., 5., 6., 13., 14., 15., 16., 17., 18., 25., 26., 27., 28., 29., 30., + 3., 4., 5., 6., 7., 8., 15., 16., 17., 18., 19., 20., 27., 28., 29., 30., 31., 32., + 5., 6., 7., 8., 9., 10., 17., 18., 19., 20., 21., 22., 29., 30., 31., 32., 33., 34., + 7., 8., 9., 10., 11., 12., 19., 20., 21., 22., 23., 24., 31., 32., 33., 34., 35., 36., + 9., 10., 11., 12., 0., 0., 21., 22., 23., 24., 0., 0., 33., 34., 35., 36., 0., 0., + 0., 0., 13., 14., 15., 16., 0., 0., 25., 26., 27., 28., 0., 0., 37., 38., 39., 40., + 13., 14., 15., 16., 17., 18., 25., 26., 27., 28., 29., 30., 37., 38., 39., 40., 41., 42., + 15., 16., 17., 18., 19., 20., 27., 28., 29., 30., 31., 32., 39., 40., 41., 42., 43., 44., + 17., 18., 19., 20., 21., 22., 29., 30., 31., 32., 33., 34., 41., 42., 43., 44., 45., 46., + 19., 20., 21., 22., 23., 24., 31., 32., 33., 34., 35., 36., 43., 44., 45., 46., 47., 48., + 21., 22., 23., 24., 0., 0., 33., 34., 35., 36., 0., 0., 45., 46., 47., 48., 0., 0., + 0., 0., 25., 26., 27., 28., 0., 0., 37., 38., 39., 40., 0., 0., 49., 50., 51., 52., + 25., 26., 27., 28., 29., 30., 37., 38., 39., 40., 41., 42., 49., 50., 51., 52., 53., 54., + 27., 28., 29., 30., 31., 32., 39., 40., 41., 42., 43., 44., 51., 52., 53., 54., 55., 56., + 29., 30., 31., 32., 33., 34., 41., 42., 43., 44., 45., 46., 53., 54., 55., 56., 57., 58., + 31., 32., 33., 34., 35., 36., 43., 44., 45., 46., 47., 48., 55., 56., 57., 58., 59., 60., + 33., 34., 35., 36., 0., 0., 45., 46., 47., 48., 0., 0., 57., 58., 59., 60., 0., 0., + 0., 0., 37., 38., 39., 40., 0., 0., 49., 50., 51., 52., 0., 0., 61., 62., 63., 64., + 37., 38., 39., 40., 41., 42., 49., 50., 51., 52., 53., 54., 61., 62., 63., 64., 65., 66., + 39., 40., 41., 42., 43., 44., 51., 52., 53., 54., 55., 56., 63., 64., 65., 66., 67., 68., + 41., 42., 43., 44., 45., 46., 53., 54., 55., 56., 57., 58., 65., 66., 67., 68., 69., 70., + 43., 44., 45., 46., 47., 48., 55., 56., 57., 58., 59., 60., 67., 68., 69., 70., 71., 72., + 45., 46., 47., 48., 0., 0., 57., 58., 59., 60., 0., 0., 69., 70., 71., 72., 0., 0., + 0., 0., 49., 50., 51., 52., 0., 0., 61., 62., 63., 64., 0., 0., 0., 0., 0., 0., + 49., 50., 51., 52., 53., 54., 61., 62., 63., 64., 65., 66., 0., 0., 0., 0., 0., 0., + 51., 52., 53., 54., 55., 56., 63., 64., 65., 66., 67., 68., 0., 0., 0., 0., 0., 0., + 53., 54., 55., 56., 57., 58., 65., 66., 67., 68., 69., 70., 0., 0., 0., 0., 0., 0., + 55., 56., 57., 58., 59., 60., 67., 68., 69., 70., 71., 72., 0., 0., 0., 0., 0., 0., + 57., 58., 59., 60., 0., 0., 69., 70., 71., 72., 0., 0., 0., 0., 0., 0., 0., 0.}); +// ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate({&x}, {}, {3,3, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); +// output->printBuffer("OutputSame"); +// exp.printBuffer("ExpectSame"); +// for (Nd4jLong e = 0; e < exp.lengthOf(); e++) +// if (exp.e(e) != output->e(e)) +// printf("%lld ", e); +// printf("\n"); + //result.at(1)->printBuffer("OUtput2"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_9_1) { + auto x = NDArrayFactory::create('c', {1, 4, 4, 2}, {1, 116, 2, 116, 3, 116, 4, 116, + 5, 117, 6, 117, 7, 117, 8, 117, + 9, 118, 10, 118, 11, 118, 12, 118, + 13, 119, 14, 119, 15, 119, 16, 119}); + //x.linspace(1); + +//Images shape is (1, 3, 3, 4) +//[1, 1, 1, 1] +//[1, 3, 2, 1] + auto exp = NDArrayFactory::create('c', {1, 4, 4, 8}, { + 1, 116, 2, 116, 5, 117, 6, 117, 2, 116, 3, 116, 6, 117, 7, 117, 3, 116, + 4, 116, 7, 117, 8, 117, 4, 116, 0, 0, 8, 117, 0, 0, 5, 117, 6, 117, + 9, 118, 10, 118, 6, 117, 7, 117, 10, 118, 11, 118, 7, 117, 8, 117, 11, 118, +12, 118, 8, 117, 0, 0, 12, 118, 0, 0, 9, 118, 10, 118, 13, 119, 14, 119, +10, 118, 11, 118, 14, 119, 15, 119, 11, 118, 12, 118, 15, 119, 16, 119, 12, 118, + 0, 0, 16, 119, 0, 0, 13, 119, 14, 119, 0, 0, 0, 0, 14, 119, 15, 119, + 0, 0, 0, 0, 15, 119, 16, 119, 0, 0, 0, 0, 16, 119, 0, 0, 0, 0, + 0, 0 + + }); +// ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); +// output->printBuffer("OutputSame"); +// exp.printBuffer("ExpectSame"); +// for (Nd4jLong e = 0; e < exp.lengthOf(); e++) +// if (exp.e(e) != output->e(e)) +// printf("%lld ", e); +// printf("\n"); + //result.at(1)->printBuffer("OUtput2"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +// +// +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_10) { + auto x = NDArrayFactory::create('c', {1, 6, 6, 2}); + x.linspace(1); + +//Images shape is (1, 3, 3, 4) +//[1, 1, 1, 1] +//[1, 3, 2, 1] + auto exp = NDArrayFactory::create('c', {1, 4, 4, 18}, { + 1., 2., 3., 4., 5., 6., 13., 14., 15., 16., 17., 18., 25., 26., 27., 28., 29., 30., + 3., 4., 5., 6., 7., 8., 15., 16., 17., 18., 19., 20., 27., 28., 29., 30., 31., 32., + 5., 6., 7., 8., 9., 10., 17., 18., 19., 20., 21., 22., 29., 30., 31., 32., 33., 34., + 7., 8., 9., 10., 11., 12., 19., 20., 21., 22., 23., 24., 31., 32., 33., 34., 35., 36., + 13., 14., 15., 16., 17., 18., 25., 26., 27., 28., 29., 30., 37., 38., 39., 40., 41., 42., + 15., 16., 17., 18., 19., 20., 27., 28., 29., 30., 31., 32., 39., 40., 41., 42., 43., 44., + 17., 18., 19., 20., 21., 22., 29., 30., 31., 32., 33., 34., 41., 42., 43., 44., 45., 46., + 19., 20., 21., 22., 23., 24., 31., 32., 33., 34., 35., 36., 43., 44., 45., 46., 47., 48., + 25., 26., 27., 28., 29., 30., 37., 38., 39., 40., 41., 42., 49., 50., 51., 52., 53., 54., + 27., 28., 29., 30., 31., 32., 39., 40., 41., 42., 43., 44., 51., 52., 53., 54., 55., 56., + 29., 30., 31., 32., 33., 34., 41., 42., 43., 44., 45., 46., 53., 54., 55., 56., 57., 58., + 31., 32., 33., 34., 35., 36., 43., 44., 45., 46., 47., 48., 55., 56., 57., 58., 59., 60., + 37., 38., 39., 40., 41., 42., 49., 50., 51., 52., 53., 54., 61., 62., 63., 64., 65., 66., + 39., 40., 41., 42., 43., 44., 51., 52., 53., 54., 55., 56., 63., 64., 65., 66., 67., 68., + 41., 42., 43., 44., 45., 46., 53., 54., 55., 56., 57., 58., 65., 66., 67., 68., 69., 70., + 43., 44., 45., 46., 47., 48., 55., 56., 57., 58., 59., 60., 67., 68., 69., 70., 71., 72.}); +// ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + //x.printIndexedBuffer("Images"); + //x.printBuffer("Images linear"); + auto result = op.evaluate({&x}, {}, {3,3, 1,1, 1,1, 0}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="VALID" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); +// output->printBuffer("OutputValid"); +// exp.printBuffer("ExpectValid"); +// for (Nd4jLong e = 0; e < exp.lengthOf(); e++) +// if (exp.e(e) != output->e(e)) +// printf("%lld ", e); +// printf("\n"); + //result.at(1)->printBuffer("OUtput2"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} +TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_010) { + auto x = NDArrayFactory::create('c', {1, 4, 4, 1}); + x.linspace(1); + +//Images shape is (1, 3, 3, 4) +//[1, 1, 1, 1] +//[1, 3, 2, 1] + auto exp = NDArrayFactory::create('c', {1, 3, 3, 4}, { + 1, 2, 5, 6, 2, 3, 6, 7, 3, 4, 7, 8, 5, 6, 9, 10, 6, 7, 10, 11, 7, 8, 11, 12, + 9, 10, 13, 14, 10, 11, 14, 15, 11, 12, 15, 16}); +// ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + //x.printIndexedBuffer("Images"); + //x.printBuffer("Images linear"); + auto result = op.evaluate({&x}, {}, {2,2, 1,1, 1,1, 0}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="VALID" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); +// output->printBuffer("OutputValid"); +// exp.printBuffer("ExpectValid"); +// for (Nd4jLong e = 0; e < exp.lengthOf(); e++) +// if (exp.e(e) != output->e(e)) +// printf("%lld ", e); +// printf("\n"); + //result.at(1)->printBuffer("OUtput2"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} +TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_010_1) { + auto x = NDArrayFactory::create('c', {1, 4, 4, 1}); + x.linspace(1); + +//Images shape is (1, 3, 3, 4) +//[1, 1, 1, 1] +//[1, 3, 2, 1] + auto exp = NDArrayFactory::create('c', {1, 4, 4, 4}, { + 1, 2, 5, 6, 2, 3, 6, 7, 3, 4, 7, 8, 4, 0, 8, 0, 5, 6, 9, 10, 6, 7, 10, 11, + 7, 8, 11, 12, 8, 0, 12, 0, 9, 10, 13, 14, 10, 11, 14, 15, 11, 12, 15, 16, 12, 0, 16, 0, + 13, 14, 0, 0, 14, 15, 0, 0, 15, 16, 0, 0, 16, 0, 0, 0}); +// ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + //x.printIndexedBuffer("Images"); + //x.printBuffer("Images linear"); + auto result = op.evaluate({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="VALID" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); +// output->printBuffer("OutputSame"); +// exp.printBuffer("ExpectSame"); +// exp.printIndexedBuffer("Expect Same Formatted"); +// output->printIndexedBuffer("Output Same Formatted"); +// for (Nd4jLong e = 0; e < exp.lengthOf(); e++) +// if (exp.e(e) != output->e(e)) +// printf("%lld ", e); +// printf("\n"); + //result.at(1)->printBuffer("OUtput2"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_011) { + auto x = NDArrayFactory::create('c', {1, 4, 4, 1}); + x.linspace(1); + +//Images shape is (1, 3, 3, 4) +//[1, 1, 1, 1] +//[1, 3, 2, 1] + auto exp = NDArrayFactory::create('c', {1, 2, 2, 4}, { + 1, 3, 9, 11, 2, 4, 10, 12, 5, 7, 13, 15, 6, 8, 14, 16, + }); +// ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + //x.printIndexedBuffer("Images"); + //x.printBuffer("Images linear"); + auto result = op.evaluate({&x}, {}, {2,2, 1,1, 2,2, 0}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="VALID" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); +// output->printBuffer("OutputValid"); +// exp.printBuffer("ExpectValid"); +// for (Nd4jLong e = 0; e < exp.lengthOf(); e++) +// if (exp.e(e) != output->e(e)) +// printf("%lld ", e); +// printf("\n"); + //result.at(1)->printBuffer("OUtput2"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_11) { + auto x = NDArrayFactory::create('c', {1, 8, 8, 2}); + x.linspace(1); + +//Images shape is (1, 3, 3, 4) +//[1, 1, 1, 1] +//[1, 3, 2, 1] + auto exp = NDArrayFactory::create('c', {1, 4, 4, 8}, { + 1, 2, 3, 4, 17, 18, 19, 20, 5, 6, 7, 8, 21, 22, 23, 24, 9, 10, + 11, 12, 25, 26, 27, 28, 13, 14, 15, 16, 29, 30, 31, 32, 33, 34, 35, 36, + 49, 50, 51, 52, 37, 38, 39, 40, 53, 54, 55, 56, 41, 42, 43, 44, 57, 58, + 59, 60, 45, 46, 47, 48, 61, 62, 63, 64, 65, 66, 67, 68, 81, 82, 83, 84, + 69, 70, 71, 72, 85, 86, 87, 88, 73, 74, 75, 76, 89, 90, 91, 92, 77, 78, + 79, 80, 93, 94, 95, 96, 97, 98, 99, 100, 113, 114, 115, 116, 101, 102, 103, 104, + 117, 118, 119, 120, 105, 106, 107, 108, 121, 122, 123, 124, 109, 110, 111, 112, 125, 126, + 127, 128}); +// ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate({&x}, {}, {2,2, 2,2, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); +// output->printBuffer("Output"); +// exp.printBuffer("Expect"); +// for (Nd4jLong e = 0; e < exp.lengthOf(); e++) +// if (exp.e(e) != output->e(e)) +// printf("%lld ", e); +// printf("\n"); + //result.at(1)->printBuffer("OUtput2"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_12) { + auto x = NDArrayFactory::create('c', {1, 8, 8, 2}); + x.linspace(1); + +//Images shape is (1, 3, 3, 4) +//[1, 1, 1, 1] +//[1, 3, 2, 1] + auto exp = NDArrayFactory::create('c', {1, 8, 8, 8}, { + 0, 0, 0, 0, 0, 0, 19, 20, 0, 0, 0, 0, 17, 18, 21, 22, 0, 0, + 0, 0, 19, 20, 23, 24, 0, 0, 0, 0, 21, 22, 25, 26, 0, 0, 0, 0, + 23, 24, 27, 28, 0, 0, 0, 0, 25, 26, 29, 30, 0, 0, 0, 0, 27, 28, + 31, 32, 0, 0, 0, 0, 29, 30, 0, 0, 0, 0, 3, 4, 0, 0, 35, 36, + 1, 2, 5, 6, 33, 34, 37, 38, 3, 4, 7, 8, 35, 36, 39, 40, 5, 6, + 9, 10, 37, 38, 41, 42, 7, 8, 11, 12, 39, 40, 43, 44, 9, 10, 13, 14, + 41, 42, 45, 46, 11, 12, 15, 16, 43, 44, 47, 48, 13, 14, 0, 0, 45, 46, + 0, 0, 0, 0, 19, 20, 0, 0, 51, 52, 17, 18, 21, 22, 49, 50, 53, 54, + 19, 20, 23, 24, 51, 52, 55, 56, 21, 22, 25, 26, 53, 54, 57, 58, 23, 24, + 27, 28, 55, 56, 59, 60, 25, 26, 29, 30, 57, 58, 61, 62, 27, 28, 31, 32, + 59, 60, 63, 64, 29, 30, 0, 0, 61, 62, 0, 0, 0, 0, 35, 36, 0, 0, + 67, 68, 33, 34, 37, 38, 65, 66, 69, 70, 35, 36, 39, 40, 67, 68, 71, 72, + 37, 38, 41, 42, 69, 70, 73, 74, 39, 40, 43, 44, 71, 72, 75, 76, 41, 42, + 45, 46, 73, 74, 77, 78, 43, 44, 47, 48, 75, 76, 79, 80, 45, 46, 0, 0, + 77, 78, 0, 0, 0, 0, 51, 52, 0, 0, 83, 84, 49, 50, 53, 54, 81, 82, + 85, 86, 51, 52, 55, 56, 83, 84, 87, 88, 53, 54, 57, 58, 85, 86, 89, 90, + 55, 56, 59, 60, 87, 88, 91, 92, 57, 58, 61, 62, 89, 90, 93, 94, 59, 60, + 63, 64, 91, 92, 95, 96, 61, 62, 0, 0, 93, 94, 0, 0, 0, 0, 67, 68, + 0, 0, 99, 100, 65, 66, 69, 70, 97, 98, 101, 102, 67, 68, 71, 72, 99, 100, + 103, 104, 69, 70, 73, 74, 101, 102, 105, 106, 71, 72, 75, 76, 103, 104, 107, 108, + 73, 74, 77, 78, 105, 106, 109, 110, 75, 76, 79, 80, 107, 108, 111, 112, 77, 78, + 0, 0, 109, 110, 0, 0, 0, 0, 83, 84, 0, 0, 115, 116, 81, 82, 85, 86, + 113, 114, 117, 118, 83, 84, 87, 88, 115, 116, 119, 120, 85, 86, 89, 90, 117, 118, + 121, 122, 87, 88, 91, 92, 119, 120, 123, 124, 89, 90, 93, 94, 121, 122, 125, 126, + 91, 92, 95, 96, 123, 124, 127, 128, 93, 94, 0, 0, 125, 126, 0, 0, 0, 0, + 99, 100, 0, 0, 0, 0, 97, 98, 101, 102, 0, 0, 0, 0, 99, 100, 103, 104, + 0, 0, 0, 0, 101, 102, 105, 106, 0, 0, 0, 0, 103, 104, 107, 108, 0, 0, + 0, 0, 105, 106, 109, 110, 0, 0, 0, 0, 107, 108, 111, 112, 0, 0, 0, 0, + 109, 110, 0, 0, 0, 0, 0, 0}); +// ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate({&x}, {}, {2,2, 1,1, 2,2, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,2,2,1], padding="SAME" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); + //output->printShapeInfo("Output shape"); +// output->printIndexedBuffer("Output"); +// exp.printBuffer("Expect"); +// for (Nd4jLong e = 0; e < exp.lengthOf(); e++) +// if (exp.e(e) != output->e(e)) +// printf("%lld ", e); +// printf("\n"); + //result.at(1)->printBuffer("OUtput2"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestExtractImagePatches_SGO_13) { + auto x = NDArrayFactory::create('c', {1, 3, 3, 2}); + x.linspace(1); + + auto exp = NDArrayFactory::create('c', {1, 3, 3, 8}, { + 1., 2., 3., 4., 7., 8., 9., 10., 3., 4., 5., 6., 9., 10., 11., 12., 5., 6., + 0., 0., 11., 12., 0., 0., 7., 8., 9., 10., 13., 14., 15., 16., 9., 10., 11., 12., + 15., 16., 17., 18., 11., 12., 0., 0., 17., 18., 0., 0., 13., 14., 15., 16., 0., 0., + 0., 0., 15., 16., 17., 18., 0., 0., 0., 0., 17., 18., 0., 0., 0., 0., 0., 0. }); +// ---------------------------------------------------------------- + sd::ops::extract_image_patches op; + + auto result = op.evaluate({&x}, {}, {2,2, 1,1, 1,1, 1}); // equiv TF ksizes=[1,2,2,1], strides=[1,1,1,1], rates=[1,1,1,1], padding="SAME" + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestRoll_1) { + auto x = NDArrayFactory::create('c', {2, 2, 4, 2}, { + 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, + 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42 +}); + +auto exp = NDArrayFactory::create('c', {2, 2, 4, 2}, { + 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, + 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, + 21.41, 21.42, 22.11, 22.12 + }); +// ---------------------------------------------------------------- + sd::ops::roll op; + + auto result = op.evaluate({&x}, {}, {6}); + ASSERT_EQ(result.status(), Status::OK()); + + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestRoll_2) { + auto x = NDArrayFactory::create('c', {2, 2, 4, 2}, { + 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, + 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42 +}); + +auto exp = NDArrayFactory::create('c', {2, 2, 4, 2}, { + 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, + 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42 + }); +// ---------------------------------------------------------------- + sd::ops::roll op; + + auto result = op.evaluate({&x}, {}, {-8}); + ASSERT_EQ(result.status(), Status::OK()); + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestRoll_3) { + auto x = NDArrayFactory::create('c', {2, 2, 4, 2}, { + 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, + 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42 +}); + +auto exp = NDArrayFactory::create('c', {2, 2, 4, 2}, { + 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, + 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42 + }); +// ---------------------------------------------------------------- + sd::ops::roll op; + + auto result = op.evaluate({&x}, {}, {-40}); + ASSERT_EQ(result.status(), Status::OK()); + + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestRoll_4) { + auto x = NDArrayFactory::create('c', {2, 2, 4, 2}, { + 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, + 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42 +}); + +auto exp = NDArrayFactory::create('c', {2, 2, 4, 2}, { + 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, + 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, + 21.41, 21.42, 22.11, 22.12 + }); +// ---------------------------------------------------------------- + sd::ops::roll op; + + auto result = op.evaluate({&x}, {}, {38}); + ASSERT_EQ(result.status(), Status::OK()); + //result.at(0)->printIndexedBuffer("Output 4"); + //exp.printIndexedBuffer("Expect 4"); + + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestRoll_4_inplace) { + auto x = NDArrayFactory::create('c', {2, 2, 4, 2}, { + 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, + 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, 21.41, 21.42, 22.11, 22.12, 22.21, 22.22, 22.31, 22.32, 22.41, 22.42 +}); + +auto exp = NDArrayFactory::create('c', {2, 2, 4, 2}, { + 22.21, 22.22, 22.31, 22.32, 22.41, 22.42, 11.11, 11.12, 11.21, 11.22, 11.31, 11.32, 11.41, 11.42, + 12.11, 12.12, 12.21, 12.22, 12.31, 12.32, 12.41, 12.42, 21.11, 21.12, 21.21, 21.22, 21.31, 21.32, + 21.41, 21.42, 22.11, 22.12 + }); +// ---------------------------------------------------------------- + sd::ops::roll op; + NDArray* y = nullptr; + auto result = op.execute({&x}, {y}, {}, {38}, {}, {}, true); + ASSERT_EQ(result, Status::OK()); + //x.printIndexedBuffer("Output 4 inplace"); + //exp.printIndexedBuffer("Expect 4 inplace"); + + ASSERT_TRUE(exp.equalsTo(&x)); + +// +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestRoll_5) { + auto x = NDArrayFactory::create('c', {3, 4}, { + 0., 1., 2., 3., 4, 5., 6., 7., 8., 9., 10., 11. +}); + +auto exp = NDArrayFactory::create('c', {3, 4}, { + 2., 3., 0., 1., 6., 7., 4., 5., 10., 11., 8., 9. +// 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3 +}); +// ---------------------------------------------------------------- + sd::ops::roll op; + + auto result = op.evaluate({&x}, {}, {2, 1}); + ASSERT_EQ(result.status(), Status::OK()); + + //result.at(0)->printIndexedBuffer("Output"); + //exp.printIndexedBuffer("Expect"); + + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestRoll_6) { + auto x = NDArrayFactory::create('c', {2, 3, 2}, { + 0., 1., 2., 3., 4, 5., 6., 7., 8., 9., 10., 11. +}); + +auto exp = NDArrayFactory::create('c', {2, 3, 2}, { + 1., 0., 3., 2., 5., 4., 7., 6., 9., 8., 11., 10. +}); +// ---------------------------------------------------------------- + sd::ops::roll op; + + auto result = op.evaluate({&x}, {}, {1, 2}); + ASSERT_EQ(result.status(), Status::OK()); + + //result.at(0)->printIndexedBuffer("Output"); + //exp.printIndexedBuffer("Expect"); + + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestRoll_7) { + auto x = NDArrayFactory::create('c', {2, 3, 2}, { + 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11. +}); + +auto exp = NDArrayFactory::create('c', {2, 3, 2}, { + 11., 10., 7., 6., 9., 8., 5., 4., 1., 0., 3., 2. +}); +// ---------------------------------------------------------------- + sd::ops::roll op; + + auto result = op.evaluate({&x}, {}, {1, 2, 1, 0}); + ASSERT_EQ(result.status(), Status::OK()); + + //result.at(0)->printIndexedBuffer("Output"); + //exp.printIndexedBuffer("Expect"); + + ASSERT_TRUE(exp.equalsTo(result.at(0))); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestRoll_8) { + auto x = NDArrayFactory::create('c', {2, 3, 2}, { + 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11. +}); + +auto exp = NDArrayFactory::create('c', {2, 3, 2}, { + 11., 10., 7., 6., 9., 8., 5., 4., 1., 0., 3., 2. +}); +// ---------------------------------------------------------------- + sd::ops::roll op; + NDArray* y = nullptr; + auto result = op.execute({&x}, {y}, {}, {1, 2, 1, 0}, {}, {}, true); + ASSERT_EQ(result, Status::OK()); + + //x.printIndexedBuffer("Output"); + //exp.printIndexedBuffer("Expect"); + + ASSERT_TRUE(exp.equalsTo(&x)); + +// +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestRoll_9) { + auto x = NDArrayFactory::create('c', {2, 3, 3}, { + 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17. +}); + +auto exp = NDArrayFactory::create('c', {2, 3, 3}, { + 6., 7., 8., 0., 1., 2., 3., 4., 5., 15., 16., 17., 9., 10., 11., 12., 13., 14. +}); +// ---------------------------------------------------------------- + sd::ops::roll op; + NDArray* y = nullptr; + auto result = op.execute({&x}, {y}, {}, {1, 1}, {}, {}, true); + ASSERT_EQ(result, Status::OK()); + + ASSERT_TRUE(exp.equalsTo(&x)); + +// +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestRoll_10) { + auto x = NDArrayFactory::create('c', {2, 3, 4}, { + 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. + }); + + auto exp = NDArrayFactory::create('c', {2, 3, 4}, { + 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. + }); +// ---------------------------------------------------------------- + sd::ops::roll op; + auto result = op.evaluate({&x}, {}, {3, 1}); + ASSERT_EQ(result.status(), Status::OK()); + auto out = result.at(0); + +// out->printIndexedBuffer("Output"); + //exp.printIndexedBuffer("Expect"); + + ASSERT_TRUE(exp.equalsTo(out)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestRoll_11) { + auto x = NDArrayFactory::create('c', {2, 3, 4}, { + 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. + }); + auto shift = NDArrayFactory::create({1,2}); + auto axis = NDArrayFactory::create({0, 1}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}, { + 17., 18., 19., 20., 21., 22., 23., 24., 13., 14., 15., 16., 5., 6., 7, 8, 9, 10, 11, 12, 1, 2, 3, 4 + }); +// ---------------------------------------------------------------- + sd::ops::roll op; + NDArray* y = nullptr; + auto result = op.evaluate({&x, &shift, &axis}); + ASSERT_EQ(result.status(), Status::OK()); + auto out = result.at(0); + +// out->printIndexedBuffer("Output"); + //exp.printIndexedBuffer("Expect"); + + ASSERT_TRUE(exp.equalsTo(out)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestRoll_12) { + auto x = NDArrayFactory::create('c', {2, 3, 4}, { + 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. + }); + auto shift = NDArrayFactory::create({1,1,1}); + auto axis = NDArrayFactory::create({0, 1, 2}); + + auto exp = NDArrayFactory::create('c', {2, 3, 4}, { + 24, 21, 22, 23, 16, 13, 14, 15, 20, 17, 18, 19, 12, 9, 10, 11, 4, 1, 2, 3, 8, 5, 6, 7 + }); +// ---------------------------------------------------------------- + sd::ops::roll op; + NDArray* y = nullptr; + auto result = op.evaluate({&x, &shift, &axis}); + ASSERT_EQ(result.status(), Status::OK()); + auto out = result.at(0); + + ASSERT_TRUE(exp.equalsTo(out)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestRoll_13) { + auto x = NDArrayFactory::create('c', {2, 3, 4}, { + 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. + }); + auto shift = NDArrayFactory::create(3); + auto axis = NDArrayFactory::create(2); + + auto exp = NDArrayFactory::create('c', {2, 3, 4}, { + 2,3,4,1,6,7,8,5,10,11,12,9,14, 15, 16, 13, 18, 19, 20, 17, 22, 23, 24, 21 + }); +// ---------------------------------------------------------------- + sd::ops::roll op; + NDArray* y = nullptr; + auto result = op.evaluate({&x}, {}, {3,2}); + ASSERT_EQ(result.status(), Status::OK()); + auto out = result.at(0); + + ASSERT_TRUE(exp.equalsTo(out)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestRoll_14) { + auto x = NDArrayFactory::create('c', {2, 3, 4}, { + 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24. + }); + auto shift = NDArrayFactory::create({1,1,1}); + auto axis = NDArrayFactory::create({0, 1, 2}); + + auto exp = NDArrayFactory::create('c', {2, 3, 4}, { + 24, 21, 22, 23, 16, 13, 14, 15, 20, 17, 18, 19, 12, 9, 10, 11, 4, 1, 2, 3, 8, 5, 6, 7 + }); +// ---------------------------------------------------------------- + sd::ops::roll op; + + auto result = op.evaluate({&x, &shift, &axis}); + ASSERT_EQ(result.status(), Status::OK()); + auto out = result.at(0); +// out->printIndexedBuffer("Output"); + //exp.printIndexedBuffer("Expect"); + + ASSERT_TRUE(exp.equalsTo(out)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TestRoll_15) { + auto x = NDArrayFactory::create({0.7788f, 0.8012f, 0.7244f, 0.2309f }); + auto shift = NDArrayFactory::create(2); + auto axis = NDArrayFactory::create(0); + + auto exp = NDArrayFactory::create({0.7244f, 0.2309f, 0.7788f, 0.8012f }); +// ---------------------------------------------------------------- + sd::ops::roll op; + + auto result = op.evaluate({&x, &shift, &axis}); + ASSERT_EQ(result.status(), Status::OK()); + auto out = result.at(0); +// out->printIndexedBuffer("Output 15"); +// exp.printIndexedBuffer("Expect 15"); + + ASSERT_TRUE(exp.equalsTo(out)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, percentile_test1) { + + const int dim0=5, dim1=5, dim2=4; + + auto input = NDArrayFactory::create('c', {dim0, dim1, dim2}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., + 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., + 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., + 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., + 82., 90., 91., 89., 92., 34., 35., 33., 36.}); + auto expected = NDArrayFactory::create(50.); + + sd::ops::percentile op; + + auto result = op.evaluate({&input}, {50.}, {}); + auto output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, percentile_test2) { + + const int dim0=5, dim1=5, dim2=4; + + auto input = NDArrayFactory::create('c', {dim0, dim1, dim2}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., + 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., + 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., + 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., + 82., 90., 91., 89., 92., 34., 35., 33., 36.}); + auto expected = NDArrayFactory::create('c', {1,1,1}, {11.}); + + sd::ops::percentile op; + //q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 2, 1}, {}); + auto output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, percentile_test3) { + + const int dim0=5, dim1=5, dim2=4; + + auto input = NDArrayFactory::create('c', {dim0, dim1, dim2}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., + 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., + 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., + 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., + 82., 90., 91., 89., 92., 34., 35., 33., 36.}); + auto expected = NDArrayFactory::create('c', {1,1,1}, {10.}); + + sd::ops::percentile op; + //q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 0, 1}, {}); + auto output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, percentile_test4) { + + const int dim0=5, dim1=5, dim2=4; + + auto input = NDArrayFactory::create('c', {dim0, dim1, dim2}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., + 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., + 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., + 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., + 82., 90., 91., 89., 92., 34., 35., 33., 36.}); + auto expected = NDArrayFactory::create('c', {1,1,1}, {11.}); + + sd::ops::percentile op; + //q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 1, 1}, {}); + auto output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, percentile_test5) { + + const int dim0=5, dim1=5, dim2=4; + + auto input = NDArrayFactory::create('c', {dim0, dim1, dim2}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., + 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., + 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., + 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., + 82., 90., 91., 89., 92., 34., 35., 33., 36.}); + + auto expected = NDArrayFactory::create('c', {1,1,4}, {12., 7., 11., 10.}); + + sd::ops::percentile op; + //q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 0, 1}, {0,1}); + auto output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, percentile_test6) { + + const int dim0=5, dim1=5, dim2=4; + + auto input = NDArrayFactory::create('c', {dim0, dim1, dim2}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., + 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., + 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., + 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., + 82., 90., 91., 89., 92., 34., 35., 33., 36.}); + + auto expected = NDArrayFactory::create('c', {1,1,4}, {16., 14., 15., 13.}); + + sd::ops::percentile op; + //q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 1, 1}, {0,1}); + auto output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, percentile_test7) { + + const int dim0=5, dim1=5, dim2=4; + + auto input = NDArrayFactory::create('c', {dim0, dim1, dim2}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., + 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., + 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., + 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., + 82., 90., 91., 89., 92., 34., 35., 33., 36.}); + + auto expected = NDArrayFactory::create('c', {1,1,4}, {12., 7., 11., 10.}); + + sd::ops::percentile op; + //q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 2, 1}, {0,1}); + auto output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, percentile_test8) { + + const int dim0=5, dim1=5, dim2=4; + + auto input = NDArrayFactory::create('c', {dim0, dim1, dim2}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., + 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., + 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., + 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., + 82., 90., 91., 89., 92., 34., 35., 33., 36.}); + + auto expected = NDArrayFactory::create('c', {4}, {12., 7., 11., 10.}); + + sd::ops::percentile op; + //q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 2, 0}, {0,1}); + auto output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, percentile_test9) { + + const int dim0=100; + + auto input = NDArrayFactory::create('c', {dim0}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., + 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., + 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., + 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., + 82., 90., 91., 89., 92., 34., 35., 33., 36.}); + + auto expected = NDArrayFactory::create(11.); + + sd::ops::percentile op; + //q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 2, 0}, {0}); + auto output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, percentile_test10) { + + const int dim0=100; + + auto input = NDArrayFactory::create('c', {dim0}, {6., 7., 83., 81., 84., 86., 87., 85., 88., 5., 8., 78., 79., 77., 80., 10., 16., 18., 19., 17., 20., 22., + 23., 21., 24., 26., 27., 25., 28., 30., 31., 29., 32., 38., 11., 9., 12., 14., 15., 13., 39., 37., 40., 42., 43., + 41., 44., 46., 47., 45., 48., 50., 51., 49., 52., 54., 55., 53., 56., 58., 59., 57., 60., 98., 99., 97.,100., 62., + 63., 61., 64., 66., 67., 65., 68., 70., 71., 69., 72., 74., 75., 73., 76., 2., 3., 1., 4., 94., 95., 93., 96., + 82., 90., 91., 89., 92., 34., 35., 33., 36.}); + + auto expected = NDArrayFactory::create('c', {1}, {11.}); + + sd::ops::percentile op; + //q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 2, 1}, {0}); + auto output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, percentile_test11) { + + const int dim0=1; + + auto input = NDArrayFactory::create('c', {dim0}, {100.}); + + auto expected = NDArrayFactory::create('c', {1}, {100.}); + + sd::ops::percentile op; + //q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 2, 1}, {0}); + auto output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, percentile_test12) { + + const int dim0=1; + + auto input = NDArrayFactory::create('c', {dim0}, {100.}); + + auto expected = NDArrayFactory::create(100.); + + sd::ops::percentile op; + //q, interpolation, keepDims + auto result = op.evaluate({&input}, {10, 2, 0}, {}); + auto output = result.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, transpose_test3) { + + auto input = NDArrayFactory::create('c', {5, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}); + auto exp = NDArrayFactory::create('c', {3, 5}, {1.f, 4.f, 7.f, 10.f, 13.f, 2.f, 5.f, 8.f, 11.f, 14.f, 3.f, 6.f, 9.f, 12.f, 15.f}); + + sd::ops::transpose op; + auto result = op.evaluate({&input}, {}, {}); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, rationaltanh_test1) { + + auto input = NDArrayFactory::create('c', {8}, {0, 1, 2, 3, 4, 5, 6, 7}); + NDArray exp = NDArrayFactory::create({0.000000, 0.998222, 1.516093, 1.658054, 1.695077, 1.706884, 1.711427, 1.713446}); + + sd::ops::rationaltanh op; + auto result = op.evaluate({&input}, {}, {}); + auto output = result.at(0); +// output->printIndexedBuffer("Output rationaltanh"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, rationaltanh_test2) { + + auto input = NDArrayFactory::create('c', {2,2,2}, {0, 1, 2, 3, 4, 5, 6, 7}); + NDArray exp = NDArrayFactory::create('c', {2,2,2}, {0.000000, 0.998222, 1.516093, 1.658054, 1.695077, 1.706884, 1.711427, 1.713446}); + + sd::ops::rationaltanh op; + auto result = op.evaluate({&input}, {}, {}); + auto output = result.at(0); +// output->printIndexedBuffer("Output rationaltanh"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, rationaltanh_test3) { + + auto input = NDArrayFactory::create('c', {2,2,2}, {0, 1, 2, 3, 4, 5, 6, 7}); + auto eps = NDArrayFactory::create('c', {2,2,2}, {1, 2, 3, 4, 5, 6, 7, 8}); + NDArray exp = NDArrayFactory::create('c', {2,2,2}, {1.143933, 1.605747, 0.795557, 0.261710, 0.095832, 0.041218, 0.020221, 0.010971}); + + sd::ops::rationaltanh_bp op; + auto result = op.evaluate({&input, &eps}, {}, {}); + auto output = result.at(0); +// output->printBuffer("Output rationaltanh BP"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, rectifiedtanh_test1) { + + auto input = NDArrayFactory::create('c', {2,2,2}, {0, 1, 2, 3, 4, 5, 6, 7}); + NDArray exp = NDArrayFactory::create('c', {2,2,2}, {0.000000, 0.761594, 0.964028, 0.995055, 0.999329, 0.999909, 0.999988, 0.999998}); + + sd::ops::rectifiedtanh op; + auto result = op.evaluate({&input}, {}, {}); + auto output = result.at(0); +// output->printIndexedBuffer("Output rectifiedtanh"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, rectifiedtanh_test2) { + + auto input = NDArrayFactory::create('c', {2,2,2}, {0, 1, 2, 3, 4, 5, 6, 7}); + auto eps = NDArrayFactory::create('c', {2,2,2}, {1, 2, 3, 4, 5, 6, 7, 8}); + NDArray exp = NDArrayFactory::create('c', {2,2,2}, {0.000000, 0.839949, 0.211952, 0.039464, 0.006705, 0.001089, 0.000172, 0.000027}); + + sd::ops::rectifiedtanh_bp op; + auto result = op.evaluate({&input, &eps}, {}, {}); + auto output = result.at(0); +// output->printBuffer("Output rectifiedtanh BP"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +TEST_F(DeclarableOpsTests7, RealDiv_1) { + + NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 4.f}); + NDArray y = NDArrayFactory::create('c', {1, 2}, {1.f,2.f}); + NDArray e = NDArrayFactory::create('c', {1, 2, 2}, {2.f, 1.f, 4.f, 2.f}); + + sd::ops::realdiv op; + auto result = op.evaluate({&x, &y}, {}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); +// z->printIndexedBuffer("OUtput RealDiv"); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(*z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, RealDiv_BP_1) { + + NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 4.f}); + NDArray y = NDArrayFactory::create('c', {1, 2}, {1.f, 2.f}); + NDArray e0 = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 5.f}); + NDArray e1 = NDArrayFactory::create('c', {1, 2}, {-14.f, -5.f}); + NDArray eps = NDArrayFactory::create('c', {1, 2, 2}, {1.f, 2.f, 3.f, 4.f}); + + sd::ops::realdiv_bp op; + auto result = op.evaluate({&x, &y, &eps}, {}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + + auto z0 = result.at(0); + auto z1 = result.at(1); +// z0->printShapeInfo("OUtput RealDiv BP0 shape"); +// z1->printShapeInfo("OUtput RealDiv BP1 shape"); +// z0->printIndexedBuffer("OUtput RealDiv BP0"); +// z1->printIndexedBuffer("OUtput RealDiv BP1"); +// ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e0.equalsTo(z0)); + ASSERT_TRUE(e1.equalsTo(z1)); + + +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, ShapesOf_1) { + + NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 4.f}); +// NDArray y = NDArrayFactory::create('c', {1, 2}, {1,2}); + NDArray e = NDArrayFactory::create({1, 2, 1}); + + sd::ops::shapes_of op; + auto result = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); +// z->printIndexedBuffer("OUtput RealDiv"); +// ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(*z)); + + +} + +//////////////////////////////////////////////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, ShapesOf_2) { + + NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 4.f}); + NDArray y = NDArrayFactory::create('c', {1, 2}, {1.f, 2.f}); + NDArray e0 = NDArrayFactory::create({1, 2, 1}); + NDArray e1 = NDArrayFactory::create({1, 2}); + + sd::ops::shapes_of op; + auto result = op.evaluate({&x, &y}, {}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + + auto z0 = result.at(0); + auto z1 = result.at(1); +// z0->printIndexedBuffer("OUtput shapes2"); +// z1->printIndexedBuffer("OUtput shapes2"); +// ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e0.equalsTo(z0)); + ASSERT_TRUE(e1.equalsTo(z1)); + + +} + +TEST_F(DeclarableOpsTests7, Size_1) { + + NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2.f, 4.f}); + NDArray y = NDArrayFactory::create('c', {5, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 7.f, 9.f, 10.f, 10.f, 11.f}); + NDArray e = NDArrayFactory::create(2); + + sd::ops::size op; + auto result = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); +// z->printIndexedBuffer("OUtput SIZE"); +/// ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(*z)); + + +} + +TEST_F(DeclarableOpsTests7, Size_2) { + + NDArray x = NDArrayFactory::create('c', {1, 2, 1}, {2, 4}); + NDArray y = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray e = NDArrayFactory::create(10); + + sd::ops::size op; + auto result = op.evaluate({&y}, {}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); +// z->printIndexedBuffer("OUtput SIZE"); +/// ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(*z)); + + +} + +TEST_F(DeclarableOpsTests7, Softplus_1) { + + NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray e = NDArrayFactory::create('c', {5, 2}, {1.3132616, 2.126928, 3.0485873, 4.01815, 5.0067153, 7.0009117, 9.000123, 10.000046, 10.000046, 11.000016}); + + sd::ops::softplus op; + auto result = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); +// z->printIndexedBuffer("OUtput Softplus"); +/// ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(*z)); + + +} + +TEST_F(DeclarableOpsTests7, Softplus_BP_1) { + + NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); +// NDArray e = NDArrayFactory::create('c', {5, 2}, {1.3132616, 2.126928, 3.0485873, 4.01815, 5.0067153, 7.0009117, 9.000123, 10.000046, 10.000046, 11.000016}); + NDArray eps = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,6,7,8, 9, 10}); + sd::ops::softplus ffOP; + sd::ops::softplus_bp bpOp; + const OpArgsHolder argsHolderFF({&x}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &eps}, {}, {}); + + bool gradOK = GradCheck::checkGrad(ffOP, bpOp, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(gradOK); +// +// auto z = result.at(0); +// z->printIndexedBuffer("OUtput Softplus"); +///// ASSERT_TRUE(e.isSameShape(z)); +// ASSERT_TRUE(e.equalsTo(*z)); +// +// +} + +TEST_F(DeclarableOpsTests7, Softsign_1) { + + NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray e = NDArrayFactory::create('c', {5, 2}, {0.5, 0.6666667, 0.75, 0.8, 0.8333333, 0.875, 0.9, 0.90909094, 0.90909094, 0.9166667}); + + sd::ops::softsign op; + auto result = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); +// z->printIndexedBuffer("OUtput Softsign"); +/// ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(*z)); + + +} + +TEST_F(DeclarableOpsTests7, Softsign_BP_1) { + + NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); +// NDArray e = NDArrayFactory::create('c', {5, 2}, {1.3132616f, 2.126928f, 3.0485873f, 4.01815f, 5.0067153f, 7.0009117f, 9.000123f, 10.000046f, 10.000046f, 11.000016f}); + NDArray eps = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,6,7,8, 9, 10}); + sd::ops::softsign ffOP; + sd::ops::softsign_bp bpOp; + const OpArgsHolder argsHolderFF({&x}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &eps}, {}, {}); + + bool gradOK = GradCheck::checkGrad(ffOP, bpOp, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(gradOK); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, fill_test2) { + + auto x = NDArrayFactory::create('c', {1,2}, {2, 2}); + auto v = NDArrayFactory::create(42.); + auto exp = NDArrayFactory::create('c', {2, 2},{42.f, 42.f, 42.f, 42.f}); + + sd::ops::fill op; + auto result = op.evaluate({&x, &v}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, fill_test3) { + + auto x = NDArrayFactory::create('c', {2}, {2, 2}); + auto v = NDArrayFactory::create(42.); + auto exp = NDArrayFactory::create('c', {2, 2}, {42.f, 42.f, 42.f, 42.f}); + + sd::ops::fill op; + auto result = op.evaluate({&x, &v}, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, ToggleBits_test1) { + + auto x = NDArrayFactory::create('c', {2}, {2, 2}); + auto exp = NDArrayFactory::create('c', {2}, {-3, -3}); + + sd::ops::toggle_bits op; + auto result = op.evaluate({&x}); + auto output = result.at(0); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); +// output->printIndexedBuffer("Toggled"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, ToggleBits_test2) { + + auto x = NDArrayFactory::create('c', {2}, {2, 2}); + auto y = NDArrayFactory::create('c', {2}, {1, 1}); + auto exp0 = NDArrayFactory::create('c', {2}, {-3, -3}); + auto exp1 = NDArrayFactory::create('c', {2}, {-2, -2}); + + sd::ops::toggle_bits op; + auto result = op.evaluate({&x, &y}); + auto output = result.at(0); + auto z = result.at(1); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); +// output->printIndexedBuffer("Toggled"); + ASSERT_TRUE(exp0.isSameShape(output)); + ASSERT_TRUE(exp0.equalsTo(output)); + ASSERT_TRUE(exp1.isSameShape(z)); + ASSERT_TRUE(exp1.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Truncatediv_test1) { + NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray y = NDArrayFactory::create('c', {5, 2}, {2,2,2,2,2,2,2,2, 2, 2}); + NDArray exp = NDArrayFactory::create('c', {5, 2}, {0.5, 1., 1.5, 2., 2.5, 3.5, 4.5, 5., 5., 5.5}); + + sd::ops::truncatediv op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); +// output->printIndexedBuffer("Toggled"); + ASSERT_TRUE(exp.isSameShape(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Truncatediv_test2) { + NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray y = NDArrayFactory::create('c', {1, 2}, {2,2}); + NDArray exp = NDArrayFactory::create('c', {5, 2}, {0.5, 1., 1.5, 2., 2.5, 3.5, 4.5, 5., 5., 5.5}); + + sd::ops::truncatediv op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); +// output->printIndexedBuffer("Toggled"); + ASSERT_TRUE(exp.isSameShape(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TypesConversion_test1) { + NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray expI = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray expL = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray expF = NDArrayFactory::create('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f}); + NDArray expF16 = NDArrayFactory::create('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f}); + + sd::ops::to_int32 op32; + sd::ops::to_int64 op64; + auto result32 = op32.evaluate({&x}, {}, {}); + auto result64 = op64.evaluate({&x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result32.status()); + ASSERT_EQ(ND4J_STATUS_OK, result64.status()); + auto out1 = result32.at(0); +// out1->printIndexedBuffer("OUT_I"); + auto out2 = result64.at(0); +// out2->printIndexedBuffer("OUT_L"); + +// output->printIndexedBuffer("Toggled"); + ASSERT_TRUE(expI.equalsTo(out1)); + ASSERT_TRUE(expL.equalsTo(out2)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TypesConversion_test2) { + NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray expF = NDArrayFactory::create('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f}); + NDArray expH = NDArrayFactory::create('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f}); + + sd::ops::to_float32 op32; + sd::ops::to_float16 op16; + auto result32 = op32.evaluate({&x}, {}, {}); + auto result16 = op16.evaluate({&x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result32.status()); + ASSERT_EQ(ND4J_STATUS_OK, result16.status()); + auto out1 = result32.at(0); +// out1->printIndexedBuffer("OUT_F"); + auto out2 = result16.at(0); +// out2->printIndexedBuffer("OUT_H"); + +// output->printIndexedBuffer("Toggled"); + ASSERT_TRUE(expF.equalsTo(out1)); + ASSERT_TRUE(expH.equalsTo(out2)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TypesConversion_test3) { + NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray exp32 = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray exp64 = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + + sd::ops::to_uint32 op32; + sd::ops::to_uint64 op64; + auto result32 = op32.evaluate({&x}, {}, {}); + auto result64 = op64.evaluate({&x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result32.status()); + ASSERT_EQ(ND4J_STATUS_OK, result64.status()); + auto out1 = result32.at(0); +// out1->printIndexedBuffer("OUT_U32"); + auto out2 = result64.at(0); +// out2->printIndexedBuffer("OUT_U64"); + +// output->printIndexedBuffer("Toggled"); + ASSERT_TRUE(exp32.equalsTo(out1)); + ASSERT_TRUE(exp64.equalsTo(out2)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, TypesConversion_test4) { + NDArray x = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + NDArray exp32 = NDArrayFactory::create('c', {5, 2}, {1.f,2.f,3.f,4.f,5.f,7.f,9.f,10.f, 10.f, 11.f}); + NDArray exp64 = NDArrayFactory::create('c', {5, 2}, {1,2,3,4,5,7,9,10, 10, 11}); + + sd::ops::to_float32 op32; + sd::ops::to_double op64; + auto result32 = op32.evaluate({&x}, {}, {}); + auto result64 = op64.evaluate({&x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result32.status()); + ASSERT_EQ(ND4J_STATUS_OK, result64.status()); + auto out1 = result32.at(0); + auto out2 = result64.at(0); + + ASSERT_TRUE(exp32.equalsTo(out1)); + ASSERT_TRUE(exp64.equalsTo(out2)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, mirrorPad_test1) { + + auto input = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 1, 2, 2}); + + auto exp = NDArrayFactory::create('c', {4, 7}, {2, 1, 1, 2, 3, 3, 2, 2, 1, 1, 2, 3, 3, 2, 5, 4, 4, 5, 6, 6, 5, 5, 4, 4, 5, 6, 6, 5}); + + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {1}); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, mirrorPad_test2) { + + auto input = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 1, 2, 2}); + + auto exp = NDArrayFactory::create('c', {4, 7}, {6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1, 6, 5, 4, 5, 6, 5, 4, 3, 2, 1, 2, 3, 2, 1}); + + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {0}); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, mirrorPad_test3) { + + auto input = NDArrayFactory::create('c', {3}, {1., 2., 3.}); + auto paddings = NDArrayFactory::create('c', {1,2}, {2, 2}); + + auto exp = NDArrayFactory::create('c', {7}, {2, 1, 1, 2, 3, 3, 2}); + + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {1}); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, mirrorPad_test4) { + + auto input = NDArrayFactory::create('c', {3}, {1., 2., 3.}); + auto paddings = NDArrayFactory::create('c', {2}, {2, 3}); + + auto exp = NDArrayFactory::create('c', {8}, {2, 1, 1, 2, 3, 3, 2, 1}); + + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {1}); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, mirrorPad_test5) { + + auto input = NDArrayFactory::create('c', {3}, {1., 2., 3.}); + auto paddings = NDArrayFactory::create('c', {2}, {2, 2}); + + auto exp = NDArrayFactory::create('c', {7}, {3, 2, 1, 2, 3, 2, 1}); + + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {0}); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, mirrorPad_test6) { + + auto input = NDArrayFactory::create(1.); + auto paddings = NDArrayFactory::create('c', {1,2,1,1}, {1, 1}); + + auto exp = NDArrayFactory::create('c', {3}, {1,1,1}); + + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {1}); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, mirrorPad_test7) { + + auto input = NDArrayFactory::create(1.); + auto paddings = NDArrayFactory::create('c', {2}, {1, 1}); + + auto exp = NDArrayFactory::create('c', {3}, {1,1,1}); + + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {1}); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, mirrorPad_test8) { + + auto input = NDArrayFactory::create('c', {1,3}, {1., 2., 3.}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 1, 3, 3}); + + auto exp = NDArrayFactory::create('c', {3,9}, {3, 2, 1, 1, 2, 3, 3, 2, 1, 3, 2, 1, 1, 2, 3, 3, 2, 1, 3, 2, 1, 1, 2, 3, 3, 2, 1}); + + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {1}); + ASSERT_EQ(result.status(), Status::OK()); + + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, mirrorPad_test9) { + + auto input = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {2, 2, 3, 3}); + + auto exp = NDArrayFactory::create('c', {6, 9}, {6, 5, 4, 4, 5, 6, 6, 5, 4, 3, 2, 1, 1, 2, 3, 3, 2, 1, 3, 2, 1, 1, 2, 3, 3, 2, 1, 6, 5, 4, 4, 5, 6, 6, 5, 4, 6, 5, 4, 4, 5, 6, 6, 5, 4, 3, 2, 1, 1, 2, 3, 3, 2, 1}); + + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {1}); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, mirrorPad_test10) { + + auto input = NDArrayFactory::create('c', {1,3}, {1., 2., 3.}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); + + auto exp = NDArrayFactory::create('c', {1,3}, {1., 2., 3.}); + + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {1}); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, mirrorPad_test11) { + + auto input = NDArrayFactory::create('c', {1,3}, {1., 2., 3.}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); + + auto exp = NDArrayFactory::create('c', {1,3}, {1., 2., 3.}); + + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {0}); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, mirrorPad_test12) { + + auto input = NDArrayFactory::create('c', {3}, {1., 2., 3.}); + auto paddings = NDArrayFactory::create('c', {2,1}, {0, 0}); + + auto exp = NDArrayFactory::create('c', {3}, {1., 2., 3.}); + + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {0}); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, mirrorPad_test13) { + + auto input = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {0, 0, 0, 0}); + + auto exp = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); + + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {0}); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, mirrorPad_test14) { + + auto input = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {1LL, 0LL, 0LL, 1LL}); + + auto exp = NDArrayFactory::create('c', {3, 4}, {4, 5, 6, 5, 1, 2, 3, 2, 4, 5, 6, 5}); + + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {0}); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, mirrorPad_test15) { + + auto input = NDArrayFactory::create('c', {2, 3}, {1., 2., 3., 4., 5., 6.}); + auto paddings = NDArrayFactory::create('c', {2, 2}, {1, 1, 0, 0}); + + auto exp = NDArrayFactory::create('c', {4, 3}, {1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6}); + + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {1}); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, mirrorPad_test16) { + + auto input = NDArrayFactory::create('c', {4,3,2}); + auto paddings = NDArrayFactory::create('c', {3,2}, {3,3,2,2,1,1}); + + auto exp = NDArrayFactory::create('c', {10,7,4}, {24., 23., 24., 23.,22., 21., 22., 21.,20., 19., 20., 19.,22., 21., 22., 21.,24., 23., 24., 23.,22., 21., 22., 21.,20., 19., 20., 19.,18., 17., 18., 17.,16., 15., 16., 15.,14., 13., 14., 13.,16., 15., 16., 15.,18., 17., 18., 17.,16., 15., 16., 15.,14., 13., 14., 13., + 12., 11., 12., 11.,10., 9., 10., 9., 8., 7., 8., 7.,10., 9., 10., 9.,12., 11., 12., 11.,10., 9., 10., 9., 8., 7., 8., 7., 6., 5., 6., 5., 4., 3., 4., 3., 2., 1., 2., 1., 4., 3., 4., 3., 6., 5., 6., 5., 4., 3., 4., 3., 2., 1., 2., 1., + 12., 11., 12., 11.,10., 9., 10., 9., 8., 7., 8., 7.,10., 9., 10., 9.,12., 11., 12., 11.,10., 9., 10., 9., 8., 7., 8., 7.,18., 17., 18., 17.,16., 15., 16., 15.,14., 13., 14., 13.,16., 15., 16., 15.,18., 17., 18., 17.,16., 15., 16., 15.,14., 13., 14., 13., + 24., 23., 24., 23.,22., 21., 22., 21.,20., 19., 20., 19.,22., 21., 22., 21.,24., 23., 24., 23.,22., 21., 22., 21.,20., 19., 20., 19.,18., 17., 18., 17.,16., 15., 16., 15.,14., 13., 14., 13.,16., 15., 16., 15.,18., 17., 18., 17.,16., 15., 16., 15.,14., 13., 14., 13., + 12., 11., 12., 11.,10., 9., 10., 9., 8., 7., 8., 7.,10., 9., 10., 9.,12., 11., 12., 11.,10., 9., 10., 9., 8., 7., 8., 7., 6., 5., 6., 5., 4., 3., 4., 3., 2., 1., 2., 1., 4., 3., 4., 3., 6., 5., 6., 5., 4., 3., 4., 3., 2., 1., 2., 1.}); + input.linspace(1.); + + sd::ops::mirror_pad op; + auto result = op.evaluate({&input, &paddings}, {}, {0}); + ASSERT_EQ(result.status(), Status::OK()); + auto output = result.at(0); + //output->printBuffer("VVV"); + //exp.printBuffer("EXP"); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_1) { + + auto input = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto exp = NDArrayFactory::create(120.f); + //************************************// + + sd::ops::reduce_sum op; + auto result = op.evaluate({&input}, {}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + //z->printIndexedBuffer("Result is "); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_2) { + + auto input = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto exp = NDArrayFactory::create({15.f, 40.f, 65.f}); + //************************************// + + sd::ops::reduce_sum op; + auto result = op.evaluate({&input}, {}, {1}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); +// z->printIndexedBuffer("Result is "); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_1) { + + auto input = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto exp = NDArrayFactory::create(1307674368000.f); + //************************************// + + sd::ops::reduce_prod op; + auto result = op.evaluate({&input}, {}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + //z->printIndexedBuffer("Result is "); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_2) { + + auto input = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto exp = NDArrayFactory::create({120.f, 30240.f, 360360.f}); + //************************************// + + sd::ops::reduce_prod op; + auto result = op.evaluate({&input}, {}, {1}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); +// z->printIndexedBuffer("Result is "); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_01) { + + auto x = NDArrayFactory::create('c', {2,3,4}); + auto exp = NDArrayFactory::create('c', {4}, {66.f, 72.f, 78.f, 84.f}); + x.linspace(1); + + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {}, {0,1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_02) { + + auto x = NDArrayFactory::create('c', {2,3,4}); + auto exp = NDArrayFactory::create('c', {1,1,4}, {66.f, 72.f, 78.f, 84.f}); + x.linspace(1); + + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_3) { + + auto x = NDArrayFactory::create('c', {2,3,4}); + auto exp = NDArrayFactory::create('c', {3}, {68.f, 100.f, 132.f}); + x.linspace(1); + + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_4) { + + auto x = NDArrayFactory::create('c', {2,3,4}); + auto exp = NDArrayFactory::create('c', {1,3,1}, {68.f, 100.f, 132.f}); + x.linspace(1); + + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_5) { + + auto x = NDArrayFactory::create('c', {2,3,4}); + auto exp = NDArrayFactory::create(300.f); + x.linspace(1); + + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_6) { + + auto x = NDArrayFactory::create('c', {2,3,4}); + auto exp = NDArrayFactory::create(300.f); + x.linspace(1); + + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {}, {0,1,2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_7) { + + auto x = NDArrayFactory::create('c', {2,3,4}); + auto exp = NDArrayFactory::create('c', {1,1,1}, {300.f}); + x.linspace(1); +// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {1.}, {0,1,2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_01) { + + auto x = NDArrayFactory::create('c', {2,3,2}); + auto exp = NDArrayFactory::create('c', {2}, {10395.f, 46080.f}); + x.linspace(1); + + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {}, {0,1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_02) { + + auto x = NDArrayFactory::create('c', {2,3,2}); + auto exp = NDArrayFactory::create('c', {1,1,2}, {10395.f, 46080.f}); + x.linspace(1); + + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_3) { + + auto x = NDArrayFactory::create('c', {2,3,2}); + auto exp = NDArrayFactory::create('c', {3}, {112.f, 1080.f, 3960.f}); + x.linspace(1); + + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_4) { + + auto x = NDArrayFactory::create('c', {2,3,2}); + auto exp = NDArrayFactory::create('c', {1,3,1}, {112.f, 1080.f, 3960.f}); + x.linspace(1); + + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_5) { + + auto x = NDArrayFactory::create('c', {2,3,2}); + auto exp = NDArrayFactory::create(479001600.f); + x.linspace(1); + + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_6) { + + auto x = NDArrayFactory::create('c', {2,3,2}); + auto exp = NDArrayFactory::create(479001600.f); + x.linspace(1); + + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {}, {0,1,2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_7) { + + auto x = NDArrayFactory::create('c', {2,3,2}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {479001600.f}); + x.linspace(1); +// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {1.}, {0,1,2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +TYPED_TEST(TypedDeclarableOpsTests7, Test_Pnorm_Once_Again) { + auto input = NDArrayFactory::create('c', {1, 1, 5, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f}); + auto exp = NDArrayFactory::create('c', {1, 1, 5, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f}); + + sd::ops::pnormpool2d op; + auto result = op.evaluate({&input}, {}, {1,1, 1,1, 0,0, 1,1,1, 3, 0}); + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_EQ(exp, *result.at(0)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Min_1) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + x.linspace(1); + + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Min_2) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1,1,4}, {1.f, 2.f, 3.f, 4.f}); + x.linspace(1); + + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Min_3) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {1.f, 5.f, 9.f}); + x.linspace(1); + + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Min_4) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1,3,1}, {1.f, 5.f, 9.f}); + x.linspace(1); + + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Min_5) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(1.f); + x.linspace(1); + + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Min_6) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(1.f); + x.linspace(1); + + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {}, {0,1,2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Min_7) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {1.f}); + x.linspace(1); +// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {1.}, {0,1,2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Max_1) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {21.f, 22.f, 23.f, 24.f}); + x.linspace(1); + + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {}, {0,1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); +// output->printShapeInfo("Output shape"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Max_2) { + + auto x = NDArrayFactory::create('c', {2,3,4}); + auto exp = NDArrayFactory::create('c', {1,1,4}, {21.f, 22.f, 23.f, 24.f}); + x.linspace(1); + + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Max_3) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {16.f, 20.f, 24.f}); + x.linspace(1); + + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Max_4) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1,3,1}, {16.f, 20.f, 24.f}); + x.linspace(1); + + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Max_5) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(24.f); + x.linspace(1); + + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Max_6) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(24.f); + x.linspace(1); + + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {}, {0,1,2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Max_7) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {24.f}); + x.linspace(1); +// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {1.}, {0,1,2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_1) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {66.f, 72.f, 78.f, 84.f}); + x.linspace(1); + + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {}, {0,1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_2) { + + auto x = NDArrayFactory::create('c', {2,3,4}); + auto exp = NDArrayFactory::create('c', {1,1,4}, {66.f, 72.f, 78.f, 84.f}); + x.linspace(1); + + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_3) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {68.f, 100.f, 132.f}); + x.linspace(1); + + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_4) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1,3,1}, {68.f, 100.f, 132.f}); + x.linspace(1); + + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_5) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(300.f); + x.linspace(1); + + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_6) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(300.f); + x.linspace(1); + + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {}, {0,1,2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_7) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {300.f}); + x.linspace(1); +// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {1.}, {0,1,2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_1) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); + x.linspace(1); + + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {}, {0,1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_2) { + + auto x = NDArrayFactory::create('c', {2,3,4}); + auto exp = NDArrayFactory::create('c', {1,1,4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); + x.linspace(1); + + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_3) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {29.597298f, 39.344631f, 49.759422f}); + x.linspace(1); + + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_4) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1,3,1}, {29.597298f, 39.344631f, 49.759422f}); + x.linspace(1); + + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_5) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(70.f); + x.linspace(1); + + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_6) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(70.f); + x.linspace(1); + + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {}, {0,1,2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_7) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {70.f}); + x.linspace(1); +// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {1.}, {0,1,2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_1) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {21.f, 22.f, 23.f, 24.f}); + x.linspace(1); + + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {}, {0,1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_2) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1,1,4}, {21.f, 22.f, 23.f, 24.f}); + x.linspace(1); + + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {1.f}, {0,1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_3) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {16.f, 20.f, 24.f}); + x.linspace(1); + + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {}, {0,2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_4) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 3, 1}, {16.f, 20.f, 24.f}); + x.linspace(1); + + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {1.f}, {0,2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_5) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(24.f); + x.linspace(1); + + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_6) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(24.f); + x.linspace(1); + + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_7) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {24.f}); + x.linspace(1); + + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {1.f}, {}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_1) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {1006.f, 1144.f, 1294.f, 1456.f}); + x.linspace(1); + + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {}, {0,1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_2) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1,1,4}, {1006.f, 1144.f, 1294.f, 1456.f}); + x.linspace(1); + + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {1.f}, {0,1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_3) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {876.f, 1548.f, 2476.f}); + x.linspace(1); + + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {}, {0,2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_4) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 3, 1}, {876.f, 1548.f, 2476.f}); + x.linspace(1); + + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {1.f}, {0,2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_5) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(4900.f); + x.linspace(1); + + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_6) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(4900.f); + x.linspace(1); + + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_7) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {4900.f}); + x.linspace(1); + + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {1.f}, {}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_BP_1) { + + auto input = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto eps = NDArrayFactory::create(0.5f); + auto exp = NDArrayFactory::create('c', {3, 4}, {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,0.5f}); + //************************************// + + sd::ops::reduce_sum_bp op; + auto result = op.evaluate({&input, &eps}, {}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); +// z->printIndexedBuffer("Result is "); +// z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_BP_2) { + + auto input = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto eps = NDArrayFactory::create('c', {1, 1}, {0.5f}); + auto exp = NDArrayFactory::create('c', {3, 4}, {0.5f, 0.5f, 0.5f, 0.5f, + 0.5f, 0.5f, 0.5f, 0.5f, + 0.5f, 0.5f, 0.5f,0.5f}); + //************************************// + + sd::ops::reduce_sum_bp op; + auto result = op.evaluate({&input, &eps}, {1.f}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); +// z->printIndexedBuffer("Result is "); +// z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_BP_3) { + + auto input = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, + 1.f, 2.f, 3.f, 4.f, + 1.f, 2.f, 3.f, 4.f}); + //************************************// + + sd::ops::reduce_sum_bp op; + auto result = op.evaluate({&input, &eps}, {}, {0}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); +// z->printIndexedBuffer("Result is "); +// z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Sum_BP_4) { + + auto input = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto eps = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, + 1.f, 2.f, 3.f, 4.f, + 1.f, 2.f, 3.f, 4.f}); + //************************************// + + sd::ops::reduce_sum_bp op; + auto result = op.evaluate({&input, &eps}, {1.f}, {0}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); +// z->printIndexedBuffer("Result is "); +// z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_BP_1) { + + auto input = NDArrayFactory::create('c', {3, 5}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}); + auto eps = NDArrayFactory::create(1307674368000.f); + //************************************// +// auto exp = NDArrayFactory::create('c', {3, 4}, {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,0.5f}); + //************************************// + auto exp = NDArrayFactory::create('c', {3, 5}, {1710012166826558903812096.f, 855006083413279451906048.f, 570004067618451974258688.f, + 427503041706639725953024.f, 342002454982589992140800.f, 285002033809225987129344.f, + 244287457550765131825152.f, 213751520853319862976512.f, 190001355872817324752896.f, + 171001227491294996070400.f, 155455648254341989531648.f, 142501016904612993564672.f, + 131539399526781282156544.f, 122143728775382565912576.f, 114000815325130245799936.f}); + + sd::ops::reduce_prod_bp op; + auto result = op.evaluate({&input, &eps}, {}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); +// z->printIndexedBuffer("Result is "); +// z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_BP_2) { + + auto input = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); + auto eps = NDArrayFactory::create(0.5f); + //************************************// +// auto exp = NDArrayFactory::create('c', {3, 4}, {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,0.5f}); + //************************************// + auto exp = NDArrayFactory::create('c', {3, 4}); + + sd::ops::reduce_prod_bp op; + sd::ops::reduce_prod op_exp; + auto res = op_exp.evaluate({&input}); + auto result = op.evaluate({&input, &eps}, {}, {}); + exp.assign(res.at(0)->e(0)); + exp /= input; + exp *= eps.e(0); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + //z->printIndexedBuffer("Result is "); + //exp.printIndexedBuffer("Expected"); +// z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_BP_3) { + + auto input = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); + auto eps = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + //************************************// + auto exp = NDArrayFactory::create('c', {3, 4}, {45.f, 120.f, 231.f, 384.f, 9.f, 40.f, 99.f, 192.f, 5.f, 24.f, 63.f, 128.f}); + + sd::ops::reduce_prod_bp op; + //sd::ops::reduce_prod op_exp; + auto result = op.evaluate({&input, &eps}, {1.f}, {0}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); +// z->printIndexedBuffer("Result is "); +// exp.printIndexedBuffer("Expected"); +// z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_BP_03) { + int ax = 0; + auto input = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); + auto eps = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + //************************************// + auto exp = NDArrayFactory::create('c', {3, 4}, {45.f, 120.f, 231.f, 384.f, 9.f, 40.f, 99.f, 192.f, 5.f, 24.f, 63.f, 128.f}); + auto axis = NDArrayFactory::create('c', {1}, {ax}); + sd::ops::reduce_prod_bp op; + //sd::ops::reduce_prod op_exp; + auto result = op.evaluate({&input, &eps, &axis}, {}, {}, {true}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); +// z->printIndexedBuffer("Result is "); +// exp.printIndexedBuffer("Expected"); +// z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_BP_4) { + + auto input = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); + auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + //************************************// + auto exp = NDArrayFactory::create('c', {3, 4}, {45.f, 120.f, 231.f, 384.f, 9.f, 40.f, 99.f, 192.f, 5.f, 24.f, 63.f, 128.f}); + + sd::ops::reduce_prod_bp op; + sd::ops::reduce_prod op_exp; +// auto res = op_exp.execute({&input}, {}, {}); + auto result = op.evaluate({&input, &eps}, {0.f}, {0}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); +// z->printIndexedBuffer("Result is "); +// exp.printIndexedBuffer("Expected"); +// z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); + +// +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Prod_BP_5) { + + auto input = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f}); + auto eps = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + //************************************// + auto exp = NDArrayFactory::create('c', {3, 4}, {24.f, 12.f, 8.f, 6.f, 672.f, 560.f, 480.f, 420.f, 3960.f, 3564.f, 3240.f, 2970.f}); + + sd::ops::reduce_prod_bp op; + sd::ops::reduce_prod op_exp; +// auto res = op_exp.execute({&input}, {}, {}); + auto result = op.evaluate({&input, &eps}, {0.f}, {1}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); +// z->printIndexedBuffer("Result is "); +// exp.printIndexedBuffer("Expected"); +// z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); + +// +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_1) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + exp.p(0, eps.e(0)); + exp.p(1, eps.e(1)); + exp.p(2, eps.e(2)); + exp.p(3, eps.e(3)); + x.linspace(1); +// x.printIndexedBuffer("Input is"); +// exp.printIndexedBuffer("Expected "); + sd::ops::reduce_min_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0, 1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_2) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {1, 1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + exp.p(0, eps.e(0)); + exp.p(1, eps.e(1)); + exp.p(2, eps.e(2)); + exp.p(3, eps.e(3)); + x.linspace(1); +// x.printIndexedBuffer("Input is"); +// exp.printIndexedBuffer("Expected "); + sd::ops::reduce_min_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {0, 1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_02) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {1, 1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + exp.p(0, eps.e(0)); + exp.p(1, eps.e(1)); + exp.p(2, eps.e(2)); + exp.p(3, eps.e(3)); + auto axes = NDArrayFactory::create({0,1}); + x.linspace(1); +// x.printIndexedBuffer("Input is"); +// exp.printIndexedBuffer("Expected "); + sd::ops::reduce_min_bp op; + auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {true}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_3) { + + auto x = NDArrayFactory::create('c', {3, 4}); + auto eps = NDArrayFactory::create('c', {1, 1}, {0.5f}); + auto exp = NDArrayFactory::create('c', {3, 4}); + x.linspace(1); + x.p(2,2, -1.f); + exp.p(2,2, 0.5f); + //x.printIndexedBuffer("Input is"); + // exp.printIndexedBuffer("Expected "); + sd::ops::reduce_min_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_4) { + + auto x = NDArrayFactory::create('c', {3, 4}); + auto eps = NDArrayFactory::create(0.5f); + auto exp = NDArrayFactory::create('c', {3, 4}); + x.linspace(1); + x.p(2,2, -1.f); + exp.p(2,2, 0.5f); +// x.printIndexedBuffer("Input is"); +// exp.printIndexedBuffer("Expected "); + sd::ops::reduce_min_bp op; + auto result = op.evaluate({&x, &eps}, {}, {}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_5) { + + auto x = NDArrayFactory::create('c', {4, 4}); + auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {4, 4}); + x.linspace(1); + x.p(0,0, -1.f); + x.p(1,1, -2.f); + x.p(2,2, -3.f); + x.p(3,3, -4.f); + exp.p(0,0, 1.f); + exp.p(1,1, 2.f); + exp.p(2,2, 3.f); + exp.p(3,3, 4.f); +// exp(2,2) = 0.5f; +// x.printIndexedBuffer("Input is"); +// exp.printIndexedBuffer("Expected "); + sd::ops::reduce_min_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Min_BP_6) { + + auto x = NDArrayFactory::create('c', {4, 4}); + auto eps = NDArrayFactory::create('c', {1,4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {4, 4}); + x.linspace(1); + x.p(0,0, -1.f); + x.p(1,1, -2.f); + x.p(2,2, -3.f); + x.p(3,3, -4.f); + exp.p(0,0, 1.f); + exp.p(1,1, 2.f); + exp.p(2,2, 3.f); + exp.p(3,3, 4.f); +// exp(2,2) = 0.5f; +// x.printIndexedBuffer("Input is"); +// exp.printIndexedBuffer("Expected "); + sd::ops::reduce_min_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {0}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Max_BP_1) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {4}, {21.f, 22.f, 23.f, 24.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + exp.p(20, eps.e(0)); + exp.p(21, eps.e(1)); + exp.p(22, eps.e(2)); + exp.p(23, eps.e(3)); + x.linspace(1); + // x.printIndexedBuffer("Input is"); + // exp.printIndexedBuffer("Expected "); + sd::ops::reduce_max_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0, 1}); + auto output = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Max_BP_2) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {1, 1, 4}, {21.f, 22.f, 23.f, 24.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + exp.p(20, eps.e(0)); + exp.p(21, eps.e(1)); + exp.p(22, eps.e(2)); + exp.p(23, eps.e(3)); + x.linspace(1); +// x.printIndexedBuffer("Input is"); +// exp.printIndexedBuffer("Expected "); + sd::ops::reduce_max_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {0, 1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Max_BP_02) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {1, 1, 4}, {21.f, 22.f, 23.f, 24.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + exp.p(20, eps.e(0)); + exp.p(21, eps.e(1)); + exp.p(22, eps.e(2)); + exp.p(23, eps.e(3)); + auto axes = NDArrayFactory::create({0, 1}); + x.linspace(1); +// x.printIndexedBuffer("Input is"); +// exp.printIndexedBuffer("Expected "); + sd::ops::reduce_max_bp op; + auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {true}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Max_BP_3) { + + auto x = NDArrayFactory::create('c', {4, 4}); + auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {4, 4}); + x.linspace(1); + x.p(0,0, 21.f); + x.p(1,1, 22.f); + x.p(2,2, 23.f); + x.p(3,3, 24.f); + exp.p(0,0, 1.f); + exp.p(1,1, 2.f); + exp.p(2,2, 3.f); + exp.p(3,3, 4.f); +// x.printIndexedBuffer("Input is"); +// exp.printIndexedBuffer("Expected "); + sd::ops::reduce_max_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Max_BP_4) { + + auto x = NDArrayFactory::create('c', {4, 4}); + auto eps = NDArrayFactory::create('c', {1,4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {4, 4}); + x.linspace(1); + x.p(0,0, 21.f); + x.p(1,1, 22.f); + x.p(2,2, 23.f); + x.p(3,3, 24.f); + exp.p(0,0, 1.f); + exp.p(1,1, 2.f); + exp.p(2,2, 3.f); + exp.p(3,3, 4.f); + +// x.printIndexedBuffer("Input is"); +// exp.printIndexedBuffer("Expected "); + sd::ops::reduce_max_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {0}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_BP_1) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create(5.f); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + x.p(12, -2.f); + x.p(20, -3.f); + exp.assign(5.f); + exp.p(12, -exp.e(12)); + exp.p(20, -exp.e(20)); + sd::ops::reduce_norm1_bp op; + auto result = op.evaluate({&x, &eps}, {}, {}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_BP_2) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create({1.f, 2.f, 3.f, 4.f}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {2, 3, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f,1.f, 2.f, 3.f, 4.f,1.f, 2.f, 3.f, 4.f}); + sd::ops::reduce_norm1_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0,1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + // exp.printIndexedBuffer("Expect is"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_BP_02) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create({1.f, 2.f, 3.f, 4.f}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {2, 3, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f,1.f, 2.f, 3.f, 4.f,1.f, 2.f, 3.f, 4.f}); + auto axes = NDArrayFactory::create({0,1}); + sd::ops::reduce_norm1_bp op; + auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {false}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Norm1_BP_3) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {1, 1, 4}, {1.f, 2.f, 3.f, 4.f}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {2, 3, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f,1.f, 2.f, 3.f, 4.f,1.f, 2.f, 3.f, 4.f}); + sd::ops::reduce_norm1_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {0,1}); + auto output = result.at(0); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_BP_1) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); + x.linspace(1); + + sd::ops::reduce_norm2_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0,1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(x.isSameShape(output)); + ASSERT_TRUE(x.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_BP_2) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {1, 1, 4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); + x.linspace(1); + + sd::ops::reduce_norm2_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {0,1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(x.isSameShape(output)); + ASSERT_TRUE(x.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_BP_02) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {1, 1, 4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); + auto axes = NDArrayFactory::create({0, 1}); + x.linspace(1); + + sd::ops::reduce_norm2_bp op; + auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {true}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(x.isSameShape(output)); + ASSERT_TRUE(x.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_BP_3) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {3}, {29.597298f, 39.344631f, 49.759422f}); + x.linspace(1); + + sd::ops::reduce_norm2_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0, 2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(x.isSameShape(output)); + ASSERT_TRUE(x.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Norm2_BP_4) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {1,3,1}, {29.597298f, 39.344631f, 49.759422f}); + x.linspace(1); + + sd::ops::reduce_norm2_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {0, 2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(x.isSameShape(output)); + ASSERT_TRUE(x.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_BP_1) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}, { 2.f, 8.f, 18.f, 32.f, + 10.f, 24.f, 42.f, 64.f, + 18.f, 40.f, 66.f, 96.f, + 26.f, 56.f, 90.f, 128.f, + 34.f, 72.f, 114.f, 160.f, + 42.f, 88.f, 138.f, 192.f}); + x.linspace(1); + + sd::ops::reduce_sqnorm_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0,1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_SquaredNorm_BP_01) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}, { 2.f, 8.f, 18.f, 32.f, + 10.f, 24.f, 42.f, 64.f, + 18.f, 40.f, 66.f, 96.f, + 26.f, 56.f, 90.f, 128.f, + 34.f, 72.f, 114.f, 160.f, + 42.f, 88.f, 138.f, 192.f}); + auto axes = NDArrayFactory::create({0, 1}); + x.linspace(1); + + sd::ops::reduce_sqnorm_bp op; + auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {false}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_1) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1); + exp.p(20, 1.f); + exp.p(21, 2.f); + exp.p(22, 3.f); + exp.p(23, 4.f); + + sd::ops::reduce_norm_max_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0,1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_2) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {1,1,4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1); + exp.p(20, 1.f); + exp.p(21, 2.f); + exp.p(22, 3.f); + exp.p(23, 4.f); + + sd::ops::reduce_norm_max_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {0,1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_02) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {1,1,4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + auto axes = NDArrayFactory::create({0,1}); + x.linspace(1); + exp.p(20, 1.f); + exp.p(21, 2.f); + exp.p(22, 3.f); + exp.p(23, 4.f); + + sd::ops::reduce_norm_max_bp op; + auto result = op.evaluate({&x, &eps, &axes}, {}, {}, {true}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_3) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1); + + exp.p(15, 1.f); + exp.p(19, 2.f); + exp.p(23, 3.f); + + sd::ops::reduce_norm_max_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0,2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_4) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {1, 3, 1}, {1.f, 2.f, 3.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1); + exp.p(15, 1.f); + exp.p(19, 2.f); + exp.p(23, 3.f); + sd::ops::reduce_norm_max_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {0,2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_5) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create(1.f); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1); + exp.p(23, 1.f); + sd::ops::reduce_norm_max_bp op; + auto result = op.evaluate({&x, &eps}, {}, {}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_6) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create(1.f); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1); + exp.p(23, 1.f); + + sd::ops::reduce_norm_max_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0, 1, 2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_NormMax_BP_7) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto eps = NDArrayFactory::create('c', {1, 1, 1}, {1.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1); + exp.p(23, 1.f); + sd::ops::reduce_norm_max_bp op; + auto result = op.evaluate({&x, &eps}, {1.f}, {}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Dot_BP_1) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto y = NDArrayFactory::create('c', {2, 3, 4}); + NDArray* z; // = NDArrayFactory::create('c', {4}); + auto eps = NDArrayFactory::create(1.f); +// auto exp = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1); + y.linspace(2); + + + sd::ops::reduce_dot_bp op; + auto result = op.evaluate({&x, &y, &eps}, {}, {}); + auto output = result.at(0); + auto outputX = result.at(1); + //tput->printIndexedBuffer("Result is"); + +// ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(x.equalsTo(outputX)); + ASSERT_TRUE(y.equalsTo(output)); + + +// delete z; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Dot_BP_2) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto y = NDArrayFactory::create('c', {2, 3, 4}); +// auto z; // = NDArrayFactory::create('c', {4}); + auto eps = NDArrayFactory::create('c', {2, 4}); + auto expX = NDArrayFactory::create('c', {2, 3, 4}, {2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, + 10.f, 12.f, 14.f, 16.f, 10.f, 12.f, 14.f, 16.f, 10.f, 12.f, 14.f, 16.f + }); + auto expY = NDArrayFactory::create('c', {2, 3, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f, 5.f, 6.f, 7.f, 8.f, 5.f, 6.f, 7.f, 8.f + }); + x.assign(1.f); + eps.linspace(1); + y.assign(2.f); + sd::ops::reduce_dot_bp op; + auto result = op.evaluate({&x, &y, &eps}, {}, {1}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + ASSERT_EQ(result.size(), 2); + auto outputX = result.at(0); + auto outputY = result.at(1); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(expX.equalsTo(outputX)); + ASSERT_TRUE(expY.equalsTo(outputY)); + + +// delete z; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Dot_BP_02) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto y = NDArrayFactory::create('c', {2, 3, 4}); +// auto z; // = NDArrayFactory::create('c', {4}); + auto eps = NDArrayFactory::create('c', {2, 4}); + auto expX = NDArrayFactory::create('c', {2, 3, 4}, {2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, 2.f, 4.f, 6.f, 8.f, + 10.f, 12.f, 14.f, 16.f, 10.f, 12.f, 14.f, 16.f, 10.f, 12.f, 14.f, 16.f + }); + auto expY = NDArrayFactory::create('c', {2, 3, 4}, {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f, 5.f, 6.f, 7.f, 8.f, 5.f, 6.f, 7.f, 8.f + }); + auto axis = NDArrayFactory::create('c', {1}, {1}); + x.assign(1.f); + eps.linspace(1); + y.assign(2.f); + sd::ops::reduce_dot_bp op; + auto result = op.evaluate({&x, &y, &eps, &axis}, {}, {}, {false}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + ASSERT_EQ(result.size(), 2); + auto outputX = result.at(0); + auto outputY = result.at(1); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(expX.equalsTo(outputX)); + ASSERT_TRUE(expY.equalsTo(outputY)); + + +// delete z; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, Test_Reduce_Dot_BP_3) { + + auto x = NDArrayFactory::create('c', {3, 4}); + auto y = NDArrayFactory::create('c', {3, 4}); + auto eps = NDArrayFactory::create('c', {3}); + auto expX = NDArrayFactory::create('c', {3, 4}, {2.f, 2.f, 2.f, 2.f, 4.f, 4.f, 4.f, 4.f, 6.f, 6.f, 6.f, 6.f}); + auto expY = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, 10.f, 12.f, 14.f, 16.f, 27.f, 30.f, 33.f, 36.f}); + x.linspace(1); + eps.linspace(1); + y.assign(2.f); + + sd::ops::reduce_dot_bp op; + auto result = op.evaluate({&x,&y, &eps}, {}, {1}); + auto outputX = result.at(0); + auto outputY = result.at(1); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(expX.equalsTo(outputX)); + ASSERT_TRUE(expY.equalsTo(outputY)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, cumsum_bp_1) { + + auto x = NDArrayFactory::create('c', {3, 4}); + auto eps = NDArrayFactory::create('c', {3, 4}); + auto exp = NDArrayFactory::create('c', {3, 4}, {12.f, 11.f, 10.f, 9.f, 8.f, 7.f, + 6.f, 5.f, 4.f, 3.f, 2.f, 1.f}); + x.linspace(1); + eps.assign(1.f); + + sd::ops::cumsum_bp op; + auto result = op.evaluate({&x, &eps}, {}, {0,0}); + auto output = result.at(0); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, cumsum_bp_2) { + auto x = NDArrayFactory::create('c', {3, 4}); + auto eps = NDArrayFactory::create('c', {3, 4}); + auto exp = NDArrayFactory::create('c', {3, 4}, { 11.f, 10.f, 9.f, 8.f, 7.f, 6.f, + 5.f, 4.f, 3.f, 2.f, 1.f, 0.f}); + x.linspace(1); + eps.assign(1.f); + + + sd::ops::cumsum_bp op; + auto result = op.evaluate({&x, &eps}, {}, {1,0}); + auto output = result.at(0); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests7, cumsum_bp_3) { + auto x = NDArrayFactory::create('c', {3, 4}); + auto eps = NDArrayFactory::create('c', {3, 4}); + auto exp = NDArrayFactory::create('c', {3, 4}); + + x.linspace(1); + exp.linspace(0); + eps.assign(1.f); + + sd::ops::cumsum_bp op; + auto result = op.evaluate({&x, &eps}, {}, {1,1}); + auto output = result.at(0); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(exp.equalsTo(output)); + + + +} + diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests8.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests8.cpp new file mode 100644 index 000000000..30d2408ff --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests8.cpp @@ -0,0 +1,3525 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com), created on 10.06.2018 +// + + +#include "testlayers.h" +#include +#include +#include +// #include + +using namespace sd; + + +class DeclarableOpsTests8 : public testing::Test { +public: + + DeclarableOpsTests8() { + printf("\n"); + fflush(stdout); + } +}; + +template +class TypedDeclarableOpsTests8 : public testing::Test { +public: + + TypedDeclarableOpsTests8() { + printf("\n"); + fflush(stdout); + } +}; + +typedef ::testing::Types TestingTypes; +TYPED_TEST_CASE(TypedDeclarableOpsTests8, TestingTypes); + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceVariance_test1) { + + auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.f}); + auto exp = NDArrayFactory::create('c', {4}, {602.2222f, 727.13885f, 993.5555f, 755.8889f}); + + sd::ops::reduce_variance op; + auto result = op.evaluate({&x}, {}, {0,1}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceVariance_test2) { + + auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.f}); + auto exp = NDArrayFactory::create('c', {1,1,4}, {602.2222f, 727.13885f, 993.5555f, 755.8889f}); + + sd::ops::reduce_variance op; + auto result = op.evaluate({&x}, {1.}, {0,1}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceVariance_test3) { + + auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.f}); + auto exp = NDArrayFactory::create('c', {3}, {900.9375f, 969.8594f, 424.1875f}); + + sd::ops::reduce_variance op; + auto result = op.evaluate({&x}, {}, {0,2}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceVariance_test4) { + + auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.f}); + auto exp = NDArrayFactory::create('c', {1,3,1}, {900.9375f, 969.8594f, 424.1875f}); + + sd::ops::reduce_variance op; + auto result = op.evaluate({&x}, {1.}, {0,2}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceVariance_test5) { + + auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.f}); + auto exp = NDArrayFactory::create(788.6927f); + + sd::ops::reduce_variance op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceVariance_test6) { + + auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); + auto exp = NDArrayFactory::create(788.6927f); + + sd::ops::reduce_variance op; + auto result = op.evaluate({&x}, {}, {0,1,2}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceVariance_test7) { + + auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); + auto exp = NDArrayFactory::create('c', {1,1,1}, {788.6927f}); + + sd::ops::reduce_variance op; + auto result = op.evaluate({&x}, {1.}, {0,1,2}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceVariance_test8) { + + auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); + auto exp = NDArrayFactory::create('c', {1,1,1}, {788.6927f}); + auto axes = NDArrayFactory::create({0, 1, 2}); + sd::ops::reduce_variance op; + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceStDev_test1) { + + auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); + auto exp = NDArrayFactory::create('c', {4}, {24.54022f, 26.96551f, 31.52072f, 27.49343f}); + + sd::ops::reduce_stdev op; + auto result = op.evaluate({&x}, {}, {0,1}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceStDev_test2) { + + auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); + auto exp = NDArrayFactory::create('c', {1,1,4}, {24.54022f, 26.96551f, 31.52072f, 27.49343f}); + + sd::ops::reduce_stdev op; + auto result = op.evaluate({&x}, {1.}, {0,1}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceStDev_test3) { + + auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); + auto exp = NDArrayFactory::create('c', {3}, {30.01562f, 31.14257f, 20.59581f}); + + sd::ops::reduce_stdev op; + auto result = op.evaluate({&x}, {}, {0,2}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceStDev_test4) { + + auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); + auto exp = NDArrayFactory::create('c', {1,3,1}, {30.01562f, 31.14257f, 20.59581f}); + + sd::ops::reduce_stdev op; + auto result = op.evaluate({&x}, {1.}, {0,2}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceStDev_test5) { + + auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); + auto exp = NDArrayFactory::create(28.08367f); + + sd::ops::reduce_stdev op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceStDev_test6) { + + auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); + auto exp = NDArrayFactory::create(28.08367f); + + sd::ops::reduce_stdev op; + auto result = op.evaluate({&x}, {}, {0,1,2}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceStDev_test7) { + + auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); + auto exp = NDArrayFactory::create('c', {1,1,1}, {28.08367f}); + + sd::ops::reduce_stdev op; + auto result = op.evaluate({&x}, {1.f}, {0,1,2}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceStDev_test8) { + + auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); + auto exp = NDArrayFactory::create('c', {4}, {26.88246f, 29.53924f, 34.52921f, 30.11755f}); + + sd::ops::reduce_stdev op; + auto result = op.evaluate({&x}, {0.f,1.f}, {0,1}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + // output->printBuffer("Reduced STDDEV"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceStDev_test08) { + + auto x = NDArrayFactory::create('c', {2,3,4}, {27.f,34.f,5.f,4.f,54.f,6.f,65.f,8.f,37.f,45.f,8.f,67.f,96.f,10.f,65.f,41.f,33.f,85.f,92.f,24.f,25.f,55.f,49.f,76.}); + auto exp = NDArrayFactory::create('c', {4}, {26.88246f, 29.53924f, 34.52921f, 30.11755f}); + auto axes = NDArrayFactory::create({0,1}); + sd::ops::reduce_stdev op; + auto result = op.evaluate({&x, &axes}, {}, {}, {false, true}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + // output->printBuffer("Reduced STDDEV08"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceVarianceBP_test1) { + + auto x = NDArrayFactory::create('c', {3,4}); + auto gradO1 = NDArrayFactory::create('c', {1,1}, {0.5f}); + auto gradO2 = NDArrayFactory::create(0.5f); + auto exp12 = NDArrayFactory::create('c', {3,4}, {-0.5f, -0.4090909f, -0.3181818f, -0.22727273f, -0.13636364f, -0.045454547f, 0.045454547f, 0.13636364f, 0.22727273f, 0.3181818f, 0.4090909f, 0.5f}); + auto exp34 = NDArrayFactory::create('c', {3,4}, {-0.45833334f, -0.375f, -0.29166666f, -0.20833333f, -0.125f, -0.041666668f, 0.041666668f, 0.125f, 0.20833333f, 0.29166666f, 0.375f, 0.45833334f}); + + x.linspace(1); + + sd::ops::reduce_variance_bp op; + + auto result = op.evaluate({&x, &gradO2}, {0,1}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + + result = op.evaluate({&x, &gradO1}, {1,1}, {}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + + result = op.evaluate({&x, &gradO2}, {0,0}, {}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); + + + result = op.evaluate({&x, &gradO1}, {1,0}, {}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceVarianceBP_test2) { + + auto x = NDArrayFactory::create('c', {3,4}); + auto gradO1 = NDArrayFactory::create('c', {1,4}, {1.f,2.f,3.f,4.f}); + auto gradO2 = NDArrayFactory::create('c', {4}, {1.,2.,3.,4.}); + auto exp12 = NDArrayFactory::create('c', {3,4}, {-2.666667f, -5.333333f, -8.000000f, -10.666667f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 2.666667f, 5.333333f, 8.000000f, 10.666667f}); + auto exp34 = NDArrayFactory::create('c', {3,4}, {-4.000000f, -8.000000f, -12.000000f, -16.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 4.000000f, 8.000000f, 12.000000f, 16.000000f}); + + x.linspace(1); + + sd::ops::reduce_variance_bp op; + + auto result = op.evaluate({&x, &gradO2}, {0,0}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + + result = op.evaluate({&x, &gradO1}, {1,0}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + + result = op.evaluate({&x, &gradO2}, {0,1}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); + + + result = op.evaluate({&x, &gradO1}, {1,1}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceVarianceBP_test02) { + + auto x = NDArrayFactory::create('c', {3,4}); + auto gradO1 = NDArrayFactory::create('c', {1,4}, {1.f,2.f,3.f,4.f}); + auto gradO2 = NDArrayFactory::create('c', {4}, {1.f,2.f,3.f,4.f}); + auto exp12 = NDArrayFactory::create('c', {3,4}, {-2.666667f, -5.333333f, -8.000000f, -10.666667f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 2.666667f, 5.333333f, 8.000000f, 10.666667f}); + auto exp34 = NDArrayFactory::create('c', {3,4}, {-4.000000f, -8.000000f, -12.000000f, -16.000000f, 0.000000f, 0.000000f, 0.000000f, 0.000000f, 4.000000f, 8.000000f, 12.000000f, 16.000000f}); + auto axes = NDArrayFactory::create({(int)0,}); + x.linspace(1); + + sd::ops::reduce_variance_bp op; + + auto result = op.evaluate({&x, &gradO2, &axes}, {}, {}, {false, false}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + + result = op.evaluate({&x, &gradO1, &axes}, {}, {}, {true, false}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + + result = op.evaluate({&x, &gradO2, &axes}, {}, {}, {false, true}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); + + + result = op.evaluate({&x, &gradO1, &axes}, {}, {}, {true, true}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceVarianceBP_test3) { + + auto x = NDArrayFactory::create('c', {3, 4}); + auto gradO1 = NDArrayFactory::create('c', {3, 1}, {1.f, 2.f, 3.f}); + auto gradO2 = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + auto exp12 = NDArrayFactory::create('c', {3, 4}, + {-0.750000f, -0.250000f, 0.250000f, 0.750000f, -1.500000f, -0.500000f, + 0.500000f, 1.500000f, -2.250000f, -0.750000f, 0.750000f, 2.250000f}); + auto exp34 = NDArrayFactory::create('c', {3, 4}, + {-1.000000f, -0.333333f, 0.333333f, 1.000000f, -2.000000f, -0.666667f, + 0.666667f, 2.000000f, -3.000000f, -1.000000f, 1.000000f, 3.000000f}); + + x.linspace(1); + + sd::ops::reduce_variance_bp op; + + auto result = op.evaluate({&x, &gradO2}, {0, 0}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + + result = op.evaluate({&x, &gradO1}, {1, 0}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + + result = op.evaluate({&x, &gradO2}, {0, 1}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); + + + result = op.evaluate({&x, &gradO1}, {1, 1}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceStDevBP_test1) { + + auto x = NDArrayFactory::create('c', {3,4}); + auto gradO1 = NDArrayFactory::create('c', {1,1}, {0.5f}); + auto gradO2 = NDArrayFactory::create(0.5f); + auto exp12 = NDArrayFactory::create('c', {3,4}, {-0.069337524f, -0.056730703f, -0.04412388f, -0.031517055f, -0.018910235f, -0.0063034114f, 0.0063034114f, 0.018910235f, 0.031517055f, 0.04412388f, 0.056730703f, 0.069337524f}); + auto exp34 = NDArrayFactory::create('c', {3,4}, {-0.06638563f, -0.05431551f, -0.0422454f, -0.030175284f, -0.01810517f, -0.006035057f, 0.006035057f, 0.01810517f, 0.030175284f, 0.0422454f, 0.05431551f, 0.06638563f}); + + x.linspace(1); + + sd::ops::reduce_stdev_bp op; + + auto result = op.evaluate({&x, &gradO2}, {0,1}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + // output->printIndexedBuffer(); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + + result = op.evaluate({&x, &gradO1}, {1,1}, {}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + + result = op.evaluate({&x, &gradO2}, {0,0}, {}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); + + + result = op.evaluate({&x, &gradO1}, {1,0}, {}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceStDevBP_test2) { + + auto x = NDArrayFactory::create('c', {3,4}); + auto gradO1 = NDArrayFactory::create('c', {1,4}, {1.f,2.f,3.f,4.f}); + auto gradO2 = NDArrayFactory::create('c', {4}, {1.f,2.f,3.f,4.f}); + auto exp12 = NDArrayFactory::create('c', {3,4}, {-0.4082483f, -0.8164966f, -1.2247449f, -1.6329932f, 0.0, 0.0, 0.0, 0.0, 0.4082483f, 0.8164966f, 1.2247449f, 1.6329932f}); + auto exp34 = NDArrayFactory::create('c', {3,4}, {-0.5f, -1.0f, -1.5f, -2.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.5f, 1.0f, 1.5f, 2.0f}); + + x.linspace(1); + + sd::ops::reduce_stdev_bp op; + + auto result = op.evaluate({&x, &gradO2}, {0,0}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + + result = op.evaluate({&x, &gradO1}, {1,0}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + + result = op.evaluate({&x, &gradO2}, {0,1}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); + + + result = op.evaluate({&x, &gradO1}, {1,1}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceStDevBP_test02) { + + int ax = 0; + auto x = NDArrayFactory::create('c', {3,4}); + auto gradO1 = NDArrayFactory::create('c', {1,4}, {1.f,2.f,3.f,4.f}); + auto gradO2 = NDArrayFactory::create('c', {4}, {1.f,2.f,3.f,4.f}); + auto exp12 = NDArrayFactory::create('c', {3,4}, {-0.4082483f, -0.8164966f, -1.2247449f, -1.6329932f, 0.0, 0.0, 0.0, 0.0, 0.4082483f, 0.8164966f, 1.2247449f, 1.6329932f}); + auto exp34 = NDArrayFactory::create('c', {3,4}, {-0.5f, -1.0f, -1.5f, -2.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.5f, 1.0f, 1.5f, 2.0f}); + auto axis = NDArrayFactory::create('c', {1}, {ax}); + x.linspace(1); + + sd::ops::reduce_stdev_bp op; + + auto result = op.evaluate({&x, &gradO2, &axis}, {}, {}, {false, false}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + + result = op.evaluate({&x, &gradO1, &axis}, {}, {}, {true, false}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + + result = op.evaluate({&x, &gradO2, &axis}, {}, {}, {false, true}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); + + + result = op.evaluate({&x, &gradO1, &axis}, {}, {}, {true, true}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceStDevBP_test3) { + + auto x = NDArrayFactory::create('c', {3,4}); + auto gradO1 = NDArrayFactory::create('c', {3,1}, {1.f,2.f,3.f}); + auto gradO2 = NDArrayFactory::create('c', {3}, {1.f,2.f,3.f}); + auto exp12 = NDArrayFactory::create('c', {3,4}, {-0.3354102f, -0.1118034f, 0.1118034f, 0.3354102f, -0.6708204f, -0.2236068f, 0.2236068f, 0.6708204f, -1.0062306f, -0.3354102f, 0.3354102f, 1.0062306f}); + auto exp34 = NDArrayFactory::create('c', {3,4}, {-0.38729835f, -0.12909944f, 0.12909944f, 0.38729835f, -0.7745967f, -0.2581989f, 0.2581989f, 0.7745967f, -1.161895f, -0.38729835f, 0.38729835f, 1.161895f}); + + x.linspace(1); + + sd::ops::reduce_stdev_bp op; + + auto result = op.evaluate({&x, &gradO2}, {0,0}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + + result = op.evaluate({&x, &gradO1}, {1,0}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp12.isSameShape(output)); + ASSERT_TRUE(exp12.equalsTo(output)); + + + result = op.evaluate({&x, &gradO2}, {0,1}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); + + + result = op.evaluate({&x, &gradO1}, {1,1}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp34.isSameShape(output)); + ASSERT_TRUE(exp34.equalsTo(output)); + + +} + + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_1) { + + auto input = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto exp = NDArrayFactory::create(120.f); + //************************************// + + sd::ops::reduce_sum op; + auto result = op.evaluate({&input}, {}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + //z->printIndexedBuffer("Result is "); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_2) { + + auto input = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto exp = NDArrayFactory::create({15.f, 40.f, 65.f}); + //************************************// + + sd::ops::reduce_sum op; + auto result = op.evaluate({&input}, {}, {1}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); +// z->printIndexedBuffer("Result is "); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_03) { + + auto input = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto exp = NDArrayFactory::create({15.f, 40.f, 65.f}); + auto axis = NDArrayFactory::create('c', {1}, {1}); + //************************************// + + sd::ops::reduce_sum op; + auto result = op.evaluate({&input, &axis}, {}, {}, {false}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + // z->printIndexedBuffer("Result is "); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_1) { + + auto input = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto exp = NDArrayFactory::create(1307674368000.f); + //************************************// + + sd::ops::reduce_prod op; + auto result = op.evaluate({&input}, {}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + //z->printIndexedBuffer("Result is "); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_2) { + + auto input = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto exp = NDArrayFactory::create({120.f, 30240.f, 360360.f}); + //************************************// + + sd::ops::reduce_prod op; + auto result = op.evaluate({&input}, {}, {1}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); +// z->printIndexedBuffer("Result is "); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_01) { + + auto x = NDArrayFactory::create('c', {2,3,4}); + auto exp = NDArrayFactory::create('c', {4}, {66.f, 72.f, 78.f, 84.f}); + x.linspace(1); + + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {}, {0,1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_02) { + + auto x = NDArrayFactory::create('c', {2,3,4}); + auto exp = NDArrayFactory::create('c', {1,1,4}, {66.f, 72.f, 78.f, 84.f}); + x.linspace(1); + + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_3) { + + auto x = NDArrayFactory::create('c', {2,3,4}); + auto exp = NDArrayFactory::create('c', {3}, {68.f, 100.f, 132.f}); + x.linspace(1); + + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_4) { + + auto x = NDArrayFactory::create('c', {2,3,4}); + auto exp = NDArrayFactory::create('c', {1,3,1}, {68.f, 100.f, 132.f}); + x.linspace(1); + + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_5) { + + auto x = NDArrayFactory::create('c', {2,3,4}); + auto exp = NDArrayFactory::create(300.f); + x.linspace(1); + + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_6) { + + auto x = NDArrayFactory::create('c', {2,3,4}); + auto exp = NDArrayFactory::create(300.f); + x.linspace(1); + + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {}, {0,1,2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_7) { + + auto x = NDArrayFactory::create('c', {2,3,4}); + auto exp = NDArrayFactory::create('c', {1,1,1}, {300.f}); + x.linspace(1); +// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_sum op; + auto result = op.evaluate({&x}, {1.}, {0,1,2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_01) { + + auto x = NDArrayFactory::create('c', {2,3,2}); + auto exp = NDArrayFactory::create('c', {2}, {10395.f, 46080.f}); + x.linspace(1); + + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {}, {0,1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_02) { + + auto x = NDArrayFactory::create('c', {2,3,2}); + auto exp = NDArrayFactory::create('c', {1,1,2}, {10395.f, 46080.f}); + x.linspace(1); + + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_3) { + + auto x = NDArrayFactory::create('c', {2,3,2}); + auto exp = NDArrayFactory::create('c', {3}, {112.f, 1080.f, 3960.f}); + x.linspace(1); + + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_4) { + + auto x = NDArrayFactory::create('c', {2,3,2}); + auto exp = NDArrayFactory::create('c', {1,3,1}, {112.f, 1080.f, 3960.f}); + x.linspace(1); + + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_04) { + + auto x = NDArrayFactory::create('c', {2,3,2}); + auto exp = NDArrayFactory::create('c', {1,3,1}, {112.f, 1080.f, 3960.f}); + auto axes = NDArrayFactory::create({0, 2}); + x.linspace(1); + + sd::ops::reduce_prod op; + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_5) { + + auto x = NDArrayFactory::create('c', {2,3,2}); + auto exp = NDArrayFactory::create(479001600.f); + x.linspace(1); + + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_6) { + + auto x = NDArrayFactory::create('c', {2,3,2}); + auto exp = NDArrayFactory::create(479001600.f); + x.linspace(1); + + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {}, {0,1,2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_7) { + + auto x = NDArrayFactory::create('c', {2,3,2}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {479001600.f}); + x.linspace(1); +// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_prod op; + auto result = op.evaluate({&x}, {1.}, {0,1,2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Min_1) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + x.linspace(1); + + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Min_2) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1,1,4}, {1.f, 2.f, 3.f, 4.f}); + x.linspace(1); + + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Min_3) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {1.f, 5.f, 9.f}); + x.linspace(1); + + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Min_4) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1,3,1}, {1.f, 5.f, 9.f}); + x.linspace(1); + + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Min_04) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1,3,1}, {1.f, 5.f, 9.f}); + auto axes = NDArrayFactory::create({0, 2}); + x.linspace(1); + + sd::ops::reduce_min op; + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Min_5) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(1.f); + x.linspace(1); + + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Min_6) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(1.f); + x.linspace(1); + + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {}, {0,1,2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Min_7) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {1.f}); + x.linspace(1); + // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_min op; + auto result = op.evaluate({&x}, {1.}, {0,1,2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Max_1) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {21.f, 22.f, 23.f, 24.f}); + x.linspace(1); + + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {}, {0,1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + // output->printShapeInfo("Output shape"); + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Max_2) { + + auto x = NDArrayFactory::create('c', {2,3,4}); + auto exp = NDArrayFactory::create('c', {1,1,4}, {21.f, 22.f, 23.f, 24.f}); + x.linspace(1); + + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Max_3) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {16.f, 20.f, 24.f}); + x.linspace(1); + + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Max_4) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1,3,1}, {16.f, 20.f, 24.f}); + x.linspace(1); + + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Max_04) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1,3,1}, {16.f, 20.f, 24.f}); + auto axes = NDArrayFactory::create({0, 2}); + x.linspace(1); + + sd::ops::reduce_max op; + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Max_5) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(24.f); + x.linspace(1); + + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Max_6) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(24.f); + x.linspace(1); + + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {}, {0,1,2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Max_7) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {24.f}); + x.linspace(1); +// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_max op; + auto result = op.evaluate({&x}, {1.}, {0,1,2}); + auto output = result.at(0); + // output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_1) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {66.f, 72.f, 78.f, 84.f}); + x.linspace(1); + + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {}, {0,1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_2) { + + auto x = NDArrayFactory::create('c', {2,3,4}); + auto exp = NDArrayFactory::create('c', {1,1,4}, {66.f, 72.f, 78.f, 84.f}); + x.linspace(1); + + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_3) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {68.f, 100.f, 132.f}); + x.linspace(1); + + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_4) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1,3,1}, {68.f, 100.f, 132.f}); + x.linspace(1); + + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_04) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1,3,1}, {68.f, 100.f, 132.f}); + auto axes = NDArrayFactory::create({0, 2}); + x.linspace(1); + + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_5) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(300.f); + x.linspace(1); + + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_6) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(300.f); + x.linspace(1); + + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {}, {0,1,2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Norm1_7) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {300.f}); + x.linspace(1); +// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_norm1 op; + auto result = op.evaluate({&x}, {1.}, {0,1,2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_1) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); + x.linspace(1); + + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {}, {0,1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_2) { + + auto x = NDArrayFactory::create('c', {2,3,4}); + auto exp = NDArrayFactory::create('c', {1,1,4}, {31.7175f, 33.823071f, 35.97221f, 38.15757f}); + x.linspace(1); + + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_3) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {29.597298f, 39.344631f, 49.759422f}); + x.linspace(1); + + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {}, {0, 2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_4) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1,3,1}, {29.597298f, 39.344631f, 49.759422f}); + x.linspace(1); + + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_04) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1,3,1}, {29.597298f, 39.344631f, 49.759422f}); + auto axes = NDArrayFactory::create({0,2}); + x.linspace(1); + + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_5) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(70.f); + x.linspace(1); + + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_6) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(70.f); + x.linspace(1); + + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {}, {0,1,2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Norm2_7) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {70.f}); + x.linspace(1); +// x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::reduce_norm2 op; + auto result = op.evaluate({&x}, {1.}, {0,1,2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_1) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {21.f, 22.f, 23.f, 24.f}); + x.linspace(1); + + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {}, {0,1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_2) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1,1,4}, {21.f, 22.f, 23.f, 24.f}); + x.linspace(1); + + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {1.f}, {0,1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_3) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {16.f, 20.f, 24.f}); + x.linspace(1); + + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {}, {0,2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_4) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 3, 1}, {16.f, 20.f, 24.f}); + x.linspace(1); + + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {1.f}, {0,2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_04) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 3, 1}, {16.f, 20.f, 24.f}); + auto axes = NDArrayFactory::create({0,2}); + x.linspace(1); + + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_5) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(24.f); + x.linspace(1); + + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_6) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(24.f); + x.linspace(1); + + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_NormMax_7) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {24.f}); + x.linspace(1); + + sd::ops::reduce_norm_max op; + auto result = op.evaluate({&x}, {1.f}, {}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_1) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {1006.f, 1144.f, 1294.f, 1456.f}); + x.linspace(1); + + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {}, {0,1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_2) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1,1,4}, {1006.f, 1144.f, 1294.f, 1456.f}); + x.linspace(1); + + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {1.f}, {0,1}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_3) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {3}, {876.f, 1548.f, 2476.f}); + x.linspace(1); + + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {}, {0,2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_4) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 3, 1}, {876.f, 1548.f, 2476.f}); + x.linspace(1); + + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {1.f}, {0,2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_04) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 3, 1}, {876.f, 1548.f, 2476.f}); + auto axes = NDArrayFactory::create({0, 2}); + x.linspace(1); + + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_5) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(4900.f); + x.linspace(1); + + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_6) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create(4900.f); + x.linspace(1); + + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {}, {0, 1, 2}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_SquaredNorm_7) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1, 1}, {4900.f}); + x.linspace(1); + + sd::ops::reduce_sqnorm op; + auto result = op.evaluate({&x}, {1.f}, {}); + auto output = result.at(0); +// output->printIndexedBuffer("Result is"); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_BP_1) { + + auto input = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto eps = NDArrayFactory::create(0.5f); + auto exp = NDArrayFactory::create('c', {3, 4}, {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,0.5f}); + //************************************// + + sd::ops::reduce_sum_bp op; + auto result = op.evaluate({&input, &eps}, {}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); +// z->printIndexedBuffer("Result is "); +// z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_BP_2) { + + auto input = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto eps = NDArrayFactory::create('c', {1, 1}, {0.5f}); + auto exp = NDArrayFactory::create('c', {3, 4}, {0.5f, 0.5f, 0.5f, 0.5f, + 0.5f, 0.5f, 0.5f, 0.5f, + 0.5f, 0.5f, 0.5f,0.5f}); + //************************************// + + sd::ops::reduce_sum_bp op; + auto result = op.evaluate({&input, &eps}, {1.f}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); +// z->printIndexedBuffer("Result is "); +// z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_BP_3) { + + auto input = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto eps = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, + 1.f, 2.f, 3.f, 4.f, + 1.f, 2.f, 3.f, 4.f}); + //************************************// + + sd::ops::reduce_sum_bp op; + auto result = op.evaluate({&input, &eps}, {}, {0}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); +// z->printIndexedBuffer("Result is "); +// z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_BP_4) { + + auto input = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto eps = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, + 1.f, 2.f, 3.f, 4.f, + 1.f, 2.f, 3.f, 4.f}); + //************************************// + + sd::ops::reduce_sum_bp op; + auto result = op.evaluate({&input, &eps}, {1.f}, {0}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); +// z->printIndexedBuffer("Result is "); +// z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Sum_BP_04) { + + int ax = 0; + auto input = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12.}); + auto eps = NDArrayFactory::create('c', {1, 4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {3, 4}, {1.f, 2.f, 3.f, 4.f, + 1.f, 2.f, 3.f, 4.f, + 1.f, 2.f, 3.f, 4.f}); + auto axis = NDArrayFactory::create('c', {1}, {ax}); + //************************************// + + sd::ops::reduce_sum_bp op; + auto result = op.evaluate({&input, &eps, &axis}, {}, {}, {true}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); +// z->printIndexedBuffer("Result is "); +// z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Reduce_Prod_BP_1) { + + auto input = NDArrayFactory::create('c', {3, 5}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}); + auto eps = NDArrayFactory::create(1307674368000.f); + //************************************// +// auto exp = NDArrayFactory::create('c', {3, 4}, {0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,0.5f}); + //************************************// + auto exp = NDArrayFactory::create('c', {3, 5}, {1710012166826558903812096.f, 855006083413279451906048.f, 570004067618451974258688.f, + 427503041706639725953024.f, 342002454982589992140800.f, 285002033809225987129344.f, + 244287457550765131825152.f, 213751520853319862976512.f, 190001355872817324752896.f, + 171001227491294996070400.f, 155455648254341989531648.f, 142501016904612993564672.f, + 131539399526781282156544.f, 122143728775382565912576.f, 114000815325130245799936.f}); + + sd::ops::reduce_prod_bp op; + auto result = op.evaluate({&input, &eps}, {}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); +// z->printIndexedBuffer("Result is "); +// z->printShapeInfo(); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceMean_test1) { + + auto x = NDArrayFactory::create('c', {2,3,4}); + auto exp = NDArrayFactory::create('c', {4}, {11.f, 12.f, 13.f, 14.f}); + x.linspace(1); + + + sd::ops::reduce_mean op; + auto result = op.evaluate({&x}, {}, {0,1}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceMean_test2) { + + auto x = NDArrayFactory::create('c', {2,3,4}); + auto exp = NDArrayFactory::create('c', {1,1,4}, {11.f, 12.f, 13.f, 14.f}); + x.linspace(1); + + + sd::ops::reduce_mean op; + auto result = op.evaluate({&x}, {1.}, {0,1}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceMean_test3) { + + auto x = NDArrayFactory::create('c', {2,3,4}); + auto exp = NDArrayFactory::create('c', {3}, {8.5f, 12.5f, 16.5f}); + x.linspace(1); + + + sd::ops::reduce_mean op; + auto result = op.evaluate({&x}, {}, {0,2}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceMean_test4) { + + auto x = NDArrayFactory::create('c', {2,3,4}); + auto exp = NDArrayFactory::create('c', {1,3,1}, {8.5f, 12.5f, 16.5f}); + x.linspace(1); + + + sd::ops::reduce_mean op; + auto result = op.evaluate({&x}, {1.f}, {0,2}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceMean_test5) { + + auto x = NDArrayFactory::create('c', {2,3,4}); + auto exp = NDArrayFactory::create(12.5f); + x.linspace(1); + + + sd::ops::reduce_mean op; + auto result = op.evaluate({&x}, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceMean_test6) { + + auto x = NDArrayFactory::create('c', {2,3,4}); + auto exp = NDArrayFactory::create(12.5f); + x.linspace(1); + + sd::ops::reduce_mean op; + auto result = op.evaluate({&x}, {}, {0,1,2}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceMean_test7) { + + auto x = NDArrayFactory::create('c', {2,3,4}); + auto exp = NDArrayFactory::create('c', {1,1,1}, {12.5f}); + x.linspace(1); + + sd::ops::reduce_mean op; + auto result = op.evaluate({&x}, {1.}, {0,1,2}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceMean_test8) { + + auto x = NDArrayFactory::create('c', {2,3,4}); + auto exp = NDArrayFactory::create('c', {1,1,1}, {12.5f}); + auto axes = NDArrayFactory::create({0, 1, 2}); + x.linspace(1); + + sd::ops::reduce_mean op; + auto result = op.evaluate({&x, &axes}, {}, {}, {true}); + auto output = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceMeanBP_test1) { + + auto x = NDArrayFactory::create('c', {3,4}); + auto gradO1 = NDArrayFactory::create(0.5f); + auto gradO2 = NDArrayFactory::create('c', {1,1}, {0.5f}); + auto exp = NDArrayFactory::create('c', {3,4}, {1./24, 1./24, 1./24, 1./24, 1./24, 1./24, 1./24, 1./24, 1./24, 1./24, 1./24, 1./24}); + + x.linspace(1); + + sd::ops::reduce_mean_bp op; + + auto result = op.evaluate({&x, &gradO1}, {0}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + + // output->printShapeInfo("o"); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + + result = op.evaluate({&x, &gradO2}, {1}, {}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + +} + + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceMeanBP_test2) { + + auto x = NDArrayFactory::create('c', {3,4}); + auto gradO1 = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto gradO2 = NDArrayFactory::create('c', {1,4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {3,4}, {1.f/3.f, 2.f/3.f, 1.f, 4.f/3.f, 1.f/3.f, 2.f/3.f, 1.f, 4.f/3.f, 1.f/3.f, 2.f/3.f, 1.f, 4.f/3.f}); + + x.linspace(1); + + sd::ops::reduce_mean_bp op; + + auto result = op.evaluate({&x, &gradO1}, {0}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + + result = op.evaluate({&x, &gradO2}, {1}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceMeanBP_test02) { + + int ax = 0; + auto x = NDArrayFactory::create('c', {3,4}); + auto gradO1 = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto gradO2 = NDArrayFactory::create('c', {1,4}, {1.f, 2.f, 3.f, 4.f}); + auto exp = NDArrayFactory::create('c', {3,4}, {1.f/3.f, 2.f/3.f, 1.f, 4.f/3.f, 1.f/3.f, 2.f/3.f, 1.f, 4.f/3.f, 1.f/3.f, 2.f/3.f, 1.f, 4.f/3.f}); + auto axis = NDArrayFactory::create('c', {1}, {ax}); + x.linspace(1); + + sd::ops::reduce_mean_bp op; + + auto result = op.evaluate({&x, &gradO1, &axis}, {}, {}, {false}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + + result = op.evaluate({&x, &gradO2, &axis}, {}, {}, {true}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceMeanBP_test3) { + + auto x = NDArrayFactory::create('c', {3,4}); + auto gradO1 = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + auto gradO2 = NDArrayFactory::create('c', {3,1}, {1.f, 2.f, 3.f}); + auto exp = NDArrayFactory::create('c', {3,4}, {0.25f, 0.25f, 0.25f, 0.25f, 0.5f, 0.5f, 0.5f, 0.5f, 0.75f, 0.75f, 0.75f, 0.75f}); + + x.linspace(1); + + sd::ops::reduce_mean_bp op; + + auto result = op.evaluate({&x, &gradO1}, {0}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + + result = op.evaluate({&x, &gradO2}, {1}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceStDevBP_test4) { + + auto x = NDArrayFactory::create('c', {3}, {2.f, 3.f, 4.f}); + auto gradO = NDArrayFactory::create(0.5f); + auto exp = NDArrayFactory::create('c', {3}, {-0.25f, 0.f, 0.25f}); + + sd::ops::reduce_stdev_bp op; + + auto result = op.evaluate({&x, &gradO}, {0,1}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test1) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,0,0,1,0,1,0,1,1,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto expected = NDArrayFactory::create('c', {2,3}, {2.78507, 1.34254, 4.12761, 2.88507, 2.78507, 2.88507}); + + logits.linspace(0.1, 0.1); + + sd::ops::softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&logits, &labels}, {}, {}); + + ASSERT_EQ(Status::OK(), results.status()); + + auto *output = results.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test2) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,0,0,1,0,1,0,1,1,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto expected = NDArrayFactory::create('c', {3,4}, {0.26328, 1.46328, 1.72656, 0. , 0.26328, 0. , 1.46328, 0.26328, 1.72656, 0. , 1.72656, 1.46328}); + + logits.linspace(0.1, 0.1); + + sd::ops::softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&logits, &labels}, {}, {0}); + + ASSERT_EQ(Status::OK(), results.status()); + + auto *output = results.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test3) { + + auto labels = NDArrayFactory::create('c', {2,3,4},{0,1,1,0,0,0,1,0,1,0,1,1,1,0,1,0,1,0,0,1,1,0,1,0}); + auto logits = NDArrayFactory::create('c', {2,3,4}); + auto expected = NDArrayFactory::create('c', {2,4}, {0.75125, 1.55125, 3.45375, 0.75125, 3.45375, 0. , 2.3025 , 1.15125}); + + logits.linspace(0.1, 0.1); + + sd::ops::softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&logits, &labels}, {}, {1}); + + ASSERT_EQ(Status::OK(), results.status()); + + auto *output = results.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test4) { + + auto labels = NDArrayFactory::create('c', {2,3},{0,1,1,0,0,1}); + auto logits = NDArrayFactory::create('c', {2,3}); + auto expected = NDArrayFactory::create('c', {2}, {2.10389, 1.00194}); + + logits.linspace(0.1, 0.1); + + sd::ops::softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&logits, &labels}, {}, {}); + + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test5) { + + auto labels = NDArrayFactory::create('c', {2,3},{0,1,1,0,0,1}); + auto logits = NDArrayFactory::create('c', {2,3}); + auto expected = NDArrayFactory::create('c', {3}, {0., 0.85436, 1.40871}); + + logits.linspace(0.1, 0.1); + + sd::ops::softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&logits, &labels}, {}, {0}); + + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test6) { + + auto labels = NDArrayFactory::create('c', {2,1}, {0,1}); + auto logits = NDArrayFactory::create('c', {2,1}); + auto expected = NDArrayFactory::create('c', {1}, {0.6444}); + + logits.linspace(0.1, 0.1); + + sd::ops::softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&logits, &labels}, {}, {0}); + + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test7) { + + auto labels = NDArrayFactory::create('c', {2,1}, {0,1}); + auto logits = NDArrayFactory::create('c', {2,1}); + auto expected = NDArrayFactory::create('c', {2}, {0., 0.}); + + logits.linspace(0.1, 0.1); + + sd::ops::softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&logits, &labels}, {}, {1}); + + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test8) { + + auto labels = NDArrayFactory::create('c', {2}, {0,1}); + auto logits = NDArrayFactory::create('c', {2}); + auto expected = NDArrayFactory::create(0.6444); + + logits.linspace(0.1, 0.1); + + sd::ops::softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&logits, &labels}, {}, {}); + + ASSERT_EQ(Status::OK(), results.status()); + + auto *output = results.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test9) { + + auto labels = NDArrayFactory::create('c', {1}, {0.}); + auto logits = NDArrayFactory::create('c', {1}, {0.2}); + auto expected = NDArrayFactory::create(0.); + + sd::ops::softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&logits, &labels}, {}, {}); + + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, softmax_cross_entropy_loss_with_logits_test10) { + + auto labels = NDArrayFactory::create('c', {1,2}, {0,1}); + auto logits = NDArrayFactory::create('c', {1,2}); + auto expected = NDArrayFactory::create('c', {2}, {0., 0.}); + + logits.linspace(0.1, 0.1); + + sd::ops::softmax_cross_entropy_loss_with_logits op; + auto results = op.evaluate({&logits, &labels}, {}, {0}); + + ASSERT_EQ(Status::OK(), results.status()); + + auto output = results.at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceMeanBP_test4) { + + auto x = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }); + auto gradO1 = NDArrayFactory::create('c', {4}, {1., 2., 3., 4.}); + auto gradO2 = NDArrayFactory::create('c', {1, 4}, {1., 2., 3., 4.}); + auto exp = NDArrayFactory::create('c', {3,4}, {0.333333, 0.666667, 1.000000, 1.333333, 0.333333, 0.666667, 1.000000, 1.333333, 0.333333, 0.666667, 1.000000, 1.333333}); + + sd::ops::reduce_mean_bp op; + + auto result = op.evaluate({&x, &gradO1}, {0}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + + result = op.evaluate({&x, &gradO2}, {1}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + +} + + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceMeanBP_test5) { + + auto x = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }); + auto gradO1 = NDArrayFactory::create('c', {3}, {1., 2., 3.}); + auto gradO2 = NDArrayFactory::create('c', {3, 1}, {1., 2., 3.}); + auto exp = NDArrayFactory::create('c', {3,4}, {0.2500,0.2500,0.2500,0.2500, 0.5000,0.5000,0.5000,0.5000, 0.7500,0.7500,0.7500,0.7500}); + + sd::ops::reduce_mean_bp op; + + auto result = op.evaluate({&x, &gradO1}, {0}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + + result = op.evaluate({&x, &gradO2}, {1}, {1}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + +} + + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, reduceStDevBP_test5) { + + auto x = NDArrayFactory::create('c', {3, 4}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12. }); + auto gradO1 = NDArrayFactory::create('c', {4}, {1., 2., 3., 4.}); + auto gradO2 = NDArrayFactory::create('c', {1, 4}, {1., 2., 3., 4.}); + auto exp = NDArrayFactory::create('c', {3,4}, {-0.408248, -0.816497, -1.224745, -1.632993, 0.000000, 0.000000, 0.000000, 0.000000, 0.408248, 0.816497, 1.224745, 1.632993}); + + sd::ops::reduce_stdev_bp op; + + auto result = op.evaluate({&x, &gradO1}, {0}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + auto output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + + result = op.evaluate({&x, &gradO2}, {1}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, zeros_as_test1) { + + auto x = NDArrayFactory::create(10.f); + auto y = NDArrayFactory::create(100.f); + auto exp = NDArrayFactory::create(0.f); + + sd::ops::zeros_as op; + + Nd4jStatus status = op.execute({&x}, {&y}, {}, {}, {}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(y.isSameShape(exp)); + ASSERT_TRUE(y.equalsTo(exp)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, zeros_as_test2) { + + auto x = NDArrayFactory::create(10.f); + //auto y = NDArrayFactory::create(100.f); + auto exp = NDArrayFactory::create(0.f); + + sd::ops::zeros_as op; + + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + auto y = result.at(0); + + ASSERT_TRUE(y->isSameShape(exp)); + ASSERT_TRUE(y->equalsTo(exp)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, ones_as_test1) { + + auto x = NDArrayFactory::create(10.); + auto y = NDArrayFactory::create(100.); + auto exp = NDArrayFactory::create(1.); + + sd::ops::ones_as op; + + Nd4jStatus status = op.execute({&x}, {&y}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(y.isSameShape(exp)); + ASSERT_TRUE(y.equalsTo(exp)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, ones_as_test2) { + + auto x = NDArrayFactory::create(10.); + //auto y = NDArrayFactory::create(100.); + auto exp = NDArrayFactory::create(1.); + + sd::ops::ones_as op; + + auto results = op.evaluate({&x}); + ASSERT_EQ(Status::OK(), results.status()); + auto y = results.at(0); + ASSERT_TRUE(y->isSameShape(exp)); + ASSERT_TRUE(y->equalsTo(exp)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, ones_as_test3) { + + auto x = NDArrayFactory::create(10.); + //auto y = NDArrayFactory::create(100.); + auto exp = NDArrayFactory::create(1.); + + sd::ops::ones_as op; + + auto results = op.evaluate({&x}, {}, {}, {}, {sd::DataType::INT32}); + ASSERT_EQ(Status::OK(), results.status()); + auto y = results.at(0); + + ASSERT_TRUE(y->isSameShape(exp)); + ASSERT_TRUE(y->equalsTo(exp)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, NormalizeMoments_SGO_1) { + + auto data = NDArrayFactory::create('c', {10, 10}); + data.linspace(1); + + auto means = data.reduceAlongDimension(reduce::Sum, {0}); + auto deviance = NDArrayFactory::create('c', {10}, {825., 825. , 825., 825., 825., 825., 825., 825., 825., 825. }); // data.varianceAlongDimension(variance::SummaryStatsVariance, false, {0}); // = NDArrayFactory::create('c', {10, 10}); + + auto counts = NDArrayFactory::create(10.0); + +// auto expMeans = NDArrayFactory::create('c', {10, 10}); + +// auto expDeviance = NDArrayFactory::create('c', {10, 10}); + auto squared = NDArrayFactory::create('c', {10, 10}); + data.applyTransform(transform::Square, squared); + auto ssSquared = squared.reduceAlongDimension(reduce::Sum, {0}); +// ssSquared->printBuffer("Sum squared"); +// squared.printBuffer("Squared"); + sd::ops::normalize_moments op; + auto results = op.evaluate({&counts, &means, &ssSquared}, {0.0}, {0}); + means /= counts; +// sd::ops::normalize_moments op; +// auto results = op.evaluate({&counts, means, deviance}, {0.0}, {}); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_EQ(results.size(), 2); + + auto outputMeans = results.at(0); + auto outputDeviance = results.at(1); + +// outputMeans->printIndexedBuffer("Means"); +// outputDeviance->printIndexedBuffer("Variance"); +// deviance.printIndexedBuffer("Expected"); +// means->printIndexedBuffer("Expected means"); + ASSERT_TRUE(means.isSameShape(outputMeans)); + ASSERT_TRUE(means.equalsTo(outputMeans)); + ASSERT_TRUE(deviance.isSameShape(outputDeviance)); + ASSERT_TRUE(deviance.equalsTo(outputDeviance)); + //delete deviance; +// ASSERT_TRUE(expMeans.isSameShape(outputMeans)); +// ASSERT_TRUE(expMeans.equalsTo(outputMeans)); +// ASSERT_TRUE(expMeans.isSameShape(outputDeviance)); +// ASSERT_TRUE(expDeviance.equalsTo(outputDeviance)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Moments_1) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto expMeans = NDArrayFactory::create('c', {4}, {11.f, 12.f, 13.f, 14.f}); + auto expVariance = NDArrayFactory::create('c', {4}, {46.666668f, 46.666668f, 46.66666f, 46.666668f}); + x.linspace(1); + + sd::ops::moments op; + auto result = op.evaluate({&x}, {}, {0, 1}); + + ASSERT_EQ(Status::OK(), result.status()); + + auto outputMeans = result.at(0); + auto outputVariance = result.at(1); + +// outputMeans->printIndexedBuffer("Means"); +// outputVariance->printIndexedBuffer("Variance"); +// outputMeans->printShapeInfo("Result shape"); + + +// ASSERT_TRUE(exp.isSameShape(output)); +// ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(expMeans.isSameShape(outputMeans)); + ASSERT_TRUE(expMeans.equalsTo(outputMeans)); + ASSERT_TRUE(expVariance.isSameShape(outputVariance)); + ASSERT_TRUE(expVariance.equalsTo(outputVariance)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Moments_2) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto expMeans = NDArrayFactory::create('c', {1,1,4}, {11.f, 12.f, 13.f, 14.f}); + auto expVariance = NDArrayFactory::create('c', {1,1,4}, {46.666668f, 46.666668f, 46.66666f, 46.666668f}); + x.linspace(1); + + sd::ops::moments op; + auto result = op.evaluate({&x}, {1.}, {0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto outputMeans = result.at(0); + auto outputVariance = result.at(1); + +// outputMeans->printIndexedBuffer("Means"); +// outputVariance->printIndexedBuffer("Variance"); +// outputMeans->printShapeInfo("Result shape"); + +// ASSERT_TRUE(exp.isSameShape(output)); +// ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(expMeans.isSameShape(outputMeans)); + ASSERT_TRUE(expMeans.equalsTo(outputMeans)); + ASSERT_TRUE(expVariance.isSameShape(outputVariance)); + ASSERT_TRUE(expVariance.equalsTo(outputVariance)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Moments_3) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto expMeans = NDArrayFactory::create('c', {3}, {8.5f, 12.5f, 16.5f}); + auto expVariance = NDArrayFactory::create('c', {3}, {37.25f, 37.25f, 37.25f}); + x.linspace(1); + + sd::ops::moments op; + auto result = op.evaluate({&x}, {}, {0, 2}); + ASSERT_EQ(Status::OK(), result.status()); + + auto outputMeans = result.at(0); + auto outputVariance = result.at(1); + +// outputMeans->printIndexedBuffer("Means"); +// outputVariance->printIndexedBuffer("Variance"); +// outputMeans->printShapeInfo("Result shape"); + +// ASSERT_TRUE(exp.isSameShape(output)); +// ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(expMeans.isSameShape(outputMeans)); + ASSERT_TRUE(expMeans.equalsTo(outputMeans)); + ASSERT_TRUE(expVariance.isSameShape(outputVariance)); + ASSERT_TRUE(expVariance.equalsTo(outputVariance)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Moments_4) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto expMeans = NDArrayFactory::create('c', {1,3,1}, {8.5f, 12.5f, 16.5f}); + auto expVariance = NDArrayFactory::create('c', {1,3,1}, {37.25f, 37.25f, 37.25f}); + x.linspace(1); + + sd::ops::moments op; + auto result = op.evaluate({&x}, {1.}, {0, 2}); + ASSERT_EQ(Status::OK(), result.status()); + + auto outputMeans = result.at(0); + auto outputVariance = result.at(1); + +// outputMeans->printIndexedBuffer("Means"); +// outputVariance->printIndexedBuffer("Variance"); +// outputMeans->printShapeInfo("Result shape"); + +// ASSERT_TRUE(exp.isSameShape(output)); +// ASSERT_TRUE(exp.equalsTo(output)); + ASSERT_TRUE(expMeans.isSameShape(outputMeans)); + ASSERT_TRUE(expMeans.equalsTo(outputMeans)); + ASSERT_TRUE(expVariance.isSameShape(outputVariance)); + ASSERT_TRUE(expVariance.equalsTo(outputVariance)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Moments_6) { + auto expMeans = NDArrayFactory::create(12.5f); + auto expVariance = NDArrayFactory::create(47.916668f); + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1); + + sd::ops::moments op; + auto result = op.evaluate({&x}, {}, {0,1,2}); + ASSERT_EQ(Status::OK(), result.status()); + + auto outputMeans = result.at(0); + auto outputVariance = result.at(1); + +// outputMeans->printIndexedBuffer("Means"); +// outputVariance->printIndexedBuffer("Variance"); +// outputMeans->printShapeInfo("Result shape"); + + ASSERT_TRUE(expMeans.isSameShape(outputMeans)); + ASSERT_TRUE(expMeans.equalsTo(outputMeans)); + ASSERT_TRUE(expVariance.isSameShape(outputVariance)); + ASSERT_TRUE(expVariance.equalsTo(outputVariance)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests8, Test_Moments_7) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + + auto expMeans = NDArrayFactory::create('c', {1,1,1}, {12.5f}); + auto expVariance = NDArrayFactory::create('c', {1,1,1}, {47.916668f}); + + x.linspace(1); + // x.printIndexedBuffer("Input with shape (2, 3, 4) is"); + sd::ops::moments op; + auto result = op.evaluate({&x}, {1.}, {0,1,2}); + ASSERT_EQ(Status::OK(), result.status()); + + auto outputMeans = result.at(0); + auto outputVariance = result.at(1); + +// outputMeans->printIndexedBuffer("Means"); +// outputVariance->printIndexedBuffer("Variance"); +// outputMeans->printShapeInfo("Result shape"); + ASSERT_TRUE(expMeans.isSameShape(outputMeans)); + ASSERT_TRUE(expMeans.equalsTo(outputMeans)); + ASSERT_TRUE(expVariance.isSameShape(outputVariance)); + ASSERT_TRUE(expVariance.equalsTo(outputVariance)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_01) { + + auto x = NDArrayFactory::create('c', {1, 1, 2, 5}, { 1.f, 2.f, 3.f, 4.f, 5.f, + 6.f, 7.f, 8.f, 9.f, 10.f} + ); + + auto exp = NDArrayFactory::create('c', {1, 1, 2, 5}, {0.2581989f, 0.3592106f, 0.40089184f, 0.53935987f, 0.70014f, 0.4898979f, 0.46056613f, 0.43971977f, 0.5240003f, 0.6375767f}// 0.72760683, 0.4850712, 0.5848977, 0.67488194, +// 0.7581754, 0.58321184, 0.86747235, 0.4048204} + ); + + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); + auto out = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + //ASSERT_TRUE(exp.isSameShape(out)); + //out->printBuffer("LRN out"); + //exp.printBuffer("LRN exp"); + ASSERT_TRUE(exp.equalsTo(out)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_02) { + + auto x = NDArrayFactory::create('c', {1, 1, 1, 6}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + + auto exp = NDArrayFactory::create('c', {1, 1, 1, 6}, { + 0.2581989f, 0.3592106f, 0.40089184f, 0.4193139f, 0.5360563f, 0.67936623f} + ); + + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); + auto out = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + //ASSERT_TRUE(exp.isSameShape(out)); + //out->printIndexedBuffer("LRN out"); +// exp.printIndexedBuffer("LRN exp"); + ASSERT_TRUE(exp.equalsTo(out)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_03) { + + auto x = NDArrayFactory::create('c', {1, 1, 1, 10}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f}); + auto exp = NDArrayFactory::create('c', {1, 1, 1, 10}, {0.10425719f, 0.16843036f, 0.2095291f, 0.23652494f, 0.25449327f, 0.3053919f, 0.35675305f, 0.4098524f, 0.46662825f, 0.52999896f}); + + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {5}); + auto out = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(out)); + // out->printIndexedBuffer("LRN out"); +// exp.printIndexedBuffer("LRN exp"); + ASSERT_TRUE(exp.equalsTo(out)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_1) { + + auto x = NDArrayFactory::create('c', {2, 2, 2, 2}, { 5.5f, 0.f, 0.3f, 5.5f, + 8.6f, 0.f, 0.f, 0.4f, + 1.5f, 1.f, 1.3f, 1.5f, + 2.6f, 2.f, 3.f, 1.4f} + ); + + auto exp = NDArrayFactory::create('c', {2, 2, 2, 2}, { + 0.98386997f, 0.f, 0.05358852f, 0.9824562f, + 0.99330735f, 0.f, 0.f, 0.37139067f, + 0.72760683f, 0.4850712f, 0.5848977f, 0.67488194f, + 0.7581754f, 0.58321184f, 0.86747235f, 0.4048204f} + ); + + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); + auto out = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(out)); +// out->printIndexedBuffer("LRN out"); +// exp.printIndexedBuffer("LRN exp"); + ASSERT_TRUE(exp.equalsTo(out)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_2) { + + auto x = NDArrayFactory::create('c', {3, 3, 5, 5}); + x.linspace(1); + + auto exp = NDArrayFactory::create('c', {3, 3, 5, 5}, { + 0.2581989f, 0.3592106f, 0.40089184f, 0.53935987f, 0.70014f, + 0.4898979f, 0.46056613f, 0.43971977f, 0.5240002f, 0.6375767f, + 0.5274096f, 0.47771242f, 0.4443308f, 0.5163977f, 0.61701745f, + 0.5424508f, 0.48452914f, 0.44570294f, 0.5123918f, 0.6068971f, + 0.5505386f, 0.4881662f, 0.4462865f, 0.5099462f, 0.60088515f, + + 0.5555859f, 0.49042296f, 0.44658744f, 0.5083028f, 0.59690416f, + 0.55903524f, 0.4919585f, 0.44676256f, 0.5071239f, 0.59407425f, + 0.5615412f, 0.49307042f, 0.44687328f, 0.50623745f, 0.5919596f, + 0.56344414f, 0.49391258f, 0.4469477f, 0.5055468f, 0.59031945f, + 0.56493837f, 0.49457246f, 0.4470002f, 0.5049936f, 0.5890103f, + + 0.56614274f, 0.49510333f, 0.44703856f, 0.50454074f, 0.5879411f, + 0.567134f, 0.49553978f, 0.4470674f, 0.504163f, 0.5870515f, + 0.5679643f, 0.4959048f, 0.44708967f, 0.5038433f, 0.5862998f, + 0.56866974f, 0.4962146f, 0.44710726f, 0.5035692f, 0.58565617f, + 0.56927663f, 0.49648085f, 0.4471213f, 0.5033315f, 0.5850988f, + + + 0.56980413f, 0.49671215f, 0.44713274f, 0.50312346f, 0.58461165f, + 0.57026696f, 0.49691492f, 0.4471422f, 0.50293994f, 0.58418214f, + 0.5706764f, 0.49709415f, 0.44715008f, 0.5027767f, 0.5838005f, + 0.571041f, 0.4972537f, 0.44715673f, 0.50263065f, 0.58345926f, + 0.57136786f, 0.49739665f, 0.44716236f, 0.5024992f, 0.58315235f, + + 0.5716625f, 0.49752548f, 0.4471672f, 0.5023803f, 0.5828747f, + 0.5719295f, 0.49764213f, 0.44717142f, 0.5022721f, 0.5826225f, + 0.57217246f, 0.49774826f, 0.44717506f, 0.5021734f, 0.58239233f, + 0.5723947f, 0.4978453f, 0.44717824f, 0.5020829f, 0.58218133f, + 0.57259864f, 0.49793428f, 0.44718108f, 0.5019997f, 0.5819874f, + + 0.5727864f, 0.49801624f, 0.44718358f, 0.5019227f, 0.5818083f, + 0.57296f, 0.49809194f, 0.44718578f, 0.5018515f, 0.5816426f, + 0.5731208f, 0.49816203f, 0.44718775f, 0.5017854f, 0.58148885f, + 0.57327026f, 0.49822718f, 0.4471895f, 0.5017239f, 0.5813457f, + 0.57340944f, 0.49828786f, 0.44719115f, 0.5016664f, 0.581212f, + + + 0.57353944f, 0.4983446f, 0.44719255f, 0.50161266f, 0.58108705f, + 0.5736612f, 0.49839762f, 0.4471939f, 0.50156236f, 0.5809699f, + 0.5737754f, 0.4984474f, 0.44719502f, 0.501515f, 0.58085984f, + 0.5738828f, 0.49849418f, 0.4471962f, 0.50147045f, 0.5807564f, + 0.5739839f, 0.49853817f, 0.44719717f, 0.5014284f, 0.5806588f, + + 0.5740793f, 0.49857965f, 0.4471981f, 0.5013887f, 0.5805666f, + 0.5741694f, 0.49861887f, 0.44719887f, 0.50135124f, 0.58047944f, + 0.57425463f, 0.49865603f, 0.44719967f, 0.5013157f, 0.5803969f, + 0.5743354f, 0.4986912f, 0.44720036f, 0.5012819f, 0.5803186f, + 0.57441217f, 0.49872455f, 0.44720104f, 0.5012499f, 0.58024424f, + + 0.57448506f, 0.4987563f, 0.44720164f, 0.5012194f, 0.58017343f, + 0.57455444f, 0.4987865f, 0.4472022f, 0.5011904f, 0.5801061f, + 0.57462054f, 0.49881527f, 0.44720277f, 0.5011627f, 0.5800419f, + 0.57468355f, 0.49884263f, 0.44720328f, 0.50113624f, 0.5799805f, + 0.57474375f, 0.49886885f, 0.44720373f, 0.50111103f, 0.5799219f } + ); +// + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); + auto out = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); +// ASSERT_TRUE(exp.isSameShape(out)); +// out->printIndexedBuffer("LRN out"); +// exp.printIndexedBuffer("LRN exp"); + ASSERT_TRUE(exp.equalsTo(out)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_3) { + + auto x = NDArrayFactory::create('c', {3, 3, 5, 5}); + x.linspace(1); + + auto exp = NDArrayFactory::create('c', {3, 3, 5, 5}, { + 0.2581989f, 0.3592106f, 0.40089184f, 0.53935987f, 0.70014f, + 0.4898979f, 0.46056613f, 0.43971977f, 0.5240002f, 0.6375767f, + 0.5274096f, 0.47771242f, 0.4443308f, 0.5163977f, 0.61701745f, + 0.5424508f, 0.48452914f, 0.44570294f, 0.5123918f, 0.6068971f, + 0.5505386f, 0.4881662f, 0.4462865f, 0.5099462f, 0.60088515f, + + 0.5555859f, 0.49042296f, 0.44658744f, 0.5083028f, 0.59690416f, + 0.55903524f, 0.4919585f, 0.44676256f, 0.5071239f, 0.59407425f, + 0.5615412f, 0.49307042f, 0.44687328f, 0.50623745f, 0.5919596f, + 0.56344414f, 0.49391258f, 0.4469477f, 0.5055468f, 0.59031945f, + 0.56493837f, 0.49457246f, 0.4470002f, 0.5049936f, 0.5890103f, + + 0.56614274f, 0.49510333f, 0.44703856f, 0.50454074f, 0.5879411f, + 0.567134f, 0.49553978f, 0.4470674f, 0.504163f, 0.5870515f, + 0.5679643f, 0.4959048f, 0.44708967f, 0.5038433f, 0.5862998f, + 0.56866974f, 0.4962146f, 0.44710726f, 0.5035692f, 0.58565617f, + 0.56927663f, 0.49648085f, 0.4471213f, 0.5033315f, 0.5850988f, + + + 0.56980413f, 0.49671215f, 0.44713274f, 0.50312346f, 0.58461165f, + 0.57026696f, 0.49691492f, 0.4471422f, 0.50293994f, 0.58418214f, + 0.5706764f, 0.49709415f, 0.44715008f, 0.5027767f, 0.5838005f, + 0.571041f, 0.4972537f, 0.44715673f, 0.50263065f, 0.58345926f, + 0.57136786f, 0.49739665f, 0.44716236f, 0.5024992f, 0.58315235f, + + 0.5716625f, 0.49752548f, 0.4471672f, 0.5023803f, 0.5828747f, + 0.5719295f, 0.49764213f, 0.44717142f, 0.5022721f, 0.5826225f, + 0.57217246f, 0.49774826f, 0.44717506f, 0.5021734f, 0.58239233f, + 0.5723947f, 0.4978453f, 0.44717824f, 0.5020829f, 0.58218133f, + 0.57259864f, 0.49793428f, 0.44718108f, 0.5019997f, 0.5819874f, + + 0.5727864f, 0.49801624f, 0.44718358f, 0.5019227f, 0.5818083f, + 0.57296f, 0.49809194f, 0.44718578f, 0.5018515f, 0.5816426f, + 0.5731208f, 0.49816203f, 0.44718775f, 0.5017854f, 0.58148885f, + 0.57327026f, 0.49822718f, 0.4471895f, 0.5017239f, 0.5813457f, + 0.57340944f, 0.49828786f, 0.44719115f, 0.5016664f, 0.581212f, + + + 0.57353944f, 0.4983446f, 0.44719255f, 0.50161266f, 0.58108705f, + 0.5736612f, 0.49839762f, 0.4471939f, 0.50156236f, 0.5809699f, + 0.5737754f, 0.4984474f, 0.44719502f, 0.501515f, 0.58085984f, + 0.5738828f, 0.49849418f, 0.4471962f, 0.50147045f, 0.5807564f, + 0.5739839f, 0.49853817f, 0.44719717f, 0.5014284f, 0.5806588f, + + 0.5740793f, 0.49857965f, 0.4471981f, 0.5013887f, 0.5805666f, + 0.5741694f, 0.49861887f, 0.44719887f, 0.50135124f, 0.58047944f, + 0.57425463f, 0.49865603f, 0.44719967f, 0.5013157f, 0.5803969f, + 0.5743354f, 0.4986912f, 0.44720036f, 0.5012819f, 0.5803186f, + 0.57441217f, 0.49872455f, 0.44720104f, 0.5012499f, 0.58024424f, + + 0.57448506f, 0.4987563f, 0.44720164f, 0.5012194f, 0.58017343f, + 0.57455444f, 0.4987865f, 0.4472022f, 0.5011904f, 0.5801061f, + 0.57462054f, 0.49881527f, 0.44720277f, 0.5011627f, 0.5800419f, + 0.57468355f, 0.49884263f, 0.44720328f, 0.50113624f, 0.5799805f, + 0.57474375f, 0.49886885f, 0.44720373f, 0.50111103f, 0.5799219f } + ); +// + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); + auto out = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); +// ASSERT_TRUE(exp.isSameShape(out)); +// out->printIndexedBuffer("LRN out"); +// exp.printIndexedBuffer("LRN exp"); + ASSERT_TRUE(exp.equalsTo(out)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_4) { + + // auto x = NDArrayFactory::create('c', {8, 32, 64, 64}); + auto x = NDArrayFactory::create('c', {2, 8, 16, 16}); + x.linspace(1); + + sd::ops::lrn op; + auto results = op.evaluate({&x}, {1.0, 1.0, 0.5}, {2}); + auto out = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); +// ASSERT_TRUE(exp.isSameShape(out)); +// out->printIndexedBuffer("LRN out"); +// exp.printIndexedBuffer("LRN exp"); +// ASSERT_TRUE(exp.equalsTo(out)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_4_119) { + int iterations = 1000; + // auto x = NDArrayFactory::create('c', {8, 32, 64, 64}); + // auto z = NDArrayFactory::create('c', {8, 32, 64, 64}); + auto x = NDArrayFactory::create('c', {2, 8, 16, 16}); + auto z = NDArrayFactory::create('c', {2, 8, 16, 16}); + x.linspace(1); + + sd::ops::lrn op; + + op.execute({&x}, {&z}, {1.0, 1.0, 0.5}, {2}); + + auto timeStart = std::chrono::system_clock::now(); + + for (int e = 0; e < iterations; e++) + op.execute({&x}, {&z}, {1.0, 1.0, 0.5}, {2}); + + auto timeEnd = std::chrono::system_clock::now(); + auto spanTime = std::chrono::duration_cast ((timeEnd - timeStart) / iterations).count(); + auto ttlTime = std::chrono::duration_cast ((timeEnd - timeStart)).count(); + + +// ASSERT_TRUE(exp.isSameShape(out)); +// ASSERT_TRUE(exp.equalsTo(out)); +} + +//////////////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_BP_01) { + + auto x = NDArrayFactory::create( 'c', {1, 1, 1, 10}); + x.linspace(1); + auto eps = NDArrayFactory::create('c', {1,1,1,10}); + eps.linspace(1); +// +// auto exp = NDArrayFactory::create('c', {3,3,5,5}, { +// 0.238337, 0.309664, 0.334077, 0.376534, 0.342926, 0.370734, 0.362017, 0.354182, 0.379140, 0.376275, 0.380027, 0.368347, 0.356401, 0.378316, 0.381315, 0.382465, 0.370592, 0.357055, 0.377670, 0.382950, 0.383445, 0.371718, 0.357332, 0.377217, 0.383677, 0.383933, 0.372391, 0.357475, 0.376891, 0.384062, 0.384212, 0.372837, 0.357557, 0.376646, 0.384290, 0.384385, 0.373153, 0.357610, 0.376457, 0.384436, 0.384500, 0.373389, 0.357645, 0.376306, 0.384536, 0.384581, 0.373572, 0.357670, 0.376184, 0.384606, 0.384639, 0.373718, 0.357688, 0.376082, 0.384658, 0.384683, 0.373837, 0.357702, 0.375996, 0.384698, 0.384717, 0.373935, 0.357712, 0.375923, 0.384728, 0.384743, 0.374019, 0.357721, 0.375860, 0.384752, 0.384764, 0.374090, 0.357727, 0.375804, 0.384771, 0.384781, 0.374152, 0.357733, 0.375756, 0.384787, 0.384795, 0.374205, 0.357737, 0.375713, 0.384800, 0.384807, 0.374253, 0.357741, 0.375674, 0.384811, 0.384817, 0.374295, 0.357744, 0.375640, 0.384820, 0.384825, 0.374333, 0.357747, 0.375609, 0.384828, 0.384832, 0.374366, 0.357749, 0.375581, 0.384835, 0.384839, 0.374397, 0.357751, 0.375555, 0.384841, 0.384844, 0.374425, 0.357753, 0.375531, 0.384846, 0.384849, 0.374450, 0.357754, 0.375510, 0.384850, 0.384853, 0.374473, 0.357756, 0.375490, 0.384854, 0.384856, 0.374494, 0.357757, 0.375471, 0.384858, 0.384860, 0.374514, 0.357758, 0.375454, 0.384861, 0.384863, 0.374532, 0.357759, 0.375438, 0.384864, 0.384865, 0.374549, 0.357760, 0.375423, 0.384866, 0.384868, 0.374565, 0.357760, 0.375410, 0.384868, 0.384870, 0.374579, 0.357761, 0.375397, 0.384870, 0.384872, 0.374593, 0.357762, 0.375384, 0.384872, 0.384873, 0.374606, 0.357762, 0.375373, 0.384874, 0.384875, 0.374618, 0.357763, 0.375362, 0.384875, 0.384876, 0.374629, 0.357763, 0.375352, 0.384877, 0.384878, 0.374640, 0.357764, 0.375342, 0.384878, 0.384879, 0.374650, 0.357764, 0.375333, 0.384879, 0.384880, 0.374660, 0.357764, 0.375325, 0.384880, 0.384881, 0.374669, 0.357765, 0.375316, 0.384881, 0.384882, 0.374677, 0.357765, 0.375309, 0.384882, 0.384883, 0.374685, 0.357765, 0.375301, 0.384883, 0.384884, 0.374693, 0.357765, 0.375294, 0.384884, 0.384884, 0.374700, 0.357766, 0.375287, 0.384885, 0.384885, 0.374707, 0.357766, 0.375281, 0.384885, 0.384886, 0.374714, 0.357766, 0.375275, 0.384886} +// ); +/// + sd::ops::lrn_bp op; + auto results = op.evaluate({&x, &eps}, {1.0, 1.0, 0.5}, {5}); + auto out = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); +// ASSERT_TRUE(exp.isSameShape(out)); + //out->printBuffer("LRN BP out"); + //exp.printBuffer("LRN BP exp"); + //ASSERT_TRUE(exp.equalsTo(out)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_BP_02) { + + auto x = NDArrayFactory::create( 'c', {1, 1, 1, 10}); + x.linspace(1); + auto eps = NDArrayFactory::create('c', {1,1,1,10}); + eps.linspace(1); +// +// auto exp = NDArrayFactory::create('c', {3,3,5,5}, { +// 0.238337, 0.309664, 0.334077, 0.376534, 0.342926, 0.370734, 0.362017, 0.354182, 0.379140, 0.376275, 0.380027, 0.368347, 0.356401, 0.378316, 0.381315, 0.382465, 0.370592, 0.357055, 0.377670, 0.382950, 0.383445, 0.371718, 0.357332, 0.377217, 0.383677, 0.383933, 0.372391, 0.357475, 0.376891, 0.384062, 0.384212, 0.372837, 0.357557, 0.376646, 0.384290, 0.384385, 0.373153, 0.357610, 0.376457, 0.384436, 0.384500, 0.373389, 0.357645, 0.376306, 0.384536, 0.384581, 0.373572, 0.357670, 0.376184, 0.384606, 0.384639, 0.373718, 0.357688, 0.376082, 0.384658, 0.384683, 0.373837, 0.357702, 0.375996, 0.384698, 0.384717, 0.373935, 0.357712, 0.375923, 0.384728, 0.384743, 0.374019, 0.357721, 0.375860, 0.384752, 0.384764, 0.374090, 0.357727, 0.375804, 0.384771, 0.384781, 0.374152, 0.357733, 0.375756, 0.384787, 0.384795, 0.374205, 0.357737, 0.375713, 0.384800, 0.384807, 0.374253, 0.357741, 0.375674, 0.384811, 0.384817, 0.374295, 0.357744, 0.375640, 0.384820, 0.384825, 0.374333, 0.357747, 0.375609, 0.384828, 0.384832, 0.374366, 0.357749, 0.375581, 0.384835, 0.384839, 0.374397, 0.357751, 0.375555, 0.384841, 0.384844, 0.374425, 0.357753, 0.375531, 0.384846, 0.384849, 0.374450, 0.357754, 0.375510, 0.384850, 0.384853, 0.374473, 0.357756, 0.375490, 0.384854, 0.384856, 0.374494, 0.357757, 0.375471, 0.384858, 0.384860, 0.374514, 0.357758, 0.375454, 0.384861, 0.384863, 0.374532, 0.357759, 0.375438, 0.384864, 0.384865, 0.374549, 0.357760, 0.375423, 0.384866, 0.384868, 0.374565, 0.357760, 0.375410, 0.384868, 0.384870, 0.374579, 0.357761, 0.375397, 0.384870, 0.384872, 0.374593, 0.357762, 0.375384, 0.384872, 0.384873, 0.374606, 0.357762, 0.375373, 0.384874, 0.384875, 0.374618, 0.357763, 0.375362, 0.384875, 0.384876, 0.374629, 0.357763, 0.375352, 0.384877, 0.384878, 0.374640, 0.357764, 0.375342, 0.384878, 0.384879, 0.374650, 0.357764, 0.375333, 0.384879, 0.384880, 0.374660, 0.357764, 0.375325, 0.384880, 0.384881, 0.374669, 0.357765, 0.375316, 0.384881, 0.384882, 0.374677, 0.357765, 0.375309, 0.384882, 0.384883, 0.374685, 0.357765, 0.375301, 0.384883, 0.384884, 0.374693, 0.357765, 0.375294, 0.384884, 0.384884, 0.374700, 0.357766, 0.375287, 0.384885, 0.384885, 0.374707, 0.357766, 0.375281, 0.384885, 0.384886, 0.374714, 0.357766, 0.375275, 0.384886} +// ); +/// + sd::ops::lrn opFF; + sd::ops::lrn_bp opBP; + + const OpArgsHolder argsHolderFF({&x}, {1., 1., 0.5}, {5}); + const OpArgsHolder argsHolderBP({&x, &eps}, {1., 1., 0.5}, {5}); + + bool gradOK = true; //GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + //auto results = op.execute({&x, &eps}, {1.0, 1.0, 0.5}, {5}, {}, false, sd::DataType::DOUBLE); + //auto out = results.at(0); + + //ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(gradOK); + //out->printBuffer("LRN BP out"); + //exp.printBuffer("LRN BP exp"); + //ASSERT_TRUE(exp.equalsTo(out)); + + // +} + +//////////////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_BP_1) { + + auto x = NDArrayFactory::create( 'c', {3, 3, 5, 5}); + x.linspace(1); + auto eps = NDArrayFactory::create('c', {3,3,5,5}); + eps.linspace(1); +// +auto exp = NDArrayFactory::create('c', {3,3,5,5}, { + 0.238337f, 0.309664f, 0.334077f, 0.376534f, 0.342926f, 0.370734f, 0.362017f, 0.354182f, 0.379140f, 0.376275f, 0.380027f, 0.368347f, 0.356401f, 0.378316f, 0.381315f, 0.382465f, 0.370592f, 0.357055f, 0.377670f, 0.382950f, 0.383445f, 0.371718f, 0.357332f, 0.377217f, 0.383677f, 0.383933f, 0.372391f, 0.357475f, 0.376891f, 0.384062f, 0.384212f, 0.372837f, 0.357557f, 0.376646f, 0.384290f, 0.384385f, 0.373153f, 0.357610f, 0.376457f, 0.384436f, 0.384500f, 0.373389f, 0.357645f, 0.376306f, 0.384536f, 0.384581f, 0.373572f, 0.357670f, 0.376184f, 0.384606f, 0.384639f, 0.373718f, 0.357688f, 0.376082f, 0.384658f, 0.384683f, 0.373837f, 0.357702f, 0.375996f, 0.384698f, 0.384717f, 0.373935f, 0.357712f, 0.375923f, 0.384728f, 0.384743f, 0.374019f, 0.357721f, 0.375860f, 0.384752f, 0.384764f, 0.374090f, 0.357727f, 0.375804f, 0.384771f, 0.384781f, 0.374152f, 0.357733f, 0.375756f, 0.384787f, 0.384795f, 0.374205f, 0.357737f, 0.375713f, 0.384800f, 0.384807f, 0.374253f, 0.357741f, 0.375674f, 0.384811f, 0.384817f, 0.374295f, 0.357744f, 0.375640f, 0.384820f, 0.384825f, 0.374333f, 0.357747f, 0.375609f, 0.384828f, 0.384832f, 0.374366f, 0.357749f, 0.375581f, 0.384835f, 0.384839f, 0.374397f, 0.357751f, 0.375555f, 0.384841f, 0.384844f, 0.374425f, 0.357753f, 0.375531f, 0.384846f, 0.384849f, 0.374450f, 0.357754f, 0.375510f, 0.384850f, 0.384853f, 0.374473f, 0.357756f, 0.375490f, 0.384854f, 0.384856f, 0.374494f, 0.357757f, 0.375471f, 0.384858f, 0.384860f, 0.374514f, 0.357758f, 0.375454f, 0.384861f, 0.384863f, 0.374532f, 0.357759f, 0.375438f, 0.384864f, 0.384865f, 0.374549f, 0.357760f, 0.375423f, 0.384866f, 0.384868f, 0.374565f, 0.357760f, 0.375410f, 0.384868f, 0.384870f, 0.374579f, 0.357761f, 0.375397f, 0.384870f, 0.384872f, 0.374593f, 0.357762f, 0.375384f, 0.384872f, 0.384873f, 0.374606f, 0.357762f, 0.375373f, 0.384874f, 0.384875f, 0.374618f, 0.357763f, 0.375362f, 0.384875f, 0.384876f, 0.374629f, 0.357763f, 0.375352f, 0.384877f, 0.384878f, 0.374640f, 0.357764f, 0.375342f, 0.384878f, 0.384879f, 0.374650f, 0.357764f, 0.375333f, 0.384879f, 0.384880f, 0.374660f, 0.357764f, 0.375325f, 0.384880f, 0.384881f, 0.374669f, 0.357765f, 0.375316f, 0.384881f, 0.384882f, 0.374677f, 0.357765f, 0.375309f, 0.384882f, 0.384883f, 0.374685f, 0.357765f, 0.375301f, 0.384883f, 0.384884f, 0.374693f, 0.357765f, 0.375294f, 0.384884f, 0.384884f, 0.374700f, 0.357766f, 0.375287f, 0.384885f, 0.384885f, 0.374707f, 0.357766f, 0.375281f, 0.384885f, 0.384886f, 0.374714f, 0.357766f, 0.375275f, 0.384886f} + ); +/// + sd::ops::lrn_bp op; + auto results = op.evaluate({&x, &eps}, {1.0, 1.0, 0.5}, {2}, {}, {}, false); + auto out = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); +// ASSERT_TRUE(exp.isSameShape(out)); + // out->printBuffer("LRN BP out"); + // exp.printBuffer("LRN BP exp"); + //ASSERT_TRUE(exp.equalsTo(out)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TYPED_TEST(TypedDeclarableOpsTests8, LrnTest_BP_2) { + + auto x = NDArrayFactory::create( 'c', {3, 3, 5, 5}); + x.linspace(1); + + auto eps = NDArrayFactory::create('c', {3, 3, 5, 5}, { 0.2581989f, 0.3592106f, 0.40089184f, 0.53935987f, 0.70014f, + 0.4898979f, 0.46056613f, 0.43971977f, 0.5240002f, 0.6375767f, + 0.5274096f, 0.47771242f, 0.4443308f, 0.5163977f, 0.61701745f, + 0.5424508f, 0.48452914f, 0.44570294f, 0.5123918f, 0.6068971f, + 0.5505386f, 0.4881662f, 0.4462865f, 0.5099462f, 0.60088515f, + + 0.5555859f, 0.49042296f, 0.44658744f, 0.5083028f, 0.59690416f, + 0.55903524f, 0.4919585f, 0.44676256f, 0.5071239f, 0.59407425f, + 0.5615412f, 0.49307042f, 0.44687328f, 0.50623745f, 0.5919596f, + 0.56344414f, 0.49391258f, 0.4469477f, 0.5055468f, 0.59031945f, + 0.56493837f, 0.49457246f, 0.4470002f, 0.5049936f, 0.5890103f, + + 0.56614274f, 0.49510333f, 0.44703856f, 0.50454074f, 0.5879411f, + 0.567134f, 0.49553978f, 0.4470674f, 0.504163f, 0.5870515f, + 0.5679643f, 0.4959048f, 0.44708967f, 0.5038433f, 0.5862998f, + 0.56866974f, 0.4962146f, 0.44710726f, 0.5035692f, 0.58565617f, + 0.56927663f, 0.49648085f, 0.4471213f, 0.5033315f, 0.5850988f, + + + 0.56980413f, 0.49671215f, 0.44713274f, 0.50312346f, 0.58461165f, + 0.57026696f, 0.49691492f, 0.4471422f, 0.50293994f, 0.58418214f, + 0.5706764f, 0.49709415f, 0.44715008f, 0.5027767f, 0.5838005f, + 0.571041f, 0.4972537f, 0.44715673f, 0.50263065f, 0.58345926f, + 0.57136786f, 0.49739665f, 0.44716236f, 0.5024992f, 0.58315235f, + + 0.5716625f, 0.49752548f, 0.4471672f, 0.5023803f, 0.5828747f, + 0.5719295f, 0.49764213f, 0.44717142f, 0.5022721f, 0.5826225f, + 0.57217246f, 0.49774826f, 0.44717506f, 0.5021734f, 0.58239233f, + 0.5723947f, 0.4978453f, 0.44717824f, 0.5020829f, 0.58218133f, + 0.57259864f, 0.49793428f, 0.44718108f, 0.5019997f, 0.5819874f, + + 0.5727864f, 0.49801624f, 0.44718358f, 0.5019227f, 0.5818083f, + 0.57296f, 0.49809194f, 0.44718578f, 0.5018515f, 0.5816426f, + 0.5731208f, 0.49816203f, 0.44718775f, 0.5017854f, 0.58148885f, + 0.57327026f, 0.49822718f, 0.4471895f, 0.5017239f, 0.5813457f, + 0.57340944f, 0.49828786f, 0.44719115f, 0.5016664f, 0.581212f, + + + 0.57353944f, 0.4983446f, 0.44719255f, 0.50161266f, 0.58108705f, + 0.5736612f, 0.49839762f, 0.4471939f, 0.50156236f, 0.5809699f, + 0.5737754f, 0.4984474f, 0.44719502f, 0.501515f, 0.58085984f, + 0.5738828f, 0.49849418f, 0.4471962f, 0.50147045f, 0.5807564f, + 0.5739839f, 0.49853817f, 0.44719717f, 0.5014284f, 0.5806588f, + + 0.5740793f, 0.49857965f, 0.4471981f, 0.5013887f, 0.5805666f, + 0.5741694f, 0.49861887f, 0.44719887f, 0.50135124f, 0.58047944f, + 0.57425463f, 0.49865603f, 0.44719967f, 0.5013157f, 0.5803969f, + 0.5743354f, 0.4986912f, 0.44720036f, 0.5012819f, 0.5803186f, + 0.57441217f, 0.49872455f, 0.44720104f, 0.5012499f, 0.58024424f, + + 0.57448506f, 0.4987563f, 0.44720164f, 0.5012194f, 0.58017343f, + 0.57455444f, 0.4987865f, 0.4472022f, 0.5011904f, 0.5801061f, + 0.57462054f, 0.49881527f, 0.44720277f, 0.5011627f, 0.5800419f, + 0.57468355f, 0.49884263f, 0.44720328f, 0.50113624f, 0.5799805f, + 0.57474375f, 0.49886885f, 0.44720373f, 0.50111103f, 0.5799219f }); +// + auto exp = NDArrayFactory::create('c', {3,3,5,5}, { + 0.061538f, 0.055617f, 0.044643f, 0.050772f, 0.048019f, 0.030270f, 0.023819f, 0.019468f, 0.022074f, 0.023990f, 0.018221f, 0.014664f, 0.012182f, 0.013954f, 0.015685f, 0.012967f, 0.010563f, 0.008841f, 0.010185f, 0.011621f, 0.010052f, 0.008248f, 0.006934f, 0.008015f, 0.009222f, 0.008204f, 0.006764f, 0.005702f, 0.006606f, 0.007642f, 0.006929f, 0.005732f, 0.004841f, 0.005618f, 0.006523f, 0.005996f, 0.004973f, 0.004205f, 0.004887f, 0.005689f, 0.005284f, 0.004391f, 0.003717f, 0.004324f, 0.005044f, 0.004723f, 0.003931f, 0.003331f, 0.003877f, 0.004531f, 0.004270f, 0.003558f, 0.003017f, 0.003514f, 0.004112f, 0.003896f, 0.003250f, 0.002757f, 0.003213f, 0.003764f, 0.003582f, 0.002991f, 0.002539f, 0.002959f, 0.003470f, 0.003315f, 0.002770f, 0.002352f, 0.002743f, 0.003219f, 0.003085f, 0.002580f, 0.002191f, 0.002556f, 0.003002f, 0.002885f, 0.002414f, 0.002051f, 0.002393f, 0.002812f, 0.002709f, 0.002268f, 0.001927f, 0.002250f, 0.002645f, 0.002553f, 0.002138f, 0.001818f, 0.002122f, 0.002496f, 0.002415f, 0.002023f, 0.001720f, 0.002009f, 0.002363f, 0.002290f, 0.001920f, 0.001632f, 0.001906f, 0.002244f, 0.002178f, 0.001826f, 0.001553f, 0.001814f, 0.002136f, 0.002076f, 0.001741f, 0.001481f, 0.001731f, 0.002038f, 0.001984f, 0.001664f, 0.001416f, 0.001654f, 0.001949f, 0.001899f, 0.001593f, 0.001356f, 0.001584f, 0.001867f, 0.001821f, 0.001528f, 0.001301f, 0.001520f, 0.001792f, 0.001750f, 0.001469f, 0.001250f, 0.001461f, 0.001722f, 0.001683f, 0.001413f, 0.001203f, 0.001406f, 0.001658f, 0.001622f, 0.001362f, 0.001159f, 0.001355f, 0.001599f, 0.001565f, 0.001314f, 0.001119f, 0.001308f, 0.001543f, 0.001512f, 0.001270f, 0.001081f, 0.001264f, 0.001491f, 0.001462f, 0.001228f, 0.001046f, 0.001223f, 0.001443f, 0.001415f, 0.001189f, 0.001013f, 0.001184f, 0.001397f, 0.001372f, 0.001153f, 0.000982f, 0.001148f, 0.001355f, 0.001331f, 0.001118f, 0.000952f, 0.001114f, 0.001315f, 0.001292f, 0.001086f, 0.000925f, 0.001082f, 0.001277f, 0.001255f, 0.001055f, 0.000899f, 0.001051f, 0.001241f, 0.001221f, 0.001026f, 0.000874f, 0.001023f, 0.001208f, 0.001188f, 0.000999f, 0.000851f, 0.000996f, 0.001176f, 0.001157f, 0.000973f, 0.000829f, 0.000970f, 0.001145f, 0.001128f, 0.000949f, 0.000808f, 0.000945f, 0.001117f, 0.001100f, 0.000925f, 0.000788f, 0.000922f, 0.001089f, 0.001073f, 0.000903f, 0.000769f, 0.000900f, 0.001063f, 0.001048f, 0.000882f, 0.000751f, 0.000879f, 0.001038f, 0.001024f, 0.000861f, 0.000734f, 0.000859f, 0.001015f, 0.001001f, 0.000842f, 0.000717f, 0.000840f, 0.000992f} + // 0.009859f, 0.013075f, 0.013874f, 0.017893f, 0.022344f, 0.014551f, 0.012859f, 0.011511f, 0.013311f, 0.015834f, 0.012025f, 0.010047f, 0.008601f, 0.009920f, 0.011885f, 0.009505f, 0.007636f, 0.006299f, 0.007413f, 0.009095f, 0.007446f, 0.005743f, 0.004540f, 0.005533f, 0.007033f, 0.005821f, 0.004282f, 0.003209f, 0.004123f, 0.005491f, 0.004577f, 0.003198f, 0.002247f, 0.003097f, 0.004355f, 0.003652f, 0.002412f, 0.001565f, 0.002357f, 0.003517f, 0.002965f, 0.001844f, 0.001084f, 0.001821f, 0.002893f, 0.002451f, 0.001430f, 0.000741f, 0.001428f, 0.002422f, -0.111434f, -0.105946f, -0.100351f, -0.091868f, -0.083323f, -0.078775f, -0.076222f, -0.073291f, -0.067635f, -0.061692f, -0.058943f, -0.057832f, -0.056263f, -0.052198f, -0.047768f, -0.046002f, -0.045655f, -0.044839f, -0.041748f, -0.038271f, -0.037084f, -0.037161f, -0.036786f, -0.034331f, -0.031495f, 0.000077f, -0.000673f, -0.001181f, -0.000667f, 0.000079f, -0.000089f, -0.000802f, -0.001285f, -0.000793f, -0.000079f, -0.000228f, -0.000908f, -0.001368f, -0.000896f, -0.000212f, -0.000345f, -0.000996f, -0.001434f, -0.000981f, -0.000325f, -0.000444f, -0.001067f, -0.001487f, -0.001051f, -0.000421f, 0.000697f, 0.000188f, -0.000152f, 0.000210f, 0.000731f, 0.000650f, 0.000165f, -0.000161f, 0.000185f, 0.000683f, 0.000610f, 0.000145f, -0.000168f, 0.000164f, 0.000641f, 0.000574f, 0.000128f, -0.000172f, 0.000146f, 0.000604f, 0.000542f, 0.000113f, -0.000175f, 0.000131f, 0.000571f, -0.009490f, -0.010070f, -0.010409f, -0.009734f, -0.008834f, -0.008785f, -0.009351f, -0.009687f, -0.009054f, -0.008207f, -0.008167f, -0.008718f, -0.009050f, -0.008455f, -0.007654f, -0.007622f, -0.008159f, -0.008485f, -0.007924f, -0.007164f, -0.007138f, -0.007661f, -0.007981f, -0.007450f, -0.006728f, -0.000901f, -0.001327f, -0.001614f, -0.001310f, -0.000869f, -0.000913f, -0.001328f, -0.001607f, -0.001310f, -0.000882f, -0.000922f, -0.001326f, -0.001598f, -0.001309f, -0.000892f, -0.000930f, -0.001323f, -0.001588f, -0.001306f, -0.000900f, -0.000936f, -0.001319f, -0.001577f, -0.001302f, -0.000906f, 0.000339f, 0.000038f, -0.000164f, 0.000048f, 0.000355f, 0.000328f, 0.000035f, -0.000162f, 0.000045f, 0.000343f, 0.000318f, 0.000033f, -0.000159f, 0.000041f, 0.000332f, 0.000308f, 0.000030f, -0.000157f, 0.000039f, 0.000322f, 0.000299f, 0.000028f, -0.000155f, 0.000036f, 0.000312f, -0.004085f, -0.004479f, -0.004733f, -0.004396f, -0.003925f, -0.003925f, -0.004309f, -0.004558f, -0.004232f, -0.003775f, -0.003776f, -0.004151f, -0.004395f, -0.004079f, -0.003636f, -0.003637f, -0.004004f, -0.004242f, -0.003936f, -0.003505f, -0.003507f, -0.003866f, -0.004100f, -0.003802f, -0.003383f} + ); + + sd::ops::lrn_bp op; + auto results = op.evaluate({&x, &eps}, {1.0, 1.0, 0.5}, {2}, {}, {}, false); + auto out = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(out)); + //out->printBuffer("LRN BP out"); +// exp.printIndexedBuffer("LRN exp"); + // ASSERT_TRUE(exp.equalsTo(out)); + + +} + + diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests9.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests9.cpp new file mode 100644 index 000000000..91ebb5ba6 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTests9.cpp @@ -0,0 +1,2592 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author Yurii Shyrma (iuriish@yahoo.com), created on 22.06.2018 +// + + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include + +using namespace sd; + + +class DeclarableOpsTests9 : public testing::Test { +public: + + DeclarableOpsTests9() { + printf("\n"); + fflush(stdout); + } +}; + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, reduceStDevBP_test3) { + + auto x = NDArrayFactory::create('c', {3,4}); + auto gradO1 = NDArrayFactory::create('c', {3,1}, {1.,2.,3.}); + auto gradO2 = NDArrayFactory::create('c', {3}, {1.,2.,3.}); + auto exp = NDArrayFactory::create('c', {3,4}, {-0.335410, -0.111803, 0.111803, 0.335410, -0.670820, -0.223607, 0.223607, 0.670820, -1.006231, -0.335410, 0.335410, 1.006231}); + + x.linspace(1); + + sd::ops::reduce_stdev_bp op; + + auto result = op.evaluate({&x, &gradO2}, {0,0}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + // output->printIndexedBuffer(); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + + result = op.evaluate({&x, &gradO1}, {1,0}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, reduceStDevBP_test03) { + + auto x = NDArrayFactory::create('c', {3,4}); + auto gradO1 = NDArrayFactory::create('c', {3,1}, {1.,2.,3.}); + auto gradO2 = NDArrayFactory::create('c', {3}, {1.,2.,3.}); + auto exp = NDArrayFactory::create('c', {3,4}, {-0.335410, -0.111803, 0.111803, 0.335410, -0.670820, -0.223607, 0.223607, 0.670820, -1.006231, -0.335410, 0.335410, 1.006231}); + auto axis = NDArrayFactory::create('c', {1}, {1}); + x.linspace(1); + + sd::ops::reduce_stdev_bp op; + + auto result = op.evaluate({&x, &gradO2, &axis}, {}, {}, {false, false}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + // output->printIndexedBuffer(); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + + result = op.evaluate({&x, &gradO1}, {1,0}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + output = result.at(0); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} +/* + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test1) { + + const int N = 50000; + const double lambda = 2.; + const double mean = 1. / lambda; + const double std = mean; + + auto x = NDArrayFactory::create('c', {N}); + double extraParams[] = {lambda}; + + Nd4jLong *buffer = new Nd4jLong[N]; + auto rng = (sd::random::RandomBuffer *) initRandom(nullptr, 123, N, (Nd4jPointer) buffer); + if (rng == nullptr) + throw std::runtime_error("DeclarableOpsTests9.exponentialDistributionInv_test1: RNG initialization failed !"); + + functions::random::RandomFunction::template execTransform>(rng, x.getBuffer(), x.shapeInfo(), extraParams); + const double actualMean = x.meanNumber().e(0); + const double actualStd = x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); + + ASSERT_NEAR(mean, actualMean, 0.01); + ASSERT_NEAR(std, actualStd, 0.01); + + destroyRandom((Nd4jPointer) rng); + delete[] buffer; + +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, exponentialDistributionInv_test2) { + + const int N = 50000; + const double lambda = 2.; + const double mean = 1. / lambda; + const double std = mean; + double extraParams[] = {lambda}; + + auto x = NDArrayFactory::create('c', {N}); + auto y = NDArrayFactory::create('c', {N}); + y.linspace(0., 1./N); // [0, 1) + + + Nd4jLong *buffer = new Nd4jLong[N]; + auto rng = (sd::random::RandomBuffer *) initRandom(nullptr, 123, N, (Nd4jPointer) buffer); + if (rng == nullptr) + throw std::runtime_error("DeclarableOpsTests9.exponentialDistributionInv_test2: RNG initialization failed !"); + + functions::random::RandomFunction::template execTransform>(rng, y.getBuffer(), y.shapeInfo(), x.getBuffer(), x.shapeInfo(), extraParams); + + const double actualMean = x.meanNumber().e(0); + const double actualStd = x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); + + ASSERT_NEAR(mean, actualMean, 0.01); + ASSERT_NEAR(std, actualStd, 0.01); + + destroyRandom((Nd4jPointer) rng); + delete[] buffer; + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, exponentialDistribution_test1) { + + const int N = 50000; + const double lambda = 2.; + const double mean = 1. / lambda; + const double std = mean; + + auto x = NDArrayFactory::create('c', {N}); + double extraParams[] = {lambda}; + + Nd4jLong *buffer = new Nd4jLong[N]; + auto rng = (sd::random::RandomBuffer *) initRandom(nullptr, 123, N, (Nd4jPointer) buffer); + if (rng == nullptr) + throw std::runtime_error("DeclarableOpsTests9.exponentialDistribution_test1: RNG initialization failed !"); + + functions::random::RandomFunction::template execTransform>(rng, x.getBuffer(), x.shapeInfo(), extraParams); + const double actualMean = x.meanNumber().e(0); + const double actualStd = x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); + + ASSERT_NEAR(mean, actualMean, 0.01); + ASSERT_NEAR(std, actualStd, 0.01); + + destroyRandom((Nd4jPointer) rng); + delete[] buffer; +} + + +////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, exponentialDistribution_test2) { + + const int N = 50000; + const double lambda = 2.; + const double mean = 1. / lambda; + const double std = mean; + double extraParams[] = {lambda}; + + auto x = NDArrayFactory::create('c', {N}); + auto y = NDArrayFactory::create('c', {N}); + y.linspace(-N/2.); // [-25000, 25000) + + + Nd4jLong *buffer = new Nd4jLong[N]; + // Nd4jPointer extra[2]; +#ifndef __CUDABLAS__ + sd::random::RandomBuffer* rng = (sd::random::RandomBuffer *) initRandom(nullptr, 123, N, (Nd4jPointer) buffer); + if (rng == nullptr) + throw std::runtime_error("DeclarableOpsTests9.exponentialDistribution_test2: RNG initialization failed !"); + + functions::random::RandomFunction::template execTransform>(rng, y.getBuffer(), y.shapeInfo(), x.getBuffer(), x.shapeInfo(), extraParams); + + destroyRandom((Nd4jPointer) rng); +#endif + const double actualMean = x.meanNumber().e(0); + const double actualStd = x.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); + ASSERT_NEAR(mean, actualMean, 0.01); + ASSERT_NEAR(std, actualStd, 0.01); + + + + + delete[] buffer; +} +*/ + +TEST_F(DeclarableOpsTests9, ScalarOpTest_MixedOrders_1) { + auto x = NDArrayFactory::create('f', {2, 2}, {1.0, 3.0, 2.0, 4.0}); + auto e = NDArrayFactory::create('c', {2, 2}, {2.0, 3.0, 4.0, 5.0}); + auto z = NDArrayFactory::create('c', {2, 2}, {0.0, 0.0, 0.0, 0.0}); + + x.applyScalar(scalar::Add, 1.0, z); + + ASSERT_EQ(e, z); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test1) { + + auto x0 = NDArrayFactory::create('c', {2,3,4}); + auto x1 = NDArrayFactory::create('c', {2,2,4}); + auto x2 = NDArrayFactory::create('c', {2,1,4}); + auto exp = NDArrayFactory::create('c', {2,6,4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f, + 13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.}); + + x0.linspace(1); + x1.linspace(1); + x2.linspace(1); + + sd::ops::concat op; + + auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + // output->printCurrentBuffer(false); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test2) { + + auto x0 = NDArrayFactory::create('c', {1,3,1}); + auto x1 = NDArrayFactory::create('c', {1,2,1}); + auto x2 = NDArrayFactory::create('c', {1,1,1}); + auto exp = NDArrayFactory::create('c', {1,6,1}, {1.f, 2.f, 3.f, 1.f, 2.f, 1.f}); + + x0.linspace(1); + x1.linspace(1); + x2.linspace(1); + + sd::ops::concat op; + + auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test3) { + + auto x0 = NDArrayFactory::create('c', {3}); + auto x1 = NDArrayFactory::create('c', {2}); + auto x2 = NDArrayFactory::create('c', {1}); + auto exp = NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 1.f, 2.f, 1.f}); + + x0.linspace(1); + x1.linspace(1); + x2.linspace(1); + + sd::ops::concat op; + + auto result = op.evaluate({&x0, &x1, &x2}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test4) { + + auto x0 = NDArrayFactory::create('c', {1,1,1}, {1.f}); + auto x1 = NDArrayFactory::create('c', {1,1,1}, {2.f}); + auto x2 = NDArrayFactory::create('c', {1,1,1}, {3.f}); + auto exp = NDArrayFactory::create('c', {1,3,1}, {1.f, 2.f, 3.f}); + + sd::ops::concat op; + + auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test5) { + + auto x0 = NDArrayFactory::create(1.f); + auto x1 = NDArrayFactory::create('c', {1}, {2.f}); + auto x2 = NDArrayFactory::create(3.f); + auto exp = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + + sd::ops::concat op; + + auto result = op.evaluate({&x0, &x1, &x2}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test6) { + + auto x0 = NDArrayFactory::create(1.f); + auto x1 = NDArrayFactory::create('c', {2}, {2.f, 20.f}); + auto x2 = NDArrayFactory::create(3.f); + auto exp = NDArrayFactory::create('c', {4}, {1.f, 2.f, 20.f, 3.f}); + + sd::ops::concat op; + + auto result = op.evaluate({&x0, &x1, &x2}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test7) { + + auto x0 = NDArrayFactory::create(1.f); + auto x1 = NDArrayFactory::create(2.f); + auto x2 = NDArrayFactory::create(3.f); + auto exp = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + + sd::ops::concat op; + + auto result = op.evaluate({&x0, &x1, &x2}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test8) { + + auto x0 = NDArrayFactory::create(1.f); + auto exp = NDArrayFactory::create('c', {1}, {1.f}); + + sd::ops::concat op; + + auto result = op.evaluate({&x0}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test9) { + + auto x0 = NDArrayFactory::create('c', {1}, {1.f}); + auto exp = NDArrayFactory::create('c', {1}, {1.f}); + + sd::ops::concat op; + + auto result = op.evaluate({&x0}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test10) { + + auto x0 = NDArrayFactory::create('c', {2,3,4}); + auto x1 = NDArrayFactory::create('f', {2,2,4}); + auto x2 = NDArrayFactory::create('c', {2,1,4}); + auto exp = NDArrayFactory::create('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f, + 13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.f}); + + x0.linspace(1); + x1.linspace(1); + x2.linspace(1); + + sd::ops::concat op; + + auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test11) { + + auto x0 = NDArrayFactory::create('c', {2,3,4}); + auto x1 = NDArrayFactory::create('f', {2,2,4}); + auto x2 = NDArrayFactory::create('f', {2,1,4}); + auto exp = NDArrayFactory::create('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f, + 13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.f}); + + x0.linspace(1); + x1.linspace(1); + x2.linspace(1); + + sd::ops::concat op; + + auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test12) { + + auto x0 = NDArrayFactory::create('c', {2,3,4}); + auto x1 = NDArrayFactory::create('f', {2,2,4}); + auto x2 = NDArrayFactory::create('f', {2,1,4}); + auto exp = NDArrayFactory::create('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f, + 13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.f}); + + x0.linspace(1); + x1.linspace(1); + x2.linspace(1); + + sd::ops::concat op; + + auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test13) { + + auto x0 = NDArrayFactory::create('f', {2,3,4}); + auto x1 = NDArrayFactory::create('f', {2,2,4}); + auto x2 = NDArrayFactory::create('f', {2,1,4}); + auto exp = NDArrayFactory::create('f', {2,6,4}, { 1.f, 13.f, 5.f, 17.f, 9.f, 21.f, 1.f, 9.f, 5.f, 13.f, 1.f, 5.f, 2.f, 14.f, 6.f, 18.f,10.f, 22.f, 2.f, 10.f, 6.f, 14.f, 2.f, 6.f, + 3.f, 15.f, 7.f, 19.f,11.f, 23.f, 3.f, 11.f, 7.f, 15.f, 3.f, 7.f, 4.f, 16.f, 8.f, 20.f,12.f, 24.f, 4.f, 12.f, 8.f, 16.f, 4.f, 8.f}); + + x0.linspace(1); + x1.linspace(1); + x2.linspace(1); + + sd::ops::concat op; + + auto result = op.evaluate({&x0, &x1, &x2}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +TEST_F(DeclarableOpsTests9, concat_test14) { + + NDArray x0('c', {1, 40, 60}, sd::DataType::FLOAT32); + NDArray x1('c', {1, 40, 60}, sd::DataType::FLOAT32); + + x0 = 1.; + x1 = 2.; + + sd::ops::concat op; + auto result = op.evaluate({&x0, &x1}, {}, {0}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z->shapeInfo(), {0}); + ASSERT_TRUE(2 == numOfTads); + + for (int e = 0; e < numOfTads; ++e) { + NDArray tad = (*z)(e, {0}); + auto mean = tad.meanNumber().e(0); + ASSERT_NEAR((e+1)*1., mean, 1e-5); + } + + +} + +TEST_F(DeclarableOpsTests9, concat_test15) { + auto x = NDArrayFactory::create('c', {2}, {1, 0}); + auto y = NDArrayFactory::create (3.0f); + auto exp = NDArrayFactory::create('c', {3}, {1, 0, 3}); + + sd::ops::concat op; + auto result = op.evaluate({&x, &y}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test16) { + + auto x = NDArrayFactory::create('c', {0,2,3}); + auto y = NDArrayFactory::create('c', {0,2,3}); + auto exp = NDArrayFactory::create('c', {0,2,3}); + + sd::ops::concat op; + auto result = op.evaluate({&x, &y}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test17) { + + NDArray x0('c', {1, 55, 40}, sd::DataType::FLOAT32); + NDArray x1('c', {1, 55, 40}, sd::DataType::FLOAT32); + + x0 = 1.; + x1 = 2.; + + sd::ops::concat op; + auto result = op.evaluate({&x0, &x1}, {}, {0}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + // z->printShapeInfo(); + // z->printIndexedBuffer(); + + Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z->shapeInfo(), {0}); + ASSERT_TRUE(2 == numOfTads); + + for (int e = 0; e < numOfTads; ++e) { + NDArray tad = (*z)(e, {0}); + auto mean = tad.meanNumber().e(0); + ASSERT_NEAR((e+1)*1., mean, 1e-5); + } +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test18) { + Context context(1); + Nd4jLong axis = 0; + + // we crate bunch of arrays, filled with specific values + for (int e = 0; e < 2000; e++) { + auto array = NDArrayFactory::create_('c', {1, 300}); + array->assign(e); + context.setInputArray(e, array, true); + } + + auto z = NDArrayFactory::create('c', {2000, 300}); + context.setOutputArray(0, &z, false); + context.setIArguments(&axis, 1); + + sd::ops::concat op; + op.execute(&context); + + for (int e = 0; e < 2000; e++) { + auto exp = NDArrayFactory::create('c', {300}); + exp.assign(e); + auto row = z(e, {0}); + ASSERT_EQ(exp, row); + } +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test19) { + + Context context(1); + Nd4jLong axis = 0; + + // we crate bunch of arrays, filled with specific values + for (int e = 0; e < 10; e++) { + auto array = NDArrayFactory::create_('c', {1, 5, 20}); + array->assign(e); + context.setInputArray(e, array, true); + } + + auto z = NDArrayFactory::create('c', {10, 5, 20}); + context.setOutputArray(0, &z, false); + context.setIArguments(&axis, 1); + + sd::ops::concat op; + op.execute(&context); + + for (int e = 0; e < 10; e++) + ASSERT_NEAR((float) e, z(e, {0}).meanNumber().e(0), 1e-5f); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test20) { + auto x0 = NDArrayFactory::create('c', {1, 100, 150}); + auto x1 = NDArrayFactory::create('c', {1, 100, 150}); + auto x2 = NDArrayFactory::create('c', {1, 100, 150}); + auto x3 = NDArrayFactory::create('c', {1, 100, 150}); + + x0.assign(1.0); + x1.assign(2.0); + x2.assign(3.0); + x3.assign(4.0); + + sd::ops::concat op; + auto result = op.evaluate({&x0, &x1, &x2, &x3}, {}, {0}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + Nd4jLong numOfTads= ShapeUtils::getNumOfSubArrs(z->shapeInfo(), {0}); + ASSERT_TRUE(4 == numOfTads); + + for (int e = 0; e < numOfTads; e++) { + NDArray tad = (*z)(e, {0}); + auto mean = tad.meanNumber().e(0); + ASSERT_NEAR((float) e+1, mean, 1e-5); + } + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test21) { + + NDArray x0('c', {1,4,5}, sd::DataType::FLOAT32); + NDArray x1('c', {2,4,5}, sd::DataType::FLOAT32); + NDArray z('f', {3,4,5}, sd::DataType::FLOAT32); + + x0 = 0.; + x1 = 1.; + + sd::ops::concat op; + auto status = op.execute({&x0, &x1}, {&z}, {}, {0}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test22) { + + NDArray x0('c', {1,6}, {1,2,3,4,5,6}, sd::DataType::FLOAT32); + NDArray x1('c', {1,6}, {7,8,9,10,11,12}, sd::DataType::FLOAT32); + NDArray output('f', {2,6}, sd::DataType::FLOAT32); + NDArray exp('c', {2,6}, {1,2,3,4,5,6,7,8,9,10,11,12}, sd::DataType::FLOAT32); + + sd::ops::concat op; + + auto status = op.execute({&x0, &x1}, {&output}, {}, {0}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + ASSERT_TRUE(exp.equalsTo(output)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test23) { + + NDArray x0('c', {1,4}, {1,2,3,4},sd::DataType::FLOAT32); + NDArray x1('c', {1,4}, {5,6,7,8},sd::DataType::FLOAT32); + NDArray output('c', {2,4}, sd::DataType::FLOAT32); + NDArray exp('c', {2,4}, {1,2,3,4,5,6,7,8}, sd::DataType::FLOAT32); + + sd::ops::concat op; + + auto status = op.execute({&x0, &x1}, {&output}, {}, {0}, {}); + ASSERT_EQ(ND4J_STATUS_OK, status); + + ASSERT_TRUE(exp.equalsTo(output)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test24) { + auto x = NDArrayFactory::create('c', {2, 1}, {1, 1}); + auto y = NDArrayFactory::create('c', {2, 1}, {0, 0}); + auto e = NDArrayFactory::create('c', {2, 2}, {1, 0, 1, 0}); + auto z = NDArrayFactory::create('c', {2, 2}); + + sd::ops::concat op; + auto status = op.execute({&x, &y}, {&z}, {}, {1}, {}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test25) { + + auto x0 = NDArrayFactory::create('c', {1,4}, {1,2,3,4}); + auto x1 = NDArrayFactory::create('c', {1,4}, {5,6,7,8}); + auto axis = NDArrayFactory::create('c', {1}, {0.}); + auto exp = NDArrayFactory::create('c', {2,4}, {1,2,3,4,5,6,7,8}); + + sd::ops::concat op; + + auto result = op.evaluate({&x0, &x1, &axis}, {}, {}, {true}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test26) { + + NDArray x0('f', {1, 2, 3}, sd::DataType::INT32); + NDArray x1('f', {1, 2, 3}, sd::DataType::INT32); + NDArray x2('f', {1, 2, 3}, sd::DataType::INT32); + + NDArray exp('f', {3, 2, 3}, {0, 6, 12, 3, 9, 15, 1, 7, 13, 4, 10, 16, 2, 8, 14, 5, 11, 17}, sd::DataType::INT32); + + x0.linspace(0); + x1.linspace(6); + x2.linspace(12); + + sd::ops::concat op; + + auto result = op.evaluate({&x0, &x1, &x2}, {}, {0}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + // output->printLinearBuffer(); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, concat_test27) { + + auto x1 = NDArrayFactory::create('c', {0,1}); + auto x2 = NDArrayFactory::create('c', {0,1}); + auto x3 = NDArrayFactory::create('c', {0,1}); + auto x4 = NDArrayFactory::create('c', {0,1}); + + std::vector expShape = {0, 4}; + + sd::ops::concat op; + auto result = op.evaluate({&x1, &x2, &x3, &x4}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(z->isSameShape(expShape)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, tile_bp_test1) { + + auto input = NDArrayFactory::create('c', {2, 3}, {1.,2.,3.,4.,5.,6.}); + auto gradO = NDArrayFactory::create('c', {4, 9}); + auto gradIExp = NDArrayFactory::create('c', {2, 3}, {0.78, 0.84, 0.9,1.32, 1.38, 1.44}); + + gradO.linspace(0.01, 0.01); + + sd::ops::tile_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {2, 3}); + auto gradI = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(gradIExp.isSameShape(gradI)); + ASSERT_TRUE(gradIExp.equalsTo(gradI)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, tile_bp_test2) { + + auto input = NDArrayFactory::create('c', {2, 3}, {1.,2.,3.,4.,5.,6.}); + auto gradO = NDArrayFactory::create('c', {2, 9}); + auto gradIExp = NDArrayFactory::create('c', {2, 3}, {0.12, 0.15, 0.18, 0.39, 0.42, 0.45}); + + gradO.linspace(0.01, 0.01); + + sd::ops::tile_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {1, 3}); + auto gradI = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(gradIExp.isSameShape(gradI)); + ASSERT_TRUE(gradIExp.equalsTo(gradI)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, tile_bp_test3) { + + auto input = NDArrayFactory::create('c', {2, 3}, {1.,2.,3.,4.,5.,6.}); + auto gradO = NDArrayFactory::create('c', {2, 3}); + auto gradIExp = NDArrayFactory::create('c', {2, 3}, {0.01, 0.02, 0.03,0.04, 0.05, 0.06}); + + gradO.linspace(0.01, 0.01); + + sd::ops::tile_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {1, 1}); + auto gradI = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(gradIExp.isSameShape(gradI)); + ASSERT_TRUE(gradIExp.equalsTo(gradI)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, tile_bp_test4) { + + auto input = NDArrayFactory::create('c', {6}, {1.,2.,3.,4.,5.,6.}); + auto gradO = NDArrayFactory::create('c', {12}); + auto gradIExp = NDArrayFactory::create('c', {6}, {0.08, 0.1 , 0.12, 0.14, 0.16, 0.18}); + + gradO.linspace(0.01, 0.01); + + sd::ops::tile_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {2}); + auto gradI = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(gradIExp.isSameShape(gradI)); + ASSERT_TRUE(gradIExp.equalsTo(gradI)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, tile_bp_test5) { + + auto input = NDArrayFactory::create('c', {1}, {1.}); + auto gradO = NDArrayFactory::create('c', {1}); + auto gradIExp = NDArrayFactory::create('c', {1}, {0.01}); + + gradO.linspace(0.01, 0.01); + + sd::ops::tile_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {1}); + auto gradI = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(gradIExp.isSameShape(gradI)); + ASSERT_TRUE(gradIExp.equalsTo(gradI)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, tile_bp_test6) { + + auto input = NDArrayFactory::create('c', {2, 1, 3}, {1.,2.,3.,4.,5.,6.}); + auto gradO = NDArrayFactory::create('c', {2, 3, 6}); + auto gradIExp = NDArrayFactory::create('c', {2, 1, 3}, {0.51, 0.57, 0.63, 1.59, 1.65, 1.71}); + + gradO.linspace(0.01, 0.01); + + sd::ops::tile_bp op; + auto results = op.evaluate({&input, &gradO}, {}, {1, 3, 2}); + auto gradI = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(gradIExp.isSameShape(gradI)); + ASSERT_TRUE(gradIExp.equalsTo(gradI)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, tile_bp_test7) { + + auto input = NDArrayFactory::create('c', {2, 1, 3}, {1.,2.,3.,4.,5.,6.}); + auto reps = NDArrayFactory::create('c', {1, 3}, {1, 3, 2}); + auto gradO = NDArrayFactory::create('c', {2, 3, 6}); + auto gradIExp = NDArrayFactory::create('c', {2, 1, 3}, {0.51, 0.57, 0.63, 1.59, 1.65, 1.71}); + + gradO.linspace(0.01, 0.01); + + sd::ops::tile_bp op; + auto results = op.evaluate({&input, &reps, &gradO}, {}, {}); + auto gradI = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(gradIExp.isSameShape(gradI)); + ASSERT_TRUE(gradIExp.equalsTo(gradI)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, tile_test1) { + + auto input = NDArrayFactory::create('c', {1, 6}, {1.,2.,3.,4.,5.,6.}); + auto reps = NDArrayFactory::create('c', {1, 2}, {2, 1}); + auto expOut = NDArrayFactory::create('c', {2, 6,}, {1.,2.,3.,4.,5.,6., 1.,2.,3.,4.,5.,6.}); + + sd::ops::tile op; + auto results = op.evaluate({&input, &reps}, {}, {}); + auto out = results.at(0); + + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(expOut.isSameShape(out)); + ASSERT_TRUE(expOut.equalsTo(out)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, TestDropout_BP_1) { + + NDArray x('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + NDArray errs('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + NDArray shape('c', {2}, {2, 2}); + sd::ops::dropout_bp op; + + auto ress = op.evaluate({&x, &errs, &shape}, {0.2f}, {113}); + + ASSERT_EQ(ND4J_STATUS_OK, ress.status()); + //ress.at(0)->printIndexedBuffer("Result is "); + //x.printIndexedBuffer("Input is"); + ASSERT_FALSE(ress.at(0)->equalsTo(errs)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, TestDropout_1) { + + NDArray x('c', {10, 10}, sd::DataType::FLOAT32); +// NDArray errs('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + //NDArray shape({2.f, 2.f}); + sd::ops::dropout op; + x.linspace(1); + auto ress = op.evaluate({&x}, {0.2f}, {113}); + + ASSERT_EQ(ND4J_STATUS_OK, ress.status()); + NDArray* res = ress.at(0); //->printIndexedBuffer("Result is "); + //x.printIndexedBuffer("Input is"); + //res->printIndexedBuffer("Result for Dropout_1"); + auto countZero = res->reduceNumber(reduce::CountZero); + ASSERT_NEAR(countZero.e(0), 80, 5); + auto ress2 = op.evaluate({&x}, {0.2f}, {113}); + + ASSERT_EQ(ND4J_STATUS_OK, ress2.status()); + NDArray* res2 = ress2.at(0); + + countZero = res->reduceNumber(reduce::CountZero); + ASSERT_NEAR(countZero.e(0), 80, 5); + //res2->printIndexedBuffer("Result for Dropout_2"); + ASSERT_TRUE(res->equalsTo(res2)); + //res->printIndexedBuffer("FF dropout"); + //res2->printIndexedBuffer("BP dropout"); + + + +} + +TEST_F(DeclarableOpsTests9, Test_DropoutInverted_01) { + NDArray x0('c', {10, 10}, sd::DataType::FLOAT32); + NDArray x1('c', {10, 10}, sd::DataType::FLOAT32); + + x0.linspace(1); + x1.linspace(1); +/* + float prob[] = {0.5f}; + Nd4jLong* _bufferA = new Nd4jLong[100000]; + long _seed = 119L; + auto _rngA = (sd::random::RandomBuffer *) initRandom(nullptr, _seed, 100000, (Nd4jPointer) _bufferA); + + x0. applyTransform(random::DropOutInverted, &x0, prob); +// x1.template applyRandom>(_rngB, nullptr, &x1, prob); +// x0.printIndexedBuffer("01Result1"); + int count = 0; + for (int e = 0; e < x0.lengthOf(); e++) + if (x0.e(e) != 0.f) + count++; +// nd4j_printf("\nX0 count %i\n", count); +// ASSERT_TRUE(x0.equalsTo(&x1)); + + // this check is required to ensure we're calling wrong signature +// ASSERT_FALSE(x0.equalsTo(nexp0)); +// ASSERT_FALSE(x0.equalsTo(nexp1)); +// ASSERT_FALSE(x0.equalsTo(nexp2)); + destroyRandom(_rngA); + delete [] _bufferA; +*/ + sd::ops::dropout op; + + auto ress = op.evaluate({&x1}, {0.5f}, {119}); + + ASSERT_EQ(ND4J_STATUS_OK, ress.status()); + //ress.at(0)->printIndexedBuffer("01Dropout result is "); + auto count = ress.at(0)->reduceNumber(reduce::CountNonZero); +// nd4j_printf("\n01Dropout count %i\n\n", count); + + sd::ops::dropout_bp op2; + //NDArray exp('c', {10,10}, {4.f, 0.f, 12.f, 0.f, 20.f, 24.f, 0.f, 32.f, 0.f, 0.f, 0.f, 0.f, 52.f, 56.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 84.f, 88.f, 0.f, 0.f, 0.f, 0.f, 108.f, 0.f, 0.f, 120.f, 0.f, 0.f, 132.f, 0.f, 0.f, 0.f, 0.f, 0.f, 156.f, 0.f, 164.f, 168.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 200.f, 204.f, 0.f, 0.f, 0.f, 220.f, 0.f, 0.f, 232.f, 236.f, 240.f, 0.f, 248.f, 0.f, 0.f, 260.f, 0.f, 0.f, 0.f, 276.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 316.f, 0.f, 324.f, 0.f, 0.f, 336.f, 0.f, 0.f, 0.f, 0.f, 356.f, 0.f, 0.f, 368.f, 0.f, 0.f, 0.f, 384.f, 388.f, 0.f, 0.f, 400.f}); + //02Dropout result is [4.000000, 0.000000, 12.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 36.000000, 0.000000, 0.000000, 0.000000, 0.000000, 56.000000, 60.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 88.000000, 0.000000, 96.000000, 0.000000, 0.000000, 108.000000, 0.000000, 0.000000, 120.000000, 0.000000, 128.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 156.000000, 0.000000, 164.000000, 0.000000, 0.000000, 0.000000, 0.000000, 184.000000, 0.000000, 0.000000, 0.000000, 200.000000, 0.000000, 0.000000, 0.000000, 216.000000, 0.000000, 0.000000, 0.000000, 232.000000, 0.000000, 240.000000, 0.000000, 248.000000, 0.000000, 0.000000, 260.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 308.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 348.000000, 0.000000, 356.000000, 0.000000, 0.000000, 0.000000, 0.000000, 376.000000, 0.000000, 384.000000, 0.000000, 0.000000, 0.000000, 400.000000] + + auto ressX = op2.evaluate({&x1, &x1}, {0.5f}, {119}); // , false, sd::DataType::FLOAT32); // skipped due given by default + //x0.printIndexedBuffer("X0"); + //x1.printIndexedBuffer("X1"); + ASSERT_EQ(ND4J_STATUS_OK, ressX.status()); + auto ressY = op2.evaluate({&x1, &x0}, {0.5f}, {119}); + ASSERT_EQ(ND4J_STATUS_OK, ressY.status()); + //ressY->at(0)->printIndexedBuffer("BP"); + //ress.at(0)->printIndexedBuffer("FF"); + bool ret = true; + for (int e = 0; e < ress.at(0)->lengthOf(); e++) { + if (ress.at(0)->e(e) == 0.f) + if (ressX.at(0)->e(e) != ress.at(0)->e(e)) { + ret = false; + break; + } + } + ASSERT_TRUE(ret); + // ASSERT_FALSE(ressX->at(0)->equalsTo(ressY->at(0))); + //ressX->at(0)->printIndexedBuffer("02Dropout result is "); +/* float countZero = ressX->at(0)->template reduceNumber>(); + ASSERT_NEAR(countZero, 50.f, 5.f); + countZero = ress.at(0)->template reduceNumber>(); + ASSERT_NEAR(countZero, 50.f, 5.f); + countZero = ressY->at(0)->template reduceNumber>(); + ASSERT_NEAR(countZero, 50.f, 5.f); + */ +// ASSERT_TRUE(exp.equalsTo(ressX->at(0))); + + +} + +TEST_F(DeclarableOpsTests9, Test_Dropout_BP_2) { + NDArray x('c', {10, 10}, sd::DataType::FLOAT32); + + x.linspace(1); + + sd::ops::dropout op; + + auto ress = op.evaluate({&x}, {0.5f}, {119}); + + ASSERT_EQ(ND4J_STATUS_OK, ress.status()); +// ress.at(0)->printIndexedBuffer("01Dropout result is "); + + sd::ops::dropout_bp op2; + + auto ressX = op2.evaluate({&x, &x}, {0.5f}, {119}); + + ASSERT_EQ(ND4J_STATUS_OK, ressX.status()); + auto ressY = op2.evaluate({&x, &x}, {0.5f}, {119}); + ASSERT_EQ(ND4J_STATUS_OK, ressY.status()); + + //ress.at(0)->printIndexedBuffer("FF Dropout result is "); + //ressY->at(0)->printIndexedBuffer("BP Dropout result is "); + + + auto countZero = ress.at(0)->reduceNumber(reduce::CountZero); + ASSERT_NEAR(countZero.e(0), 50.f, 10.f); + countZero = ressX.at(0)->reduceNumber(reduce::CountZero); + //nd4j_printf("X zero count is %f\n", countZero); + ASSERT_NEAR(countZero.e(0), 50.f, 10.f); + countZero = ressY.at(0)->reduceNumber(reduce::CountZero); + //nd4j_printf("Y zero count is %f\n", countZero); + ASSERT_NEAR(countZero.e(0), 50.f, 10.f); +// ASSERT_TRUE(exp.equalsTo(ressX->at(0))); + ASSERT_TRUE(ressX.at(0)->equalsTo(ressY.at(0))); + +} + + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, Test_AlphaDropout_BP_1) { + NDArray x('c', {10, 10}, sd::DataType::FLOAT32); + NDArray eps('c', {10, 10}, sd::DataType::FLOAT32); + + x.linspace(1); + eps.linspace(1); + + sd::ops::alpha_dropout_bp op; + + auto ress = op.evaluate({&x, &eps}, {0.5f, 0.5f, 1.5f, 1.6f}, {119}); + + ASSERT_EQ(ND4J_STATUS_OK, ress.status()); + NDArray* res = ress.at(0); + + auto ress2 = op.evaluate({&x, &eps}, {0.5f, 0.5f, 1.5f, 1.6f}, {119}); + + ASSERT_EQ(ND4J_STATUS_OK, ress2.status()); + NDArray* res2 = ress2.at(0); + //res->printIndexedBuffer("Result1AlphaBP1"); + //res2->printIndexedBuffer("Result1AlphaBP2"); + ASSERT_TRUE(res2->equalsTo(res)); + +} + +TEST_F(DeclarableOpsTests9, test_range_int_1) { + auto x0 = NDArrayFactory::create(0); + auto x1 = NDArrayFactory::create(2); + auto x2 = NDArrayFactory::create(1); + + sd::ops::range op; + auto result = op.evaluate({&x0, &x1, &x2}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + +} + +TEST_F(DeclarableOpsTests9, test_range_empty_1) { + auto x0 = NDArrayFactory::create(0); + auto x1 = NDArrayFactory::create(0); + auto x2 = NDArrayFactory::create(1); + + sd::ops::range op; + auto result = op.evaluate({&x0, &x1, &x2}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(z->isEmpty()); + +} + + +TEST_F(DeclarableOpsTests9, test_broadcast_bool_1) { + auto x = NDArrayFactory::create('c', {1, 3, 2, 4, 4}); + auto y = NDArrayFactory::create('c', {1, 2, 4, 4}); + auto z = NDArrayFactory::create('c', {1, 3, 2, 4, 4}); + + std::vector dims = {0, 2, 3, 4}; + x.applyBroadcast(broadcast::LessThan, dims, y, z); +} + +TEST_F(DeclarableOpsTests9, test_broadcast_bool_2) { + auto orig = NDArrayFactory::create('c', {1, 7, 4, 4}); + std::vector list = {0,0, 0,2, 0,0, 0,0}; + auto x = NDArrayFactory::create('c', {1, 3, 2, 4, 4}); + + auto y = orig(list, true); + + auto z = NDArrayFactory::create('c', {1, 3, 2, 4, 4}); + + std::vector dims = {0, 2, 3, 4}; + x.applyBroadcast(broadcast::LessThan, dims, y, z); + +} + +TEST_F(DeclarableOpsTests9, test_unstack_1) { + auto x = NDArrayFactory::create('c', {5, 5}); + x.linspace(1.0); + + sd::ops::unstack op; + auto result = op.evaluate({&x}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(5, result.size()); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, test_unstack_SGO_1) { + auto x = NDArrayFactory::create({1, 2, 3, 4, 5}); + x.linspace(1.0); + auto z1 = NDArrayFactory::create(1); + auto z2 = NDArrayFactory::create(2); + auto z3 = NDArrayFactory::create(3); + auto z4 = NDArrayFactory::create(4); + auto z5 = NDArrayFactory::create(5); + std::vector z({&z1, &z2, &z3, &z4, &z5}); + sd::ops::unstack op; + auto result = op.evaluate({&x}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_EQ(5, result.size()); + for (size_t i = 0; i < result.size(); i++) { + ASSERT_TRUE(result.at(i)->isSameShape(z[i])); + ASSERT_TRUE(result.at(i)->equalsTo(z[i])); + } + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, cumprod_1) { + + auto inputC = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto axis = NDArrayFactory::create(1); + + auto expFF = NDArrayFactory::create('c', {3, 5}, {1., 2., 6., 24., 120., 6., 42., 336., 3024., 30240.,11., 132.,1716., 24024.,360360.}); + auto expTF = NDArrayFactory::create('c', {3, 5}, {1, 1, 2, 6, 24,1, 6, 42, 336, 3024,1, 11, 132, 1716, 24024}); + + auto expFT = NDArrayFactory::create('c', {3, 5}, {120, 120, 60, 20, 5,30240, 5040, 720, 90, 10,360360, 32760, 2730, 210, 15}); //+++ + auto expTT = NDArrayFactory::create('c', {3, 5}, {120, 60, 20, 5, 1,5040, 720, 90, 10, 1,32760, 2730, 210, 15, 1}); + + int exclusive, reverse; + + //************************************// + exclusive = 0; reverse = 0; + + sd::ops::cumprod op; + auto result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); + ASSERT_EQ(Status::OK(), result.status()); + auto z = result.at(0); + ASSERT_TRUE(expFF.equalsTo(z)); + + + //************************************// + exclusive = 1; reverse = 0; + + result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); + ASSERT_EQ(Status::OK(), result.status()); + z = result.at(0); + ASSERT_TRUE(expTF.equalsTo(z)); + + + //************************************// + exclusive = 0; reverse = 1; + + result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); + ASSERT_EQ(Status::OK(), result.status()); + z = result.at(0); + ASSERT_TRUE(expFT.equalsTo(z)); + + + //************************************// + exclusive = 1; reverse = 1; + + result = op.evaluate({&inputC, &axis}, {}, {exclusive, reverse}); + ASSERT_EQ(Status::OK(), result.status()); + z = result.at(0); + ASSERT_TRUE(expTT.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, cumprod_2) { + + NDArray x('c', {2, 1500}, sd::DataType::FLOAT32); + NDArray x0 = x(0, {0}); + NDArray x1 = x(1, {0}); + x0.linspace(1, 0.1); + x1.linspace(1, 0.1); + + NDArray exp('c', {2, 1500}, sd::DataType::FLOAT32); + NDArray exp0 = exp(0, {0}); + NDArray exp1 = exp(1, {0}); + + exp0.p(0, 1.f); + exp1.p(0, 1.f); + + for (int i = 1; i < 1500; ++i) { + const auto prev = exp0.e(i-1); + exp0.p(i, prev * x0.e(i)); + exp1.p(i, prev * x1.e(i)); + } + + sd::ops::cumprod op; + auto result = op.evaluate({&x}, {}, {0, 0, 1}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, cumprod_bp_check_1) { + + auto x = NDArrayFactory::create('c', {4, 4}); + auto gradO = NDArrayFactory::create('c', {4, 4}); + + x.linspace(1); + + const OpArgsHolder argsHolderFF({&x}, {}, {0, 0}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {}, {0, 0}); + + sd::ops::cumprod opFF; + sd::ops::cumprod_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1},GradCheck::MEAN); + + ASSERT_TRUE(isGradCorrect); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, cumprod_bp_check_2) { + + auto x = NDArrayFactory::create('c', {4, 4}); + auto gradO = NDArrayFactory::create('c', {4, 4}); + + x.linspace(1); + + const OpArgsHolder argsHolderFF({&x}, {}, {1, 1}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {}, {1, 1}); + + sd::ops::cumprod opFF; + sd::ops::cumprod_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1},GradCheck::MEAN); + + ASSERT_TRUE(isGradCorrect); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, cumprod_bp_check_3) { + + auto x = NDArrayFactory::create('c', {4, 4}); + auto gradO = NDArrayFactory::create('c', {4, 4}); + + x.linspace(1); + + const OpArgsHolder argsHolderFF({&x}, {}, {1, 0}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {}, {1, 0}); + + sd::ops::cumprod opFF; + sd::ops::cumprod_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1},GradCheck::MEAN); + + ASSERT_TRUE(isGradCorrect); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, cumprod_bp_check_4) { + + auto x = NDArrayFactory::create('c', {4, 4}); + auto gradO = NDArrayFactory::create('c', {4, 4}); + + x.linspace(1); + + const OpArgsHolder argsHolderFF({&x}, {}, {0, 1}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {}, {0, 1}); + + sd::ops::cumprod opFF; + sd::ops::cumprod_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1},GradCheck::MEAN); + + ASSERT_TRUE(isGradCorrect); +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, cumsum_bp_check_2) { + + auto x = NDArrayFactory::create('c', {4, 4}); + auto gradO = NDArrayFactory::create('c', {4, 4}); + + x.linspace(1); + + const OpArgsHolder argsHolderFF({&x}, {}, {1, 1}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {}, {1, 1}); + + sd::ops::cumsum opFF; + sd::ops::cumsum_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1},GradCheck::MEAN); + + ASSERT_TRUE(isGradCorrect); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, cumprod_test1) { + + auto inputC = NDArrayFactory::create('c', {3, 5}, {1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.}); + auto axis = NDArrayFactory::create(1.); + + auto expFF = NDArrayFactory::create('c', {3, 5}, {1., 2., 6., 24., 120., 6., 42., 336., 3024., 30240.,11., 132.,1716., 24024.,360360.}); + auto expTF = NDArrayFactory::create('c', {3, 5}, {1, 1, 2, 6, 24,1, 6, 42, 336, 3024,1, 11, 132, 1716, 24024}); + + auto expFT = NDArrayFactory::create('c', {3, 5}, {120, 120, 60, 20, 5,30240, 5040, 720, 90, 10,360360, 32760, 2730, 210, 15}); //+++ + auto expTT = NDArrayFactory::create('c', {3, 5}, {120, 60, 20, 5, 1,5040, 720, 90, 10, 1,32760, 2730, 210, 15, 1}); + auto gradO = NDArrayFactory::create('c', {3, 5}); + + int exclusive, reverse; + + //************************************// + exclusive = 0; reverse = 0; + + const OpArgsHolder argsHolderFF({&inputC, &axis}, {}, {exclusive, reverse}); + const OpArgsHolder argsHolderBP({&inputC, &axis, &gradO}, {}, {exclusive, reverse}); + + sd::ops::cumprod opFF; + sd::ops::cumprod_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1}, {1, 1},GradCheck::MEAN); + + ASSERT_TRUE(isGradCorrect); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, cumprod_test2) { + + auto inputC = NDArrayFactory::create('c', {2, 2}); + auto axis = NDArrayFactory::create(1.); + + auto gradO = NDArrayFactory::create('c', {2, 2}); + + int exclusive, reverse; + + //************************************// + exclusive = 0; reverse = 0; + inputC.linspace(1); + const OpArgsHolder argsHolderFF({&inputC, &axis}, {}, {exclusive, reverse}); + const OpArgsHolder argsHolderBP({&inputC, &axis, &gradO}, {}, {exclusive, reverse}); + + sd::ops::cumprod opFF; + sd::ops::cumprod_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1, 1, 1}, {1, 1},GradCheck::MEAN); + + ASSERT_TRUE(isGradCorrect); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, prelu_test1) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + auto alpha = NDArrayFactory::create('c', {3, 4}, {-0.6f, -0.5f, -0.4f, -0.3f, -0.2f, -0.1f, 0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}, {7.2f, 5.5f, 4.f, 2.7f, 1.6f, 0.7f, 0.f, -0.5f,-0.8f, -0.9f, -0.8f, -0.5f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + + sd::ops::prelu op; + + auto result = op.evaluate({&x, &alpha}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, prelu_test2) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + auto alpha = NDArrayFactory::create('c', {3}, {-0.6f, 2.f, 4.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, prelu_test3) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + auto alpha = NDArrayFactory::create('c', {3,1}, {-0.6f, 2.f, 4.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, prelu_test4) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + auto alpha = NDArrayFactory::create('c', {1, 3}, {-0.6f, 2.f, 4.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}, {7.2f, 6.6f, 6.f, 5.4f, -16.f, -14.f, -12.f, -10.f, -16.f, -12.f, -8.f, -4.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, prelu_test5) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + auto alpha = NDArrayFactory::create('c', {4}, {-0.6f, 2.f, 4.f, -1.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}, {7.2f, -22.f, -40.f, 9.f, 4.8f, -14.f, -24.f, 5.f, 2.4f, -6.f, -8.f, 1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, prelu_test6) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + auto alpha = NDArrayFactory::create('c', {1,1,1}, {-2.}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}, {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {1,0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, prelu_test7) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + auto alpha = NDArrayFactory::create(-2.f); + auto exp = NDArrayFactory::create('c', {2, 3, 4}, {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {1,0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, prelu_test8) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + auto alpha = NDArrayFactory::create(-2.f); + auto exp = NDArrayFactory::create('c', {2, 3, 4}, {24.f, 22.f, 20.f, 18.f, 16.f, 14.f, 12.f, 10.f, 8.f, 6.f, 4.f, 2.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {1,0,1,0,1,0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, prelu_test9) { + + auto x = NDArrayFactory::create('c', {2, 4}, {-4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f}); + auto alpha = NDArrayFactory::create(-2.f); + auto exp = NDArrayFactory::create('c', {2, 4}, {8.f, 6.f, 4.f, 2.f,0.f, 1.f, 2.f, 3.f}); + + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, prelu_test10) { + + auto x = NDArrayFactory::create('c', {2, 4}, {-4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f}); + auto alpha = NDArrayFactory::create(-2.f); + auto exp = NDArrayFactory::create('c', {2, 4}, {8.f, 6.f, 4.f, 2.f,0.f, 1.f, 2.f, 3.f}); + + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, prelu_test11) { + + auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); + x.linspace(-50.); + auto alpha = NDArrayFactory::create('c', {4}, {0.f, -0.5f, 0.5f, -1.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4, 5}, {0.f, 0.f, 0.f, 0.f, 0.f, 22.5f, 22.f, 21.5f, 21.f, 20.5f, -20.f, -19.5f, -19.f, -18.5f, -18.f, 35.f, 34.f, 33.f, + 32.f, 31.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.5f, 12.f, 11.5f, 11.f, 10.5f, -10.f, -9.5f, -9.f, -8.5f, -8.f, 15.f, + 14.f, 13.f, 12.f, 11.f, 0.f, 0.f, 0.f, 0.f, 0.f, 2.5f, 2.f, 1.5f, 1.f, 0.5f, 0.f, 1.f, 2.f, 3.f, 4.f, + 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, + 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, + 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, + 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f}); + + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {1,3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, prelu_test12) { + + auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); + x.linspace(-50.); + auto alpha = NDArrayFactory::create('c', {3,5}, {-0.7f, -0.6f, -0.5f, -0.4f, -0.3f, -0.2f, -0.1f, 0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4, 5}, {35.f, 29.4f, 24.f, 18.8f, 13.8f, 31.5f, 26.4f, 21.5f, 16.8f, 12.3f, 28.f, 23.4f, 19.f, 14.8f, 10.8f, 24.5f, 20.4f, 16.5f, 12.8f, + 9.3f, 6.f, 2.9f, 0.f, -2.7f, -5.2f, 5.f, 2.4f, 0.f, -2.2f, -4.2f, 4.f, 1.9f, 0.f, -1.7f, -3.2f, 3.f, 1.4f, 0.f, -1.2f, + -2.2f, -3.f, -3.6f, -4.f, -4.2f, -4.2f, -1.5f, -1.6f, -1.5f, -1.2f, -0.7f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 52.f, + 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f}); + + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {-1, 2}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, prelu_test13) { + + auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); + x.linspace(-50.); + auto alpha = NDArrayFactory::create('c', {5,3}, {-0.7f, -0.6f, -0.5f, -0.4f, -0.3f, -0.2f, -0.1f, 0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4, 5}, {35.f, 29.4f, 24.f, 18.8f, 13.8f, 31.5f, 26.4f, 21.5f, 16.8f, 12.3f, 28.f, 23.4f, 19.f, 14.8f, 10.8f, 24.5f, 20.4f, 16.5f, 12.8f, + 9.3f, 6.f, 2.9f, 0.f, -2.7f, -5.2f, 5.f, 2.4f, 0.f, -2.2f, -4.2f, 4.f, 1.9f, 0.f, -1.7f, -3.2f, 3.f, 1.4f, 0.f, -1.2f, + -2.2f, -3.f, -3.6f, -4.f, -4.2f, -4.2f, -1.5f, -1.6f, -1.5f, -1.2f, -0.7f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, + 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, + 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 52.f, + 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f}); + + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {-1, 2}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, prelu_test14) { + + auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); + x.linspace(-50.); + auto alpha = NDArrayFactory::create('c', {2,10}, {-0.7f, -0.6f, -0.5f, -0.4f, -0.3f, -0.2f, -0.1f, 0.f, 0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4, 5}, {35.f, 29.4f, 24.f, 18.8f, 13.8f, 9.f, 4.4f, 0.f, -4.2f, -8.2f, -12.f, -15.6f, -19.f, -22.2f, -25.2f, -28.f, -30.6f, + -33.f,-35.2f, -37.2f, 21.f, 17.4f, 14.f, 10.8f, 7.8f, 5.f, 2.4f, 0.f, -2.2f, -4.2f, -6.f, -7.6f, -9.f, -10.2f, + -11.2f, -12.f, -12.6f, -13.f, -13.2f, -13.2f, 7.f, 5.4f, 4.f, 2.8f, 1.8f, 1.f, 0.4f, 0.f, -0.2f, -0.2f, 0.f, + 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, + 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, + 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, + 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f}); + + sd::ops::prelu op; + auto result = op.evaluate({&x, &alpha}, {}, {-2}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, compare_and_bitpack_test1) { + + auto x = NDArrayFactory::create('c', {2, 3, 16}, { + 0.865595f, 0.381197f, 0.911656f, 0.256752f, 0.084921f, 0.070434f, 0.469923f, 0.269935f, 0.510656f, 0.949777f, 0.926772f, 0.622540f, 0.688253f, 0.164974f, + 0.068558f, 0.031173f, 0.910035f, 0.219362f, 0.731336f, 0.135392f, 0.449875f, 0.020135f, 0.891820f, 0.907567f, 0.114376f, 0.652253f, 0.892939f, 0.698095f, + 0.423831f, 0.971155f, 0.968733f, 0.194465f, 0.852475f, 0.642962f, 0.417665f, 0.768379f, 0.753035f, 0.738440f, 0.046251f, 0.659487f, 0.486230f, 0.246724f, + 0.276700f, 0.103631f, 0.843105f, 0.562587f, 0.784459f, 0.109871f, 0.455828f, 0.129641f, 0.002471f, 0.148281f, 0.976162f, 0.603573f, 0.752530f, 0.249840f, + 0.723716f, 0.658430f, 0.661057f, 0.328042f, 0.338351f, 0.903157f, 0.485580f, 0.405103f, 0.335052f, 0.509858f, 0.764852f, 0.764527f, 0.382572f, 0.962121f, + 0.296145f, 0.602766f, 0.169683f, 0.750371f, 0.993936f, 0.914704f, 0.199342f, 0.858098f, 0.617198f, 0.219334f, 0.167574f, 0.305204f, 0.960773f, 0.537944f, + 0.245441f, 0.787276f, 0.968920f, 0.980918f, 0.615237f, 0.355165f, 0.480441f, 0.304282f, 0.961229f, 0.639195f, 0.017776f, 0.836153f + }); + auto threshold = NDArrayFactory::create(0.5f); + auto exp = NDArrayFactory::create('c', {2, 3, 2}, {160, 248, 163, 118, 221, 14, 14, 228, 117, 118, 55, 141}); + + sd::ops::compare_and_bitpack op; + auto result = op.evaluate({&x, &threshold}, {}, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + +} + +TEST_F(DeclarableOpsTests9, compare_and_bitpack_test2) { + + auto x = NDArrayFactory::create('c', {2, 3, 16}, { + true, false, true, false, false, false, false, false, true, + true, true, true, true, false, false, false, true, false, + true, false, false, false, true, true, false, true, true, + true, false, true, true, false, true, true, false, true, + true, true, false, true, false, false, false, false, true, + true, true, false, false, false, false, false, true, true, + true, false, true, true, true, false, false, true, false, + false, false, true, true, true, false, true, false, true, + false, true, true, true, false, true, true, false, false, + false, true, true, false, true, true, true, true, false, + false, false, true, true, false, true + }); + //threshold is ignored here ,actually + auto threshold = NDArrayFactory::create(true); + auto exp = NDArrayFactory::create('c', {2, 3, 2}, {160, 248, 163, 118, 221, 14, 14, 228, 117, 118, 55, 141}); + + sd::ops::compare_and_bitpack op; + auto result = op.evaluate({&x, &threshold}, {}, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, compare_and_bitpack_test3) { + + auto x = NDArrayFactory::create('c', {2, 0, 3, 16}); + auto threshold = NDArrayFactory::create(0.5f); + auto exp = NDArrayFactory::create('c', {2, 0, 3, 2}); + + sd::ops::compare_and_bitpack op; + auto result = op.evaluate({&x, &threshold}, {}, {}, {}); + auto output = result.at(0); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + output->printShapeInfo("output"); + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, compare_and_bitpack_test4) { + + auto x = NDArrayFactory::create('c', {2, 0, 3, 13}); + auto threshold = NDArrayFactory::create(0.5f); + sd::ops::compare_and_bitpack op; + + ASSERT_THROW(op.evaluate({&x, &threshold}, {}, {}, {}), std::invalid_argument); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, compare_and_bitpack_test5) { + + auto x = NDArrayFactory::create('c', {2, 0, 3, 13}); + auto threshold = NDArrayFactory::create(0.5f); + auto out = NDArrayFactory::create('c', {2, 0, 3, 1}); + sd::ops::compare_and_bitpack op; + + ASSERT_THROW(op.execute({&x, &threshold}, {&out}, {}, {}), std::invalid_argument); + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, compare_and_bitpack_test6) { + + auto x = NDArrayFactory::create('c', {2, 0, 3, 8}); + auto threshold = NDArrayFactory::create(0.5f); + auto out = NDArrayFactory::create('c', {2, 0, 3, 2}); + sd::ops::compare_and_bitpack op; + //shape mismatch throws runtime error + ASSERT_THROW(op.execute({&x, &threshold}, {&out}, {}, {}), std::runtime_error); + +} + +TEST_F(DeclarableOpsTests9, compare_and_bitpack_test7) { + constexpr int pp = 32*32*16; + constexpr int s1 = 3; + constexpr int t1 = 8; + std::vector shape1 = {pp}; + std::vector strides1 = {s1}; + std::vector shape2 = {pp/8}; + std::vector strides2 = {t1}; + ShapeDescriptor desc1 (DataType::BOOL, 'c', shape1, strides1, s1); + ShapeDescriptor desc2 (DataType::UINT8, 'c', shape2, strides2, t1); + auto x = NDArrayFactory::create(desc1); + auto output = NDArrayFactory::create(desc2); + auto exp = NDArrayFactory::create(desc2); + auto threshold = NDArrayFactory::create(true); + auto buff = x.bufferAsT(); + uint8_t *expBuff = exp.bufferAsT(); + //generate test + for(int l=0;l shape1 = {pp,pp,pp}; + std::vector strides1 = {s3 , s2 , s1}; + std::vector shape2 = {pp,pp,pp/8}; + std::vector strides2 = {t3 , t2 , t1}; + ShapeDescriptor desc1 (DataType::BOOL, 'c', shape1, strides1, 0); + ShapeDescriptor desc2 (DataType::UINT8, 'c', shape2, strides2, 0); + auto x = NDArrayFactory::create(desc1); + auto output = NDArrayFactory::create(desc2); + auto exp = NDArrayFactory::create(desc2); + auto threshold = NDArrayFactory::create(true); + auto buff = x.bufferAsT(); + uint8_t *expBuff = exp.bufferAsT(); + //generate test + for(int i=0;i('c', {2, 3, 4}, {-12.f, -11.f, -10.f, -9.f, -8.f, -7.f, -6.f, -5.f, -4.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}, {0.f, 0.f, 0.f, 0.f,0.f, 0.f, 0.f, 0.f,0.f, 0.f, 0.f, 0.f,0.f, 0.f, 0.f, 3.f,4.f, 5.f, 6.f, 7.f,8.f, 9.f,10.f,11.f}); + + sd::ops::thresholdedrelu op; + + auto result = op.evaluate({&x}, {theta}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, thresholdedrelu_test2) { + + const float theta = -2.f; + auto x = NDArrayFactory::create('c', {2, 3, 4}, {0.f,-4.f, -10.f, -8.f, 0.f, -9.f, -8.f, 5.f, 6.f, 6.f, 9.f, 6.f, -8.f, 5.f, 10.f, -2.f, 3.f, -7.f, 4.f, -8.f, -4.f, -9.f, -9.f, 3.f}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}, {0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 5.f, 6.f, 6.f, 9.f, 6.f, 0.f, 5.f, 10.f, 0.f, 3.f, 0.f, 4.f, 0.f, 0.f, 0.f, 0.f, 3.f}); + + sd::ops::thresholdedrelu op; + + auto result = op.evaluate({&x}, {theta}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto output = result.at(0); + + ASSERT_TRUE(exp.isSameShape(output)); + ASSERT_TRUE(exp.equalsTo(output)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, prelu_bp_test1) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12., -11., -10., -9., -8., -7., -6., -5., -4., -3., -2., -1., 0.5, 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.}); + auto alpha = NDArrayFactory::create('c', {3, 4}, {-0.6, -0.5, -0.4, -0.3, -0.2, -0.1, 0.5, 0.1, 0.2, 0.3, 0.4, 0.5}); + auto dLdO = NDArrayFactory::create('c', {2, 3, 4}); + + const OpArgsHolder argsHolderFF({&x, &alpha}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &alpha, &dLdO}, {}, {}); + + sd::ops::prelu opFF; + sd::ops::prelu_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, prelu_bp_test2) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}, {-12., -11., -10., -9., -8., -7., -6., -5., -4., -3., -2., -1., 0.5, 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.}); + auto alpha = NDArrayFactory::create('c', {4}, {-0.6, 2., 4., -1.}); + auto dLdO = NDArrayFactory::create('c', {2, 3, 4}); + + const OpArgsHolder argsHolderFF({&x, &alpha}, {}, {1}); + const OpArgsHolder argsHolderBP({&x, &alpha, &dLdO}, {}, {1}); + + sd::ops::prelu opFF; + sd::ops::prelu_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, prelu_bp_test3) { + + auto x = NDArrayFactory::create('c', {2, 3, 2, 5}); + x.linspace(-30.); + x.p(30, 0.5); // avoid zero, since it is points of discontinuity for prelu + auto alpha = NDArrayFactory::create('c', {5,3}, {-0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1, 0.5, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7}); + auto dLdO = NDArrayFactory::create('c', {2, 3, 2, 5}); + + const OpArgsHolder argsHolderFF({&x, &alpha}, {}, {-1, 2}); + const OpArgsHolder argsHolderBP({&x, &alpha, &dLdO}, {}, {-1, 2}); + + sd::ops::prelu opFF; + sd::ops::prelu_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, prelu_bp_test4) { + + auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); + x.linspace(-50.); + x.p(50, 0.5); // avoid zero, since it is points of discontinuity for prele + auto alpha = NDArrayFactory::create('c', {2,10}, {-0.7, -0.6, -0.5, -0.4, -0.3, -0.2, -0.1, 0.25, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1., 1.1, 1.2}); + auto dLdO = NDArrayFactory::create('c', {2, 3, 4, 5}); + + const OpArgsHolder argsHolderFF({&x, &alpha}, {}, {-2}); + const OpArgsHolder argsHolderBP({&x, &alpha, &dLdO}, {}, {-2}); + + sd::ops::prelu opFF; + sd::ops::prelu_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, thresholdedrelu_bp_test1) { + + const double theta = 0.15; + + auto x = NDArrayFactory::create('c', {2, 3, 4}, {1.2, 1.1, 1., 0.9, 0.8, -0.7, -0.6,-0.5,-0.4,-0.3,-0.2,-0.1, 0., 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, -0.9, -1.0, -1.1}); + auto dLdO = NDArrayFactory::create('c', {2, 3, 4}); + + const OpArgsHolder argsHolderFF({&x}, {theta}, {}); + const OpArgsHolder argsHolderBP({&x, &dLdO}, {theta}, {}); + + sd::ops::thresholdedrelu opFF; + sd::ops::thresholdedrelu_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, multiply_test1) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto y = NDArrayFactory::create('c', {4}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}, {0.1f, 0.4f, 0.9f, 1.6f, 0.5f, 1.2f, 2.1f, 3.2f, 0.9f, 2.f, 3.3f, 4.8f, 1.3f, 2.8f, 4.5f, 6.4f, 1.7f, 3.6f, 5.7f, 8.f, 2.1f, 4.4f, 6.9f, 9.6f}); + x.linspace(1.f); + y.linspace(0.1f, 0.1f); + + sd::ops::multiply op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, multiply_test2) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto y = NDArrayFactory::create(0.1); + auto exp = NDArrayFactory::create('c', {2, 3, 4}, {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.f, 1.1f, 1.2f, 1.3f, 1.4f, 1.5f, 1.6f, 1.7f, 1.8f, 1.9f, 2.f, 2.1f, 2.2f, 2.3f, 2.4f}); + x.linspace(1.f); + // y.linspace(0.1f, 0.1f); + + sd::ops::multiply op; + auto result = op.evaluate({&y, &x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, multiply_test3) { + + auto x = NDArrayFactory::create('c', {2, 1, 4}); + auto y = NDArrayFactory::create('c', {3,1}); + auto exp = NDArrayFactory::create('c', {2, 3, 4}, {0.1f, 0.2f, 0.3f, 0.4f, 0.2f, 0.4f, 0.6f, 0.8f, 0.3f, 0.6f, 0.9f, 1.2f, 0.5f, 0.6f, 0.7f, 0.8f, 1.f, 1.2f, 1.4f, 1.6f, 1.5f, 1.8f, 2.1f, 2.4f}); + x.linspace(1.f); + y.linspace(0.1f, 0.1f); + + sd::ops::multiply op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, multiply_test4) { + + auto x = NDArrayFactory::create('c', {1, 1}); + auto y = NDArrayFactory::create(0.1f); + auto exp = NDArrayFactory::create('c', {1, 1}, {0.1f}); + x.linspace(1.f); + + sd::ops::multiply op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, multiply_test5) { + + auto x = NDArrayFactory::create(1.f); + auto y = NDArrayFactory::create(0.1f); + auto exp = NDArrayFactory::create(0.1f); + + sd::ops::multiply op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, multiply_bp_test1) { + + auto x = NDArrayFactory::create('c', {1, 1}, {100.}); + auto y = NDArrayFactory::create(0.1); + auto dLdz = NDArrayFactory::create('c', {1, 1}); + + const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); + + sd::ops::multiply opFF; + sd::ops::multiply_bp opBP; + auto resFF = opFF.evaluate({&x, &y}, {}, {}); + auto resBP = opBP.evaluate({&x, &y, &dLdz}, {}, {}); +// resFF->at(0)->printIndexedBuffer("Multiply 1x1"); +// resBP->at(0)->printIndexedBuffer("Multiply BP 1x1 x"); +// resBP->at(1)->printIndexedBuffer("Multyply BP 1x1 y");*/ + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + ASSERT_TRUE(isGradCorrect); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, multiply_bp_test2) { + + auto x = NDArrayFactory::create('c', {2, 2}, {1.,2.,3.,4.}); + auto y = NDArrayFactory::create(0.1); + auto dLdz = NDArrayFactory::create('c', {2, 2}); + + const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); + + sd::ops::multiply opFF; + sd::ops::multiply_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, multiply_bp_test3) { + + auto y = NDArrayFactory::create('c', {2, 2}, {1.,2.,3.,4.}); + auto x = NDArrayFactory::create(0.1); + auto dLdz = NDArrayFactory::create('c', {2, 2}); + + const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); + + sd::ops::multiply opFF; + sd::ops::multiply_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, multiply_bp_test4) { + + auto x = NDArrayFactory::create('c', {2, 2}, {1.,2.,3.,4.}); + auto y = NDArrayFactory::create('c', {2, 2}, {0.1,0.2,0.3,0.4}); + auto dLdz = NDArrayFactory::create('c', {2, 2}); + + const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); + + sd::ops::multiply opFF; + sd::ops::multiply_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, multiply_bp_test5) { + + auto x = NDArrayFactory::create('c', {2, 2}, {1.,2.,3.,4.}); + auto y = NDArrayFactory::create('c', {2}, {0.1,0.2}); + auto dLdz = NDArrayFactory::create('c', {2, 2}); + + const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); + + sd::ops::multiply opFF; + sd::ops::multiply_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, multiply_bp_test6) { + + auto y = NDArrayFactory::create('c', {2, 2}, {1.,2.,3.,4.}); + auto x = NDArrayFactory::create('c', {2}, {0.1,0.2}); + auto dLdz = NDArrayFactory::create('c', {2, 2}); + + const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); + + sd::ops::multiply opFF; + sd::ops::multiply_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, multiply_bp_test7) { + + auto y = NDArrayFactory::create('c', {2, 3}, {1.,2.,3.,4.,5.,6.}); + auto x = NDArrayFactory::create('c', {2, 1}, {0.1,0.2}); + auto dLdz = NDArrayFactory::create('c', {2, 3}); + + const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); + + sd::ops::multiply opFF; + sd::ops::multiply_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, multiply_bp_test8) { + + auto y = NDArrayFactory::create('c', {2, 1, 4}); + auto x = NDArrayFactory::create('c', {1, 3, 4}); + auto dLdz = NDArrayFactory::create('c', {2, 3, 4}); + x.linspace(1., 0.5); + y.linspace(0.1, 0.05); + + const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); + + sd::ops::multiply opFF; + sd::ops::multiply_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, Floormod_BP_Test_2) { + + auto y = NDArrayFactory::create('c', {10, 10}); + auto x = NDArrayFactory::create('c', {10, 10}); + auto dLdz = NDArrayFactory::create('c', {10, 10}); + //auto eps = NDArrayFactory::create('c', {10, 10}); + x.linspace(4); //2., 2.0); + y.linspace(3); + dLdz.linspace(1); +// const OpArgsHolder argsHolderFF({&x, &y}, {}, {}); +// const OpArgsHolder argsHolderBP({&x, &y, &dLdz}, {}, {}); + +// sd::ops::floormod opFF; +// auto resFF = opFF.execute({&x, &y}, {}, {}); +// resFF->at(0)->printIndexedBuffer("FF floormod"); +// delete resFF; + sd::ops::floormod_bp opBP; + auto resBP = opBP.evaluate({&x, &y, &dLdz}, {}, {}); + ASSERT_TRUE(resBP.status() == ND4J_STATUS_OK); + +// resBP->at(0)->printIndexedBuffer("BP floormod /dx"); +// resBP->at(1)->printIndexedBuffer("BP floormod /dy"); + ASSERT_TRUE(dLdz.equalsTo(resBP.at(0))); + ASSERT_TRUE(dLdz.equalsTo(resBP.at(1))); + +// const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + +// ASSERT_TRUE(isGradCorrect); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, Dynamic_Partition_BP_1) { + + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto y = NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 1, 0, 2}); + auto dLdzX = NDArrayFactory::create('c', {2, 4}); + auto dLdzY = NDArrayFactory::create('c', {2, 4}); + auto dLdzZ = NDArrayFactory::create('c', {2, 4}); + auto exp = NDArrayFactory::create('c', {2,3,4}, {1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 3, 3, 3, 3}); + x.linspace(1); +// dLdzX.linspace(1); +// dLdzY.linspace(2); +// dLdzZ.linspace(3); + dLdzX.assign(1); + dLdzY.assign(2); + dLdzZ.assign(3); + + sd::ops::dynamic_partition op1; + auto res1 = op1.evaluate({&x, &y}, {}, {3}); + + sd::ops::dynamic_partition_bp op2; + auto res2 = op2.evaluate({&x, &y, &dLdzX, &dLdzY, &dLdzZ}, {}, {3}); + ASSERT_TRUE(res2.status() == ND4J_STATUS_OK); + ASSERT_TRUE(res2.size() == 2); +// printf("How many: %ul\n", res2->size()); +// res2->at(0)->printBuffer("Ouputput0"); +// res2->at(1)->printBuffer("Ouputput1"); + ASSERT_TRUE(res2.at(0)->equalsTo(exp)); + +} +////////////////////////////////////////////////////////////////////// +//TEST_F(DeclarableOpsTests9, Dynamic_Partition_BP_2) { +// +// auto x = NDArrayFactory::create('c', {2, 3, 4}); +// auto y = NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 1, 0, 2}); +// auto dLdzX = NDArrayFactory::create('c', {2, 4}); +// auto dLdzY = NDArrayFactory::create('c', {2, 4}); +// auto dLdzZ = NDArrayFactory::create('c', {2, 4}); +// x.linspace(1); +// dLdzX.linspace(1); +// dLdzY.linspace(1); +// dLdzZ.linspace(1); +// +// const OpArgsHolder argsHolderFF({&x, &y}, {}, {3}); +// const OpArgsHolder argsHolderBP({&x, &y, &dLdzX, &dLdzY, &dLdzZ}, {}, {3}); +// +// sd::ops::dynamic_partition opFF; +// sd::ops::dynamic_partition_bp opBP; +// +// const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); +// +// ASSERT_TRUE(isGradCorrect); +//} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, Floormod_BP_Test_4) { + + auto x = NDArrayFactory::create('c', {2, 1, 3}, {2.0, 6.0, -3.0, 2.0, 6.0, -3.0}); + auto y = NDArrayFactory::create('c', {1, 3}, {-3.0, 2.0, -2.0}); + auto exp = NDArrayFactory::create('c', {1, 3}, {-1., 0., -1.}); + auto eps = NDArrayFactory::create('c', {2, 1, 3}); + eps.assign(1.f); + sd::ops::floormod_bp op; + + auto result = op.evaluate({&x, &y, &eps}, {}, {}); + + ASSERT_TRUE(result.size() == 2); + auto gradX = result.at(0); + auto gradY = result.at(1); + +// gradX->printIndexedBuffer("gradX"); +// gradY->printIndexedBuffer("gradY"); + ASSERT_TRUE(exp.isSameShape(gradY)); + + ASSERT_TRUE(exp.equalsTo(gradY)); + +} + + +/* +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, gru_cell_bp_test1) { + + const int bS = 2; + const int iS = 3; + const int nU = 4; + + NDArray x('c', {bS, iS}, sd::DataType::DOUBLE); + NDArray hi('c', {bS, nU}, sd::DataType::DOUBLE); + NDArray W('c', {iS+nU, 2*nU}, sd::DataType::DOUBLE); + NDArray Wc('c', {iS+nU, nU}, sd::DataType::DOUBLE); + NDArray b('c', {2*nU}, sd::DataType::DOUBLE); + NDArray bc('c', {nU}, sd::DataType::DOUBLE); + NDArray dLdr('c', {bS, nU}, sd::DataType::DOUBLE); + NDArray dLdu('c', {bS, nU}, sd::DataType::DOUBLE); + NDArray dLdc('c', {bS, nU}, sd::DataType::DOUBLE); + NDArray dLdh('c', {bS, nU}, sd::DataType::DOUBLE); + + x.linspace(-5, 0.5); + hi = 1.; + W = 0.003; + Wc = 0.006; + b = 0.5; + bc = 0.35; + + + const OpArgsHolder argsHolderFF({&x, &hi, &W, &Wc, &b, &bc}, {}, {}); + sd::ops::gruCell op; + auto results = op.evaluate(argsHolderFF); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + auto u = results.at(1); // [bS, nU] + auto c = results.at(2); // [bS, nU] + auto h = results.at(3); // [bS, nU] + + dLdh = 1.; // SUM loss + + NDArray Wch = Wc({iS,iS+nU, 0,0}); // [nU, nU] + NDArray dhdc = 1. - *u; + NDArray dhdu = hi - *c; + NDArray dcdZc = 1. - *c * *c; + dLdc.assign(dLdh * dhdc); + dLdu.assign(dLdh * dhdu); + dLdr.assign(mmul(dLdc * dcdZc * hi, Wch.transpose())); + + + + + const OpArgsHolder argsHolderBP({&x, &hi, &W, &Wc, &b, &bc, &dLdr, &dLdu, &dLdc, &dLdh}, {}, {}); + + sd::ops::gruCell opFF; + sd::ops::gruCell_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1, 1, 1 , 1, 1}, {0., 1.}, sd::GradCheck::LossFunc::SUM, true); + + ASSERT_TRUE(isGradCorrect); +} +*/ + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, Cholesky_Test_1) { + + NDArray x = NDArrayFactory::create('c', {3, 3}, {4,12,-16, 12 ,37,-43, -16, -43, 98}); + NDArray exp = NDArrayFactory::create('c', {3,3}, {2., 0., 0., 6., 1., 0., -8., 5., 3.}); + + sd::ops::cholesky op; + + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto res = result.at(0); +// res->printIndexedBuffer("Output for Cholesky1"); + ASSERT_TRUE(exp.equalsTo(res)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, Cholesky_Test_2) { + + NDArray x = NDArrayFactory::create('c', {2, 3, 3}, {4, 12,-16, 12 ,37,-43, -16, -43, 98, 1, 1, 1, 1, 2, 2, 1, 2., 6}); + NDArray exp = NDArrayFactory::create('c', {2, 3, 3}, {2., 0., 0., 6., 1., 0., -8., 5., 3., 1., 0., 0., 1., 1., 0,1., 1., 2.}); + + sd::ops::cholesky op; + + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto res = result.at(0); +// res->printIndexedBuffer("Output for Cholesky 2"); + ASSERT_TRUE(exp.equalsTo(res)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests9, Cholesky_Test_3) { + + NDArray x = NDArrayFactory::create('c', {2, 3, 3}, {4.f, 12.f, -16.f, 12.f, 37.f, -43.f, -16.f, -43.f, 98.f, 1.f, 1.f, 1.f, 1.f, 2.f, 2.f, 1.f, 2.f, 6.f}); + NDArray exp = NDArrayFactory::create('c', {2, 3, 3}, {2.f, 0.f, 0.f, 6.f, 1.f, 0.f, -8.f, 5.f, 3.f, 1.f, 0.f, 0.f, 1.f, 1.f, 0.f, 1.f, 1.f, 2.f}); + + sd::ops::cholesky op; + + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(result.status(), ND4J_STATUS_OK); + auto res = result.at(0); + // res->printIndexedBuffer("Output for Cholesky 3"); + ASSERT_TRUE(exp.equalsTo(res, 1e-4)); + +} + +//////////////////////////////////////////////////////////////////// +// TEST_F(DeclarableOpsTests9, gru_bp_test1) { + +// const int time = 5; +// const int bS = 2; +// const int iS = 3; +// const int nU = 4; + +// NDArray x ('c', {time, bS, iS}); +// NDArray h0 ('c', {bS, nU}); +// NDArray Wx ('c', {iS, 3*nU}); +// NDArray Wh ('c', {nU, 3*nU}); +// NDArray b ('c', {3*nU}); +// NDArray dLdh ('c', {time, bS, nU}); + +// x.linspace(0.5, 0.5); +// h0 = 1.; +// Wx = 0.003; +// Wh = 0.006; +// b = 0.5; + +// const OpArgsHolder argsHolderFF({&x, &h0, &Wx, &Wh, &b}, {}, {}); +// const OpArgsHolder argsHolderBP({&x, &h0, &Wx, &Wh, &b, &dLdh}, {}, {}); + +// sd::ops::gru opFF; +// sd::ops::gru_bp opBP; + +// const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + +// ASSERT_TRUE(isGradCorrect); +// } + +// diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTestsCuda1.cu b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTestsCuda1.cu new file mode 100644 index 000000000..db8da9e61 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/DeclarableOpsTestsCuda1.cu @@ -0,0 +1,78 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include + + +using namespace sd; + + +class DeclarableOpsTestsCuda1 : public testing::Test { +public: + + DeclarableOpsTestsCuda1() { + printf("\n"); + fflush(stdout); + } +}; + + +TEST_F(DeclarableOpsTestsCuda1, Test_CHOOSE_SCALAR_LARGE) { + double inputData[150] = { + 0, 0.51, 0.68, 0.69, 0.86, 0.91, 0.96, 0.97, 0.97, 1.03, 1.13, 1.16, 1.16, 1.17, 1.19, 1.25, 1.25, 1.26, 1.27, 1.28, 1.29, 1.29, 1.29, 1.30, 1.31, 1.32, 1.33, 1.33, 1.35, 1.35, 1.36, 1.37, 1.38, 1.40, 1.41, 1.42, 1.43, 1.44, 1.44, 1.45, 1.45, 1.47, 1.47, 1.51, 1.51, 1.51, 1.52, 1.53, 1.56, 1.57, 1.58, 1.59, 1.61, 1.62, 1.63, 1.63, 1.64, 1.64, 1.66, 1.66, 1.67, 1.67, 1.70, 1.70, 1.70, 1.72, 1.72, 1.72, 1.72, 1.73, 1.74, 1.74, 1.76, 1.76, 1.77, 1.77, 1.80, 1.80, 1.81, 1.82, 1.83, 1.83, 1.84, 1.84, 1.84, 1.85, 1.85, 1.85, 1.86, 1.86, 1.87, 1.88, 1.89, 1.89, 1.89, 1.89, 1.89, 1.91, 1.91, 1.91, 1.92, 1.94, 1.95, 1.97, 1.98, 1.98, 1.98, 1.98, 1.98, 1.99, 2, 2, 2.01, 2.01, 2.02, 2.03, 2.03, 2.03, 2.04, 2.04, 2.05, 2.06, 2.07, 2.08, 2.08, 2.08, 2.08, 2.09, 2.09, 2.10, 2.10, 2.11, 2.11, 2.11, 2.12, 2.12, 2.13, 2.13, 2.14, 2.14, 2.14, 2.14, 2.15, 2.15, 2.16, 2.16, 2.16, 2.16, 2.16, 2.17 + }; + + auto precursor = NDArrayFactory::create(inputData,'c',{1,149}); + NDArray x(nullptr, precursor.specialBuffer(), precursor.shapeInfo()); + + sd::ops::choose op; + //greater than test + auto result = op.evaluate({&x}, {0.0},{3}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(1); + + ASSERT_EQ(148,z->e(0)); + //ASSERT_TRUE(exp.isSameShape(z)); +} + +/* +TEST_F(DeclarableOpsTestsCuda1, Test_Reverse_TAD_1) { + auto x = NDArrayFactory::create('c', {1, 3, 608, 608}); + auto z = x.like(); + x.linspace(1.0f); + + sd::ops::reverse op; + auto timeStart = std::chrono::system_clock::now(); + auto status = op.execute({&x}, {&z}, {}, {1}, {}); + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); + nd4j_printf("exec time: %lld us\n", outerTime); + ASSERT_EQ(Status::OK(), status); +} +*/ \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/EmptyTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/EmptyTests.cpp new file mode 100644 index 000000000..28c060757 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/EmptyTests.cpp @@ -0,0 +1,256 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver on 6/18/2018. +// + +#include "testlayers.h" +#include +#include +// #include + +using namespace sd; + + +class EmptyTests : public testing::Test { +public: + + EmptyTests() { + printf("\n"); + fflush(stdout); + } +}; + +TEST_F(EmptyTests, Test_Create_Empty_1) { + auto empty = NDArrayFactory::empty_(); + ASSERT_TRUE(empty->isEmpty()); + + ASSERT_EQ(0, empty->lengthOf()); + ASSERT_TRUE(empty->buffer() == nullptr); + + ASSERT_TRUE(shape::isEmpty(empty->shapeInfo())); + + delete empty; +} + +TEST_F(EmptyTests, Test_Create_Empty_2) { + auto empty = NDArrayFactory::empty(); + ASSERT_TRUE(empty.isEmpty()); + + ASSERT_EQ(0, empty.lengthOf()); + ASSERT_TRUE(empty.buffer() == nullptr); + + ASSERT_TRUE(shape::isEmpty(empty.shapeInfo())); + ASSERT_TRUE(empty.isEmpty()); +} + +TEST_F(EmptyTests, Test_Concat_1) { +// auto empty = NDArrayFactory::empty_(); + auto empty = new NDArray('c', {0}, sd::DataType::FLOAT32);//NDArrayFactory::create_('c', {(Nd4jLong)0}}; + auto vector = NDArrayFactory::create_('c', {1}, {1.0f}); + + ASSERT_TRUE(empty->isEmpty()); + + sd::ops::concat op; + auto result = op.evaluate({empty, vector}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + +// z->printShapeInfo("z shape"); +// z->printIndexedBuffer("z buffr"); + + ASSERT_EQ(*vector, *z); + + delete empty; + delete vector; +} + + +TEST_F(EmptyTests, Test_Concat_2) { + auto empty = new NDArray('c', {0}, sd::DataType::FLOAT32); //NDArrayFactory::empty_(); + auto scalar1 = NDArrayFactory::create_('c', {1}, {1.0f}); + auto scalar2 = NDArrayFactory::create_('c', {1}, {2.0f}); + auto exp = NDArrayFactory::create('c', {2}, {1.f, 2.f}); + + ASSERT_TRUE(empty->isEmpty()); + + sd::ops::concat op; + auto result = op.evaluate({empty, scalar1, scalar2}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + +// z->printShapeInfo("z shape"); +// z->printIndexedBuffer("z buffr"); + + ASSERT_EQ(exp, *z); + + delete empty; + delete scalar1; + delete scalar2; +} + +TEST_F(EmptyTests, Test_Concat_3) { + auto empty = NDArrayFactory::empty(); //NDArrayFactory::empty_(); + auto scalar1 = NDArrayFactory::create(1.0f); + auto scalar2 = NDArrayFactory::create(2.0f); + auto exp = NDArrayFactory::create('c', {2}, {1.f, 2.f}); + + ASSERT_TRUE(empty.isEmpty()); + + sd::ops::concat op; + auto result = op.evaluate({&empty, &scalar1, &scalar2}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_EQ(exp, *z); + +} + +TEST_F(EmptyTests, Test_Concat_4) { + auto empty = NDArrayFactory::empty(); //NDArrayFactory::empty_(); + auto scalar1 = NDArrayFactory::create(1.0f); + auto scalar2 = NDArrayFactory::create(2.0f); + auto exp = NDArrayFactory::create('c', {2}, {1.f, 2.f}); + + ASSERT_TRUE(empty.isEmpty()); + + sd::ops::concat op; + auto result = op.evaluate({&scalar1, &empty, &scalar2}, {}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_EQ(exp, *z); +} + +TEST_F(EmptyTests, Test_dup_1) { + auto empty = NDArrayFactory::empty(); + auto dup = new NDArray(empty.dup()); + + ASSERT_TRUE(dup->isEmpty()); + ASSERT_EQ(empty, *dup); + + delete dup; +} + +TEST_F(EmptyTests, test_empty_scatter_1) { + auto x = NDArrayFactory::create('c', {5}); + auto indices = NDArrayFactory::create('c', {0}); + auto updates = NDArrayFactory::create('c', {0}); + + x.linspace(1.0f); + + sd::ops::scatter_upd op; + auto result = op.evaluate({&x, &indices, &updates}, {}, {}, {true}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_EQ(x, *z); + +} + +TEST_F(EmptyTests, test_empty_scatter_2) { + NDArray x ('c', {5}, sd::DataType::FLOAT32); + NDArray z ('c', {5}, sd::DataType::FLOAT32); + auto indices = NDArrayFactory::create('c', {0}); + auto updates = NDArrayFactory::create('c', {0}); + + x.linspace(1.0f); + + sd::ops::scatter_upd op; + auto status = op.execute({&x, &indices, &updates}, {&z}, {}, {}, {true}); + + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(x, z); +} + +TEST_F(EmptyTests, test_shaped_empty_1) { + auto empty = NDArrayFactory::create('c', {2, 0, 3}); + std::vector shape = {2, 0, 3}; + + ASSERT_EQ(sd::DataType::FLOAT32, empty.dataType()); + ASSERT_EQ(0, empty.lengthOf()); + ASSERT_TRUE(empty.isEmpty()); + ASSERT_EQ(shape, empty.getShapeAsVector()); + ASSERT_EQ(3, empty.rankOf()); +} + +TEST_F(EmptyTests, test_shaped_empty_2) { + auto empty = NDArrayFactory::create('c', {0, 3}); + std::vector shape = {0, 3}; + + ASSERT_EQ(sd::DataType::FLOAT32, empty.dataType()); + ASSERT_EQ(0, empty.lengthOf()); + ASSERT_TRUE(empty.isEmpty()); + ASSERT_EQ(shape, empty.getShapeAsVector()); + ASSERT_EQ(2, empty.rankOf()); +} + +TEST_F(EmptyTests, test_shaped_empty_3) { + auto empty = NDArrayFactory::create('c', {0}); + std::vector shape = {0}; + + ASSERT_EQ(sd::DataType::FLOAT32, empty.dataType()); + ASSERT_EQ(0, empty.lengthOf()); + ASSERT_TRUE(empty.isEmpty()); + ASSERT_EQ(shape, empty.getShapeAsVector()); + ASSERT_EQ(1, empty.rankOf()); +} + +TEST_F(EmptyTests, test_shaped_empty_4) { + const auto shape = ConstantShapeHelper::getInstance().vectorShapeInfo(0, sd::DataType::FLOAT32); + NDArray array(shape, true, sd::LaunchContext::defaultContext()); + std::vector shapeOf({0}); + + ASSERT_TRUE(array.isEmpty()); + ASSERT_EQ(1, array.rankOf()); + ASSERT_EQ(shapeOf, array.getShapeAsVector()); +} + + +TEST_F(EmptyTests, test_empty_matmul_1) { + auto x = NDArrayFactory::create('c', {0, 1}); + auto y = NDArrayFactory::create('c', {1, 0}); + auto e = NDArrayFactory::create('c', {0, 0}); + + sd::ops::matmul op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_EQ(e, *z); + +} + +TEST_F(EmptyTests, test_empty_matmul_2) { + auto x = NDArrayFactory::create('c', {1, 0, 4}); + auto y = NDArrayFactory::create('c', {1, 4, 0}); + auto e = NDArrayFactory::create('c', {1, 0, 0}); + + sd::ops::matmul op; + auto result = op.evaluate({&x, &y}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_EQ(e, *z); +} diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ExtraArgumentsTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ExtraArgumentsTests.cpp new file mode 100644 index 000000000..fdcecf4a4 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ExtraArgumentsTests.cpp @@ -0,0 +1,68 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include + +using namespace sd; + +class ExtraArgumentsTests : public testing::Test { +public: + + ExtraArgumentsTests() { + printf("\n"); + fflush(stdout); + } +}; + +TEST_F(ExtraArgumentsTests, Basic_Test_1) { + if (!Environment::getInstance().isCPU()) + return; + + ExtraArguments args({1.0, 2.0, 3.0}); + + float ef[] = {1.f, 2.f, 3.f}; + double ed[] = {1., 2., 3.}; + + auto ptrFloat = reinterpret_cast(args.argumentsAsT()); + auto ptrDouble = reinterpret_cast(args.argumentsAsT()); + ASSERT_TRUE(ptrFloat != nullptr); + ASSERT_TRUE(ptrDouble != nullptr); + + for (int e = 0; e < 3; e++) { + ASSERT_NEAR(ef[e], ptrFloat[e], 1e-5f); + } + + for (int e = 0; e < 3; e++) { + ASSERT_NEAR(ed[e], ptrDouble[e], 1e-5); + } +} + + +TEST_F(ExtraArgumentsTests, Basic_Test_2) { + ExtraArguments args; + + auto ptrInt = args.argumentsAsT(); + ASSERT_TRUE(ptrInt == nullptr); +} + diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/FlatBuffersTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/FlatBuffersTests.cpp new file mode 100644 index 000000000..816ec3f92 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/FlatBuffersTests.cpp @@ -0,0 +1,817 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace sd; +using namespace sd::graph; + +class FlatBuffersTest : public testing::Test { +public: + int alpha = 0; + + Nd4jLong *cShape = new Nd4jLong[8]{2, 2, 2, 2, 1, 8192, 1, 99}; + Nd4jLong *fShape = new Nd4jLong[8]{2, 2, 2, 1, 2, 8192, 1, 102}; + + FlatBuffersTest() { + Environment::getInstance().setDebug(false); + Environment::getInstance().setVerbose(false); + Environment::getInstance().setProfiling(false); + } + + ~FlatBuffersTest() { + Environment::getInstance().setDebug(false); + Environment::getInstance().setVerbose(false); + Environment::getInstance().setProfiling(false); + + delete[] cShape; + delete[] fShape; + } +}; + +/** + * Simple test that creates Node & reads it + */ +TEST_F(FlatBuffersTest, BasicTest1) { + flatbuffers::FlatBufferBuilder builder(1024); + + auto name = builder.CreateString("wow"); + + auto node = CreateFlatNode(builder, -1, name, OpType_TRANSFORM_SAME, transform::Ones, {0}); + + builder.Finish(node); + + // now we have our buffer with data + uint8_t *buf = builder.GetBufferPointer(); + int size = builder.GetSize(); + ASSERT_TRUE(size > 0); + + + + auto restored = GetFlatNode(buf); + + auto gA = new Node(restored); + auto gB = new Node(restored); + + ASSERT_TRUE(gA->equals(gB)); + + delete gA; + delete gB; +} + +/* +TEST_F(FlatBuffersTest, FlatGraphTest1) { + flatbuffers::FlatBufferBuilder builder(4096); + + auto array = NDArrayFactory::create_('c', {5, 5}); + array->assign(-2.0f); + + auto fShape = builder.CreateVector(array->getShapeInfoAsFlatVector()); + auto fBuffer = builder.CreateVector(array->asByteVector()); + + auto fArray = CreateFlatArray(builder, fShape, fBuffer, sd::graph::DType::DType_FLOAT); + auto fVid = CreateIntPair(builder, -1); + + auto fVar = CreateFlatVariable(builder, fVid, 0, sd::graph::DType::DType_FLOAT, 0, fArray); + + std::vector outputs1, outputs2, inputs1, inputs2; + outputs1.push_back(2); + outputs2.push_back(0); + + inputs1.push_back(-1); + inputs2.push_back(1); + + + auto vec1 = builder.CreateVector(outputs1); + auto vec2 = builder.CreateVector(outputs2); + + auto in1 = builder.CreateVector(inputs1); + auto in2 = builder.CreateVector(inputs2); + + auto name1 = builder.CreateString("wow1"); + auto name2 = builder.CreateString("wow2"); + + auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM_SAME, transform::Abs, 0, in1, 0, vec1); + auto node2 = CreateFlatNode(builder, 2, name2, OpType_TRANSFORM_STRICT, transform::Cosine, 0, in2, 0, vec2); + + std::vector> variables_vector; + variables_vector.push_back(fVar); + + std::vector> nodes_vector; + + nodes_vector.push_back(node1); + nodes_vector.push_back(node2); + + auto nodes = builder.CreateVector(nodes_vector); + + auto variables = builder.CreateVector(variables_vector); + + FlatGraphBuilder graphBuilder(builder); + + graphBuilder.add_variables(variables); + graphBuilder.add_id(119); + graphBuilder.add_nodes(nodes); + + auto flatGraph = graphBuilder.Finish(); + + builder.Finish(flatGraph); + + uint8_t *buf = builder.GetBufferPointer(); + int size = builder.GetSize(); + ASSERT_TRUE(size > 0); + + + auto restoredGraph = GetFlatGraph(buf); + ASSERT_EQ(119, restoredGraph->id()); + ASSERT_EQ(2, restoredGraph->nodes()->size()); + + // checking op nodes + ASSERT_EQ(transform::Abs, restoredGraph->nodes()->Get(0)->opNum()); + ASSERT_EQ(transform::Cosine, restoredGraph->nodes()->Get(1)->opNum()); + ASSERT_EQ(transform::Abs, restoredGraph->nodes()->Get(0)->opNum()); + + // checking variables + ASSERT_EQ(1, restoredGraph->variables()->size()); + ASSERT_EQ(-1, restoredGraph->variables()->Get(0)->id()->first()); + + // nd4j_printf("-------------------------\n",""); + + Graph graph(restoredGraph); + + // graph.printOut(); + + ASSERT_EQ(2, graph.totalNodes()); + ASSERT_EQ(1, graph.rootNodes()); + + + auto vs = graph.getVariableSpace(); + + ASSERT_EQ(OutputMode_IMPLICIT, graph.getExecutorConfiguration()->_outputMode); + + ASSERT_EQ(3, vs->totalEntries()); + ASSERT_EQ(1, vs->externalEntries()); + ASSERT_EQ(2, vs->internalEntries()); + + auto var = vs->getVariable(-1)->getNDArray(); + + ASSERT_TRUE(var != nullptr); + ASSERT_EQ(-2.0, var->reduceNumber(reduce::Mean).e(0)); + + sd::graph::GraphExecutioner::execute(&graph); + + auto resultWrapper = sd::graph::GraphExecutioner::executeFlatBuffer((Nd4jPointer) buf); + + auto flatResults = GetFlatResult(resultWrapper->pointer()); + + ASSERT_EQ(1, flatResults->variables()->size()); + ASSERT_TRUE(flatResults->variables()->Get(0)->name() != nullptr); + ASSERT_TRUE(flatResults->variables()->Get(0)->name()->c_str() != nullptr); + //nd4j_printf("VARNAME: %s\n", flatResults->variables()->Get(0)->name()->c_str()); + + auto var0 = new Variable(flatResults->variables()->Get(0)); + //auto var1 = new Variable(flatResults->variables()->Get(1)); + auto avg = var0->getNDArray()->reduceNumber(reduce::Mean); + avg.printIndexedBuffer("FBT_1"); + ASSERT_NEAR(-0.4161468, avg.e(0), 1e-5); + + //ASSERT_TRUE(var->equalsTo(var0->getNDArray())); + + delete array; + delete var0; + delete resultWrapper; +} +*/ +TEST_F(FlatBuffersTest, ExecutionTest1) { + auto gA = new Node(OpType_TRANSFORM_SAME); + + auto c = new float[4] {-1, -2, -3, -4}; + auto array = new NDArray(c, cShape); + + auto e = new float[4] {1, 2, 3, 4}; + auto exp = new NDArray(e, cShape); + + //gA->execute(array, nullptr, array); + + //ASSERT_TRUE(exp->equalsTo(array)); + + delete gA; + delete[] c; + delete array; + delete[] e; + delete exp; +} + +/* +TEST_F(FlatBuffersTest, ExplicitOutputTest1) { + flatbuffers::FlatBufferBuilder builder(4096); + + auto x = NDArrayFactory::create_(5, 5, 'c'); + x->assign(-2.0f); + + auto fXShape = builder.CreateVector(x->getShapeInfoAsVector()); + auto fXBuffer = builder.CreateVector(x->asByteVector()); + auto fXArray = CreateFlatArray(builder, fXShape, fXBuffer); + auto fXid = CreateIntPair(builder, -1); + + auto fXVar = CreateFlatVariable(builder, fXid, 0, 0, fXArray); + + + auto y = NDArrayFactory::create_(5, 5, 'c'); + y->assign(-1.0f); + + auto fYShape = builder.CreateVector(y->getShapeInfoAsVector()); + auto fYBuffer = builder.CreateVector(y->asByteVector()); + auto fYArray = CreateFlatArray(builder, fYShape, fYBuffer); + auto fYid = CreateIntPair(builder, -2); + + auto fYVar = CreateFlatVariable(builder, fYid, 0, 0, fYArray); + + + std::vector> inputs1, outputs1, outputs; + inputs1.push_back(CreateIntPair(builder, -1)); + inputs1.push_back(CreateIntPair(builder, -2)); + + outputs.push_back(CreateIntPair(builder, -1)); + outputs.push_back(CreateIntPair(builder, -2)); + + auto out1 = builder.CreateVector(outputs1); + auto in1 = builder.CreateVector(inputs1); + auto o = builder.CreateVector(outputs); + + auto name1 = builder.CreateString("wow1"); + + auto node1 = CreateFlatNode(builder, 1, name1, OpType_TRANSFORM, 0, in1, 0, sd::graph::DType::FLOAT); + + std::vector> variables_vector; + variables_vector.push_back(fXVar); + variables_vector.push_back(fYVar); + + std::vector> nodes_vector; + nodes_vector.push_back(node1); + + + + auto nodes = builder.CreateVector(nodes_vector); + auto variables = builder.CreateVector(variables_vector); + + FlatGraphBuilder graphBuilder(builder); + + graphBuilder.add_variables(variables); + graphBuilder.add_id(119); + graphBuilder.add_nodes(nodes); + graphBuilder.add_outputs(o); + + + auto flatGraph = graphBuilder.Finish(); + builder.Finish(flatGraph); + + auto restoredGraph = new Graph(GetFlatGraph(builder.GetBufferPointer())); + + GraphExecutioner::execute(restoredGraph); + + auto results = restoredGraph->fetchOutputs(); + + // IMPLICIT is default + ASSERT_EQ(1, results->size()); + + //ASSERT_NEAR(-2.0, results->at(0)->getNDArray()->reduceNumber>(), 1e-5); + //ASSERT_NEAR(-1.0, results->at(1)->getNDArray()->reduceNumber>(), 1e-5); + ASSERT_NEAR(-3.0, results->at(0)->getNDArray()->reduceNumber>(), 1e-5); + + //ASSERT_EQ(-1, results->at(0)->id()); + //ASSERT_EQ(-2, results->at(1)->id()); + + delete restoredGraph; + delete results; + delete x; + delete y; +} +*/ + +/* +TEST_F(FlatBuffersTest, ReadFile1) { + + uint8_t* data = sd::graph::readFlatBuffers("./resources/adam_sum.fb"); + + auto fg = GetFlatGraph(data); + auto restoredGraph = new Graph(fg); + + ASSERT_EQ(1, restoredGraph->rootNodes()); + ASSERT_EQ(2, restoredGraph->totalNodes()); + + auto ones = restoredGraph->getVariableSpace()->getVariable(-1)->getNDArray(); + + ASSERT_EQ(4, ones->lengthOf()); + ASSERT_NEAR(4.0f, ones->template reduceNumber>(), 1e-5); + + Nd4jStatus status = GraphExecutioner::execute(restoredGraph); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto result = restoredGraph->getVariableSpace()->getVariable(2)->getNDArray(); + ASSERT_EQ(1, result->lengthOf()); + ASSERT_EQ(8, result->e(0)); + + delete[] data; + delete restoredGraph; +} + +TEST_F(FlatBuffersTest, ReadFile2) { + uint8_t* data = sd::graph::readFlatBuffers("./resources/adam_sum.fb"); + Nd4jPointer result = GraphExecutioner::executeFlatBuffer((Nd4jPointer) data); + + ResultSet arrays(GetFlatResult(result)); + + ASSERT_EQ(1, arrays.size()); + ASSERT_EQ(1, arrays.at(0)->lengthOf()); + ASSERT_EQ(8, arrays.at(0)->e(0)); + + delete[] data; + delete[] (char *) result; +} + +TEST_F(FlatBuffersTest, ReadFile3) { + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/adam_sum.fb"); + Nd4jStatus status = GraphExecutioner::execute(graph); + + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto z = graph->getVariableSpace()->getVariable(2)->getNDArray(); + + ASSERT_EQ(1, z->lengthOf()); + ASSERT_EQ(8, z->e(0)); + + delete graph; +} + + +TEST_F(FlatBuffersTest, ReadInception1) { + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/inception.fb"); + + Nd4jStatus status = GraphExecutioner::execute(graph); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(227)); + + auto lastNode = graph->getVariableSpace()->getVariable(227)->getNDArray(); + + lastNode->printShapeInfo("Result shape"); + + auto argMax = lastNode->argMax(); + + //nd4j_printf("Predicted class: %i\n", (int) argMax); + //nd4j_printf("Probability: %f\n", lastNode->e(argMax)); + //nd4j_printf("Probability ipod: %f\n", lastNode->e(980)); + //lastNode->printBuffer("Whole output"); + + ASSERT_EQ(561, (int) argMax); + + delete graph; +} + +TEST_F(FlatBuffersTest, ReadLoops_3argsWhile_1) { + // TF graph: + // https://gist.github.com/raver119/b86ef727e9a094aab386e2b35e878966 + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/three_args_while.fb"); + + ASSERT_TRUE(graph != nullptr); + + //graph->printOut(); + + auto expPhi('c', {2, 2}); + + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(-1)); + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(-2)); + + auto phi = graph->getVariableSpace()->getVariable(-2)->getNDArray(); + auto constA = graph->getVariableSpace()->getVariable(-5)->getNDArray(); + auto lessY = graph->getVariableSpace()->getVariable(-6)->getNDArray(); + + //constA->printBuffer("constA"); + //lessY->printBuffer("lessY"); + + ASSERT_TRUE(expPhi.isSameShape(phi)); + + Nd4jStatus status = GraphExecutioner::execute(graph); + + ASSERT_EQ(ND4J_STATUS_OK, status); + + // now, we expect some values + + auto x = graph->getVariableSpace()->getVariable(20); + auto y = graph->getVariableSpace()->getVariable(21); + + ASSERT_NEAR(110.0f, x->getNDArray()->meanNumber(), 1e-5); + ASSERT_NEAR(33.0f, y->getNDArray()->meanNumber(), 1e-5); + + delete graph; +} + + + +TEST_F(FlatBuffersTest, ReadTensorArrayLoop_1) { + auto exp('c', {5, 2}, {3., 6., 9., 12., 15., 18., 21., 24., 27., 30.}); + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_array_loop.fb"); + + ASSERT_TRUE(graph != nullptr); + + //graph->printOut(); + + Nd4jStatus status = GraphExecutioner::execute(graph); + + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto variableSpace = graph->getVariableSpace(); + + ASSERT_TRUE(variableSpace->hasVariable(23,0)); + + auto z = variableSpace->getVariable(23)->getNDArray(); + + //z->printShapeInfo("z shape"); + //z->printIndexedBuffer("z buffer"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete graph; +} + +*/ + +/* +TEST_F(FlatBuffersTest, ReadLoops_NestedWhile_1) { + // TF graph: + // https://gist.github.com/raver119/2aa49daf7ec09ed4ddddbc6262f213a0 + sd::ops::assign op1; + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/nested_while.fb"); + + ASSERT_TRUE(graph != nullptr); + + Nd4jStatus status = GraphExecutioner::execute(graph); + + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto x = graph->getVariableSpace()->getVariable(28); + auto y = graph->getVariableSpace()->getVariable(29); + auto z = graph->getVariableSpace()->getVariable(11, 2); + + ASSERT_NEAR(110.0f, x->getNDArray()->meanNumber(), 1e-5); + ASSERT_NEAR(33.0f, y->getNDArray()->meanNumber(), 1e-5); + + // we should have only 3 cycles in nested loop + ASSERT_NEAR(30.0f, z->getNDArray()->meanNumber(), 1e-5); + + delete graph; +} +*/ +/* + +TEST_F(FlatBuffersTest, ReadTensorArray_1) { + // TF graph: https://gist.github.com/raver119/3265923eed48feecc465d17ec842b6e2 + + auto exp('c', {3, 2}, {1.000000, 1.000000, 2.000000, 2.000000, 3.000000, 3.000000}); + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_array.fb"); + + ASSERT_TRUE(graph != nullptr); + + Nd4jStatus status = GraphExecutioner::execute(graph); + + ASSERT_EQ(ND4J_STATUS_OK, status); + + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(14)); + + auto z = graph->getVariableSpace()->getVariable(14)->getNDArray(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete graph; +} + +*/ +/* +TEST_F(FlatBuffersTest, ReadStridedSlice_1) { + // TF graph: https://gist.github.com/raver119/fc3bf2d31c91e465c635b24020fd798d + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_slice.fb"); + + ASSERT_TRUE(graph != nullptr); + + Nd4jStatus status = GraphExecutioner::execute(graph); + + ASSERT_EQ(ND4J_STATUS_OK, status); + + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(7)); + + auto z = graph->getVariableSpace()->getVariable(7)->getNDArray(); + + ASSERT_NEAR(73.5f, z->e(0), 1e-5); + + delete graph; +} + + + +TEST_F(FlatBuffersTest, ReduceDim_1) { + auto exp = NDArrayFactory::create('c', {3}); + exp.assign(3.0); + + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb"); + + graph->printOut(); + + auto variableSpace = graph->getVariableSpace(); + + + ASSERT_TRUE(variableSpace->hasVariable(1)); + ASSERT_TRUE(variableSpace->hasVariable(2)); + + auto x = variableSpace->getVariable(1)->getNDArray(); + auto y = variableSpace->getVariable(2)->getNDArray(); + + Nd4jStatus status = GraphExecutioner::execute(graph); + + ASSERT_EQ(ND4J_STATUS_OK, status); + + ASSERT_TRUE(variableSpace->hasVariable(3)); + + auto result = variableSpace->getVariable(3)->getNDArray(); + + result->printShapeInfo("z"); + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); + + delete graph; +} + +TEST_F(FlatBuffersTest, ReduceDim_2) { + auto exp = NDArrayFactory::create('c', {3, 1}); + exp.assign(3.0); + + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_true.fb"); + + graph->printOut(); + + auto variableSpace = graph->getVariableSpace(); + + + ASSERT_TRUE(variableSpace->hasVariable(1)); + ASSERT_TRUE(variableSpace->hasVariable(2)); + + auto x = variableSpace->getVariable(1)->getNDArray(); + auto y = variableSpace->getVariable(2)->getNDArray(); + + Nd4jStatus status = GraphExecutioner::execute(graph); + + ASSERT_EQ(ND4J_STATUS_OK, status); + + ASSERT_TRUE(variableSpace->hasVariable(3)); + + auto result = variableSpace->getVariable(3)->getNDArray(); + + ASSERT_TRUE(exp.isSameShape(result)); + ASSERT_TRUE(exp.equalsTo(result)); + + delete graph; +} + */ + +#ifdef GRAPH_FILES_OK +TEST_F(FlatBuffersTest, Ae_00) { + sd::ops::rank op1; + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb"); + + auto exp = NDArrayFactory::create('c', {5, 4}, {0.32454616f, -0.06604697f, 0.22593613f, 0.43166467f, -0.18320604f, 0.00102305f, -0.06963076f, 0.25266643f, 0.07568010f, -0.03009197f, 0.07805517f, 0.33180334f, -0.06220427f, 0.07249600f, -0.06726961f, -0.22998397f, -0.06343779f, 0.07384885f, -0.06891008f, -0.23745790f}); + +// graph->printOut(); + + ASSERT_EQ(OutputMode_VARIABLE_SPACE, graph->getExecutorConfiguration()->_outputMode); + + auto result = GraphExecutioner::execute(graph); + ASSERT_EQ(ND4J_STATUS_OK, result); + + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(18)); + + auto z = graph->getVariableSpace()->getVariable(18)->getNDArray(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete graph; +} + +TEST_F(FlatBuffersTest, expand_dims) { + sd::ops::rank op1; + + auto exp = NDArrayFactory::create('c', {3, 1, 4}, {-0.95938617f, -1.20301781f, 1.22260064f, 0.50172403f, 0.59972949f, 0.78568028f, 0.31609724f, 1.51674747f, 0.68013491f, -0.05227458f, 0.25903158f, 1.13243439f}); + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/expand_dim.fb"); + +// graph->printOut(); + + auto result = GraphExecutioner::execute(graph); + ASSERT_EQ(ND4J_STATUS_OK, result); + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(5)); + + auto z = graph->getVariableSpace()->getVariable(5)->getNDArray(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete graph; +} + +TEST_F(FlatBuffersTest, transpose) { + sd::ops::rank op1; + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/transpose.fb"); + + //graph->printOut(); + + auto result = GraphExecutioner::execute(graph); + ASSERT_EQ(ND4J_STATUS_OK, result); + + delete graph; +} + +TEST_F(FlatBuffersTest, Test_Stitches) { + sd::ops::realdiv op0; + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/partition_stitch_misc.fb"); + //graph->printOut(); + + + auto result = GraphExecutioner::execute(graph); + ASSERT_EQ(ND4J_STATUS_OK, result); + + delete graph; +} + +TEST_F(FlatBuffersTest, Test_GruDynamicMnist) { + sd::Environment::getInstance().setDebug(false); + sd::Environment::getInstance().setVerbose(false); + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/gru_dynamic_mnist.fb"); + //graph->printOut(); + + auto timeStart = std::chrono::system_clock::now(); + auto result = GraphExecutioner::execute(graph); + ASSERT_EQ(ND4J_STATUS_OK, result); + + auto timeEnd = std::chrono::system_clock::now(); + + auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); + + // nd4j_printf("GRU time 1 time %lld us\n", outerTime); + + delete graph; +} + +TEST_F(FlatBuffersTest, Test_Non2D_2) { + sd::Environment::getInstance().setDebug(false); + sd::Environment::getInstance().setVerbose(false); + sd::ops::realdiv op0; + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/non2d_2.fb"); + //graph->printOut(); + + auto result = GraphExecutioner::execute(graph); + ASSERT_EQ(ND4J_STATUS_OK, result); + + delete graph; +} + + +TEST_F(FlatBuffersTest, Test_TensorDotMisc) { + Environment::getInstance().setVerbose(false); + Environment::getInstance().setDebug(false); + + auto e = NDArrayFactory::create('c', {1, 3, 16, 20}, {4.f, 6.f, 6.f, 5.f, 6.f, 4.f, 2.f, 3.f, 5.f, 5.f, 1.f, 4.f, 6.f, 3.f, 2.f, 1.f, 5.f, 4.f, 4.f, 4.f, 4.f, 4.f, 3.f, 4.f, 2.f, 3.f, 3.f, 5.f, 3.f, 6.f, 5.f, 4.f, 4.f, 3.f, 6.f, 1.f, 2.f, 4.f, 2.f, 6.f, 4.f, 2.f, 3.f, 2.f, 3.f, 1.f, 2.f, 4.f, 3.f, 5.f, 3.f, 3.f, 5.f, 2.f, 6.f, 3.f, 4.f, 4.f, 4.f, 4.f, 6.f, 4.f, 5.f, 2.f, 5.f, 5.f, 5.f, 5.f, 2.f, 4.f, 4.f, 4.f, 5.f, 4.f, 3.f, 6.f, 3.f, 4.f, 5.f, 2.f, 5.f, 4.f, 4.f, 5.f, 4.f, 3.f, 4.f, 5.f, 5.f, 3.f, 5.f, 6.f, 6.f, 3.f, 4.f, 5.f, 7.f, 6.f, 5.f, 2.f, 4.f, 5.f, 5.f, 4.f, 5.f, 4.f, 4.f, 6.f, 3.f, 4.f, 5.f, 4.f, 6.f, 2.f, 3.f, 4.f, 3.f, 3.f, 2.f, 2.f, 3.f, 4.f, 7.f, 3.f, 5.f, 4.f, 5.f, 4.f, 4.f, 4.f, 4.f, 6.f, 2.f, 3.f, 2.f, 5.f, 5.f, 4.f, 5.f, 2.f, 2.f, 1.f, 6.f, 2.f, 2.f, 3.f, 4.f, 5.f, 5.f, 3.f, 6.f, 6.f, 4.f, 3.f, 3.f, 3.f, 3.f, 3.f, 4.f, 5.f, 4.f, 4.f, 3.f, 5.f, 2.f, 3.f, 4.f, 5.f, 3.f, 4.f, 5.f, 5.f, 8.f, 4.f, 5.f, 3.f, 3.f, 4.f, 4.f, 5.f, 4.f, 5.f, 3.f, 3.f, 7.f, 2.f, 3.f, 2.f, 6.f, 6.f, 4.f, 4.f, 3.f, 5.f, 6.f, 2.f, 4.f, 3.f, 3.f, 4.f, 5.f, 3.f, 3.f, 6.f, 5.f, 3.f, 2.f, 5.f, 4.f, 4.f, 3.f, 5.f, 5.f, 6.f, 7.f, 3.f, 4.f, 3.f, 5.f, 6.f, 7.f, 5.f, 6.f, 5.f, 7.f, 4.f, 6.f, 5.f, 5.f, 6.f, 4.f, 2.f, 5.f, 4.f, 3.f, 4.f, 1.f, 5.f, 5.f, 3.f, 2.f, 2.f, 6.f, 5.f, 5.f, 2.f, 5.f, 2.f, 4.f, 4.f, 5.f, 5.f, 4.f, 3.f, 7.f, 4.f, 5.f, 3.f, 3.f, 3.f, 2.f, 3.f, 2.f, 3.f, 3.f, 4.f, 4.f, 2.f, 4.f, 5.f, 3.f, 4.f, 5.f, 3.f, 7.f, 2.f, 1.f, 3.f, 2.f, 3.f, 2.f, 3.f, 3.f, 4.f, 3.f, 4.f, 2.f, 4.f, 4.f, 4.f, 5.f, 3.f, 5.f, 3.f, 6.f, 6.f, 5.f, 3.f, 5.f, 3.f, 4.f, 3.f, 5.f, 3.f, 5.f, 6.f, 5.f, 3.f, 4.f, 5.f, 5.f, 3.f, 3.f, 3.f, 4.f, 6.f, 4.f, 3.f, 7.f, 4.f, 4.f, 6.f, 7.f, 5.f, 5.f, 3.f, 1.f, 2.f, 5.f, 5.f, 2.f, 5.f, 7.f, 5.f, 3.f, 1.f, 4.f, 6.f, 5.f, 7.f, 5.f, 6.f, 5.f, 6.f, 4.f, 3.f, 3.f, 4.f, 3.f, 4.f, 4.f, 4.f, 4.f, 3.f, 5.f, 2.f, 4.f, 5.f, 2.f, 5.f, 5.f, 4.f, 5.f, 4.f, 5.f, 2.f, 3.f, 5.f, 3.f, 6.f, 3.f, 4.f, 5.f, 3.f, 6.f, 5.f, 5.f, 6.f, 4.f, 6.f, 7.f, 4.f, 5.f, 3.f, 5.f, 4.f, 4.f, 4.f, 2.f, 2.f, 5.f, 3.f, 5.f, 3.f, 4.f, 6.f, 3.f, 5.f, 5.f, 3.f, 5.f, 4.f, 4.f, 4.f, 5.f, 2.f, 3.f, 5.f, 4.f, 2.f, 4.f, 5.f, 4.f, 2.f, 3.f, 4.f, 4.f, 5.f, 5.f, 1.f, 4.f, 4.f, 4.f, 3.f, 4.f, 5.f, 5.f, 8.f, 4.f, 4.f, 4.f, 3.f, 6.f, 2.f, 3.f, 4.f, 4.f, 4.f, 3.f, 2.f, 3.f, 4.f, 8.f, 3.f, 5.f, 5.f, 5.f, 3.f, 3.f, 4.f, 5.f, 7.f, 3.f, 3.f, 3.f, 6.f, 6.f, 5.f, 5.f, 3.f, 4.f, 3.f, 8.f, 3.f, 4.f, 2.f, 3.f, 4.f, 4.f, 3.f, 5.f, 5.f, 3.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 6.f, 6.f, 5.f, 6.f, 4.f, 5.f, 4.f, 6.f, 4.f, 5.f, 5.f, 4.f, 7.f, 3.f, 5.f, 5.f, 3.f, 5.f, 5.f, 6.f, 4.f, 5.f, 4.f, 2.f, 7.f, 2.f, 3.f, 1.f, 4.f, 5.f, 5.f, 4.f, 4.f, 5.f, 7.f, 2.f, 3.f, 3.f, 4.f, 4.f, 5.f, 3.f, 3.f, 6.f, 6.f, 3.f, 2.f, 4.f, 3.f, 3.f, 3.f, 3.f, 4.f, 4.f, 5.f, 1.f, 2.f, 3.f, 3.f, 4.f, 5.f, 4.f, 5.f, 4.f, 5.f, 6.f, 6.f, 6.f, 6.f, 7.f, 4.f, 3.f, 4.f, 5.f, 4.f, 4.f, 2.f, 5.f, 6.f, 4.f, 2.f, 2.f, 6.f, 5.f, 5.f, 1.f, 4.f, 2.f, 3.f, 4.f, 5.f, 5.f, 4.f, 5.f, 9.f, 4.f, 6.f, 4.f, 5.f, 5.f, 3.f, 4.f, 5.f, 5.f, 5.f, 4.f, 3.f, 1.f, 3.f, 4.f, 3.f, 4.f, 4.f, 3.f, 6.f, 2.f, 3.f, 3.f, 2.f, 3.f, 3.f, 4.f, 5.f, 6.f, 5.f, 5.f, 3.f, 4.f, 5.f, 5.f, 4.f, 3.f, 4.f, 3.f, 6.f, 7.f, 6.f, 4.f, 6.f, 4.f, 3.f, 3.f, 4.f, 3.f, 5.f, 5.f, 4.f, 2.f, 3.f, 4.f, 5.f, 3.f, 4.f, 2.f, 4.f, 5.f, 3.f, 3.f, 7.f, 4.f, 2.f, 5.f, 6.f, 5.f, 5.f, 3.f, 1.f, 2.f, 4.f, 4.f, 1.f, 3.f, 6.f, 3.f, 3.f, 1.f, 4.f, 4.f, 4.f, 5.f, 3.f, 4.f, 3.f, 4.f, 2.f, 3.f, 3.f, 4.f, 3.f, 4.f, 3.f, 3.f, 4.f, 2.f, 5.f, 1.f, 3.f, 4.f, 2.f, 6.f, 4.f, 3.f, 4.f, 3.f, 3.f, 1.f, 2.f, 5.f, 2.f, 6.f, 4.f, 5.f, 6.f, 3.f, 6.f, 4.f, 4.f, 5.f, 3.f, 5.f, 6.f, 3.f, 4.f, 2.f, 4.f, 5.f, 5.f, 5.f, 2.f, 3.f, 4.f, 3.f, 5.f, 3.f, 3.f, 9.f, 6.f, 7.f, 7.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 4.f, 6.f, 5.f, 3.f, 5.f, 5.f, 5.f, 2.f, 4.f, 6.f, 7.f, 7.f, 5.f, 3.f, 4.f, 5.f, 4.f, 4.f, 5.f, 5.f, 5.f, 8.f, 4.f, 4.f, 4.f, 3.f, 5.f, 3.f, 3.f, 4.f, 4.f, 5.f, 3.f, 3.f, 2.f, 3.f, 6.f, 2.f, 5.f, 4.f, 4.f, 3.f, 3.f, 3.f, 5.f, 7.f, 2.f, 3.f, 2.f, 5.f, 5.f, 4.f, 4.f, 2.f, 2.f, 1.f, 6.f, 1.f, 2.f, 2.f, 3.f, 5.f, 4.f, 3.f, 5.f, 5.f, 3.f, 2.f, 2.f, 2.f, 2.f, 4.f, 3.f, 4.f, 4.f, 4.f, 4.f, 5.f, 2.f, 4.f, 4.f, 5.f, 2.f, 4.f, 4.f, 5.f, 9.f, 4.f, 5.f, 4.f, 3.f, 5.f, 5.f, 6.f, 4.f, 4.f, 3.f, 3.f, 6.f, 2.f, 3.f, 2.f, 5.f, 6.f, 4.f, 4.f, 3.f, 5.f, 6.f, 4.f, 5.f, 5.f, 6.f, 7.f, 4.f, 2.f, 3.f, 5.f, 4.f, 4.f, 3.f, 5.f, 5.f, 4.f, 3.f, 4.f, 5.f, 4.f, 6.f, 3.f, 4.f, 4.f, 5.f, 6.f, 6.f, 4.f, 6.f, 6.f, 6.f, 5.f, 6.f, 6.f, 7.f, 7.f, 4.f, 3.f, 4.f, 4.f, 4.f, 5.f, 2.f, 5.f, 7.f, 5.f, 2.f, 1.f, 5.f, 5.f, 4.f, 1.f, 4.f, 1.f, 3.f, 3.f, 5.f, 4.f, 4.f, 3.f, 7.f, 3.f, 6.f, 3.f, 3.f, 4.f, 1.f, 3.f, 2.f, 3.f, 3.f, 4.f, 3.f, 1.f, 3.f, 4.f, 2.f, 4.f, 4.f, 2.f, 6.f, 1.f, 2.f, 2.f, 2.f, 3.f, 2.f, 3.f, 3.f, 4.f, 4.f, 4.f, 2.f, 4.f, 4.f, 4.f, 5.f, 5.f, 5.f, 4.f, 8.f, 5.f, 5.f, 3.f, 5.f, 3.f, 3.f, 2.f, 4.f, 3.f, 5.f, 6.f, 5.f, 3.f, 4.f, 5.f, 5.f, 3.f, 4.f, 3.f, 4.f, 8.f, 6.f, 5.f, 9.f, 6.f}); + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_dot_misc.fb"); +// graph->printOut(); + + auto result = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), result); + + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(77)); + + auto z = graph->getVariableSpace()->getVariable(77,0)->getNDArray(); + + ASSERT_EQ(e, *z); + + delete graph; +} + + +TEST_F(FlatBuffersTest, Test_MNIST_00_1) { + auto e = NDArrayFactory::create('c', {100, 10}, {0.00066107f, 0.00002358f, 0.00031518f, 0.00238039f, 0.00027216f, 0.00030300f, 0.00004659f, 0.98962247f, 0.00050380f, 0.00587174f, 0.05895791f, 0.00323104f, 0.52636790f, 0.12912551f, 0.00003951f, 0.03615341f, 0.22013727f, 0.00007333f, 0.02566659f, 0.00024759f, 0.00192367f, 0.90509874f, 0.01985082f, 0.02080356f, 0.00260053f, 0.00497826f, 0.01107823f, 0.00872595f, 0.01559795f, 0.00934229f, 0.98202229f, 0.00000150f, 0.00137381f, 0.00082931f, 0.00001806f, 0.00384426f, 0.00758274f, 0.00305049f, 0.00052152f, 0.00075617f, 0.01094264f, 0.00044708f, 0.03576852f, 0.00711267f, 0.65963465f, 0.00734364f, 0.02747800f, 0.06494589f, 0.02966754f, 0.15665947f, 0.00035806f, 0.95196360f, 0.00622721f, 0.01610696f, 0.00084180f, 0.00139947f, 0.00127350f, 0.00577912f, 0.00980321f, 0.00624705f, 0.00167418f, 0.00125611f, 0.00109477f, 0.04061511f, 0.57403159f, 0.08173440f, 0.00423709f, 0.10187119f, 0.07103974f, 0.12244581f, 0.00073566f, 0.00624759f, 0.00559816f, 0.01215601f, 0.08299568f, 0.06209232f, 0.01742392f, 0.01341172f, 0.02181461f, 0.77752429f, 0.08474547f, 0.00957346f, 0.29235491f, 0.00243696f, 0.06653537f, 0.03792902f, 0.43910959f, 0.00344940f, 0.02626713f, 0.03759870f, 0.00143713f, 0.00011047f, 0.00018820f, 0.00047970f, 0.02127167f, 0.00308758f, 0.00093357f, 0.17067374f, 0.00545499f, 0.79636300f, 0.95257199f, 0.00002157f, 0.00647615f, 0.01024892f, 0.00005942f, 0.01910058f, 0.00044579f, 0.00008416f, 0.01097712f, 0.00001441f, 0.16705236f, 0.01782482f, 0.17580827f, 0.06262068f, 0.03860324f, 0.01763505f, 0.32766294f, 0.00555595f, 0.17227779f, 0.01495883f, 0.00180449f, 0.00010494f, 0.00075124f, 0.00161161f, 0.08859238f, 0.00364861f, 0.00162414f, 0.06005199f, 0.00805061f, 0.83375996f, 0.97355360f, 0.00000305f, 0.00144336f, 0.00051544f, 0.00010043f, 0.00714774f, 0.00021183f, 0.00042562f, 0.01294680f, 0.00365222f, 0.00026871f, 0.95752406f, 0.00408361f, 0.02153200f, 0.00015639f, 0.00153930f, 0.00323335f, 0.00178700f, 0.00516464f, 0.00471107f, 0.07408376f, 0.00468759f, 0.02638813f, 0.33325842f, 0.01172767f, 0.36993489f, 0.01118315f, 0.01460529f, 0.14850292f, 0.00562817f, 0.00551083f, 0.00015134f, 0.01184739f, 0.00643833f, 0.11686873f, 0.00163741f, 0.00582776f, 0.11497385f, 0.02010887f, 0.71663547f, 0.00154932f, 0.00001290f, 0.00023825f, 0.01393047f, 0.00012438f, 0.00033184f, 0.00010033f, 0.98197538f, 0.00022847f, 0.00150876f, 0.00597587f, 0.00819661f, 0.03041674f, 0.43121871f, 0.00986523f, 0.13834484f, 0.29576671f, 0.01305170f, 0.03919542f, 0.02796829f, 0.00139392f, 0.00031466f, 0.00229704f, 0.00647669f, 0.86193180f, 0.01064646f, 0.00494287f, 0.00901443f, 0.00526376f, 0.09771839f, 0.00184158f, 0.00040986f, 0.00008309f, 0.01634205f, 0.01102151f, 0.01133229f, 0.00011603f, 0.30489817f, 0.00813993f, 0.64581543f, 0.00132390f, 0.00009014f, 0.00471620f, 0.00419161f, 0.01024686f, 0.02504917f, 0.94500881f, 0.00010234f, 0.00620976f, 0.00306121f, 0.00971363f, 0.05415262f, 0.05265132f, 0.01217585f, 0.16251956f, 0.00188165f, 0.61800343f, 0.04541704f, 0.01950107f, 0.02398386f, 0.05354780f, 0.00129718f, 0.00762409f, 0.06902183f, 0.01746517f, 0.71758413f, 0.04491642f, 0.00194128f, 0.07204670f, 0.01455537f, 0.00356139f, 0.00223315f, 0.01881612f, 0.01844147f, 0.65686893f, 0.01172961f, 0.01321550f, 0.06555344f, 0.00993031f, 0.19965005f, 0.99641657f, 0.00000005f, 0.00027076f, 0.00000523f, 0.00001288f, 0.00173779f, 0.00140848f, 0.00001787f, 0.00012701f, 0.00000342f, 0.00364264f, 0.00040242f, 0.00199880f, 0.01658181f, 0.00522031f, 0.00494563f, 0.00134627f, 0.87392259f, 0.00277323f, 0.08916643f, 0.00200165f, 0.00006030f, 0.00265544f, 0.00137030f, 0.85328883f, 0.00988892f, 0.00416652f, 0.00394441f, 0.00617034f, 0.11645336f, 0.97291315f, 0.00000182f, 0.00194084f, 0.01498440f, 0.00001028f, 0.00389095f, 0.00023297f, 0.00044887f, 0.00528154f, 0.00029516f, 0.00188889f, 0.79829764f, 0.01104437f, 0.04222726f, 0.00522182f, 0.04550264f, 0.03192228f, 0.01099020f, 0.04107348f, 0.01183154f, 0.00058263f, 0.00048307f, 0.00013920f, 0.96885711f, 0.00005209f, 0.01755359f, 0.00061751f, 0.00787173f, 0.00087605f, 0.00296709f, 0.00342248f, 0.68736714f, 0.01477064f, 0.11038199f, 0.00979373f, 0.03290173f, 0.02064420f, 0.03154078f, 0.03068676f, 0.05849051f, 0.00054699f, 0.00028973f, 0.00066918f, 0.79915440f, 0.00078404f, 0.18881910f, 0.00078736f, 0.00024780f, 0.00598373f, 0.00271761f, 0.37178108f, 0.00029151f, 0.11573081f, 0.00016159f, 0.08614764f, 0.05626433f, 0.33961067f, 0.00184490f, 0.01931754f, 0.00884999f, 0.00103338f, 0.00105793f, 0.01583840f, 0.01417849f, 0.00086645f, 0.00075313f, 0.00009471f, 0.92975640f, 0.00786521f, 0.02855594f, 0.00831110f, 0.00041050f, 0.95547730f, 0.01004958f, 0.00024040f, 0.00674337f, 0.01100292f, 0.00229303f, 0.00543977f, 0.00003204f, 0.00073861f, 0.00003656f, 0.00233217f, 0.00864751f, 0.00044351f, 0.00055325f, 0.00046273f, 0.97456056f, 0.00097461f, 0.01125053f, 0.00035382f, 0.94428235f, 0.00286066f, 0.01286138f, 0.00111129f, 0.00731637f, 0.00518610f, 0.00538214f, 0.01197775f, 0.00866815f, 0.06013579f, 0.03228600f, 0.20441757f, 0.54548728f, 0.00006484f, 0.02362618f, 0.05482962f, 0.00106437f, 0.07713205f, 0.00095635f, 0.00029120f, 0.94839782f, 0.00271641f, 0.02038633f, 0.00010249f, 0.00270848f, 0.00299053f, 0.00069419f, 0.01599395f, 0.00571855f, 0.00580072f, 0.81594771f, 0.03097420f, 0.03646614f, 0.00565077f, 0.01715674f, 0.02362122f, 0.01730293f, 0.02312471f, 0.02395495f, 0.00083797f, 0.00032276f, 0.00475549f, 0.00577861f, 0.00193654f, 0.00201117f, 0.00095864f, 0.89032167f, 0.00238766f, 0.09068950f, 0.00007685f, 0.00309113f, 0.00165920f, 0.00566203f, 0.79406202f, 0.00106585f, 0.00073159f, 0.02779965f, 0.01331810f, 0.15253356f, 0.01362522f, 0.17258310f, 0.57671696f, 0.04606603f, 0.02204953f, 0.00909986f, 0.04971812f, 0.00135137f, 0.09417208f, 0.01461779f, 0.00351132f, 0.01659229f, 0.02209206f, 0.77456558f, 0.00303461f, 0.07932901f, 0.06269170f, 0.01151956f, 0.01363456f, 0.01302921f, 0.04056359f, 0.00052574f, 0.00214679f, 0.41835260f, 0.00373941f, 0.47472891f, 0.00819933f, 0.00047488f, 0.04602791f, 0.00524084f, 0.00085833f, 0.19585223f, 0.03986045f, 0.44138056f, 0.01866945f, 0.11297230f, 0.03688592f, 0.03147812f, 0.04306961f, 0.07897298f, 0.00580970f, 0.00654101f, 0.80165571f, 0.01388136f, 0.04366852f, 0.00407737f, 0.07712067f, 0.01289223f, 0.01437380f, 0.01997955f, 0.00013239f, 0.00000585f, 0.00003676f, 0.00288744f, 0.76327205f, 0.00911173f, 0.00025323f, 0.00345270f, 0.00977252f, 0.21107534f, 0.00238540f, 0.00011487f, 0.01707160f, 0.00274678f, 0.85196322f, 0.00066304f, 0.01279381f, 0.02112481f, 0.00446795f, 0.08666852f, 0.01046857f, 0.00011744f, 0.00377885f, 0.00806424f, 0.00110093f, 0.01087467f, 0.96216726f, 0.00024677f, 0.00213707f, 0.00104427f, 0.00835356f, 0.00037980f, 0.00540865f, 0.91882282f, 0.00084274f, 0.03935680f, 0.00700863f, 0.00609934f, 0.00307425f, 0.01065346f, 0.09310398f, 0.00066428f, 0.00076882f, 0.02210450f, 0.04447530f, 0.77650899f, 0.00945148f, 0.00689890f, 0.00886871f, 0.03715509f, 0.07214937f, 0.00624633f, 0.01399398f, 0.29444799f, 0.03825752f, 0.36904955f, 0.02109544f, 0.01373637f, 0.14653027f, 0.02449317f, 0.01878268f, 0.01089148f, 0.36442387f, 0.01426089f, 0.02649262f, 0.00308395f, 0.51123023f, 0.00987128f, 0.02856500f, 0.01239803f, 0.65732223f, 0.00001665f, 0.00257388f, 0.02261361f, 0.00056261f, 0.08028404f, 0.00753943f, 0.00092872f, 0.22300763f, 0.00515121f, 0.00238470f, 0.00001802f, 0.00303019f, 0.00282769f, 0.93392336f, 0.00829813f, 0.00937593f, 0.00232166f, 0.00606702f, 0.03175319f, 0.00192149f, 0.89188498f, 0.01474108f, 0.03585867f, 0.00123343f, 0.00441551f, 0.00399710f, 0.00857630f, 0.01781271f, 0.01955875f, 0.00221238f, 0.00005268f, 0.00038176f, 0.00141851f, 0.07513693f, 0.00153898f, 0.00254140f, 0.04116146f, 0.00216117f, 0.87339473f, 0.17824675f, 0.04543359f, 0.01501061f, 0.03382575f, 0.09682461f, 0.29989448f, 0.02655865f, 0.16809541f, 0.09566309f, 0.04044705f, 0.00052125f, 0.00006512f, 0.00041621f, 0.03254773f, 0.00120942f, 0.00177929f, 0.00091721f, 0.95285058f, 0.00068729f, 0.00900588f, 0.04185560f, 0.00125587f, 0.33473280f, 0.00119652f, 0.00552071f, 0.03358750f, 0.04974457f, 0.00243473f, 0.41644078f, 0.11323092f, 0.00945223f, 0.00509389f, 0.04602458f, 0.02943204f, 0.23871920f, 0.06141117f, 0.05274383f, 0.03511769f, 0.09954999f, 0.42245534f, 0.00686926f, 0.01075546f, 0.49830484f, 0.37111449f, 0.00928881f, 0.00910977f, 0.00822666f, 0.00448587f, 0.04094843f, 0.04089646f, 0.00190534f, 0.00074783f, 0.02465805f, 0.02045769f, 0.02690129f, 0.00249506f, 0.00202899f, 0.84847659f, 0.01121813f, 0.06111111f, 0.00527403f, 0.00617689f, 0.00719898f, 0.17549324f, 0.25461593f, 0.15036304f, 0.04163047f, 0.01647436f, 0.08906800f, 0.25370511f, 0.10200825f, 0.03916828f, 0.22575049f, 0.08762794f, 0.06703069f, 0.01087492f, 0.27197123f, 0.15926389f, 0.02289790f, 0.01340644f, 0.00233572f, 0.00071111f, 0.01389953f, 0.00187034f, 0.89338356f, 0.00067592f, 0.00535080f, 0.02598928f, 0.01003115f, 0.04575264f, 0.00010197f, 0.00006095f, 0.00021980f, 0.99164659f, 0.00011408f, 0.00474983f, 0.00004892f, 0.00012496f, 0.00257160f, 0.00036128f, 0.91125363f, 0.00012225f, 0.02511939f, 0.00156989f, 0.00002669f, 0.03335980f, 0.01791442f, 0.00531134f, 0.00345027f, 0.00187230f, 0.00210833f, 0.00001888f, 0.00016036f, 0.00394190f, 0.00016232f, 0.00026980f, 0.00012382f, 0.99098623f, 0.00036967f, 0.00185874f, 0.99578768f, 0.00000018f, 0.00162244f, 0.00012927f, 0.00000136f, 0.00158810f, 0.00016544f, 0.00000476f, 0.00069853f, 0.00000226f, 0.19834445f, 0.00044551f, 0.40857196f, 0.34896207f, 0.00023418f, 0.00828141f, 0.02426279f, 0.00148875f, 0.00938030f, 0.00002860f, 0.00201644f, 0.06109568f, 0.01542680f, 0.05984236f, 0.00112191f, 0.00419699f, 0.00110061f, 0.28937989f, 0.13231210f, 0.43350723f, 0.00055382f, 0.92216444f, 0.00396460f, 0.01456171f, 0.00061405f, 0.00972675f, 0.00677260f, 0.00454273f, 0.02471014f, 0.01238921f, 0.00027888f, 0.02572848f, 0.00290584f, 0.00748292f, 0.08441166f, 0.00232722f, 0.00188305f, 0.81133318f, 0.01191756f, 0.05173124f, 0.00315098f, 0.00499059f, 0.00158580f, 0.92859417f, 0.00035086f, 0.04807130f, 0.00101955f, 0.00034313f, 0.01119398f, 0.00069962f, 0.00112821f, 0.00214349f, 0.03968662f, 0.00325992f, 0.00253143f, 0.00199443f, 0.00964058f, 0.90529889f, 0.00384289f, 0.03047365f, 0.00174196f, 0.06674320f, 0.00283191f, 0.09274873f, 0.01944309f, 0.03424436f, 0.00694406f, 0.07912937f, 0.15087396f, 0.54529935f, 0.00007096f, 0.00001000f, 0.00001498f, 0.00007066f, 0.00002792f, 0.00005677f, 0.00000490f, 0.99606401f, 0.00030978f, 0.00337013f, 0.00286575f, 0.00011636f, 0.00064778f, 0.00992065f, 0.04501861f, 0.03149971f, 0.00287679f, 0.37334359f, 0.00214695f, 0.53156382f, 0.00600238f, 0.00003215f, 0.02112119f, 0.00084685f, 0.00497269f, 0.00753993f, 0.95174772f, 0.00150877f, 0.00212018f, 0.00410815f, 0.00006566f, 0.00001179f, 0.99827027f, 0.00028396f, 0.00004237f, 0.00000550f, 0.00091406f, 0.00003423f, 0.00036640f, 0.00000567f, 0.00079063f, 0.00006855f, 0.00051338f, 0.00590454f, 0.00732460f, 0.00195139f, 0.00034534f, 0.90222436f, 0.00163695f, 0.07924022f, 0.00362202f, 0.01493629f, 0.01135249f, 0.00781013f, 0.05138498f, 0.22704794f, 0.00442778f, 0.00350683f, 0.59828150f, 0.07762999f, 0.00016529f, 0.00001219f, 0.00006521f, 0.00446292f, 0.94456083f, 0.00407963f, 0.00102245f, 0.00057420f, 0.00344479f, 0.04161252f, 0.00000981f, 0.00030270f, 0.00017082f, 0.00029943f, 0.00010159f, 0.00003605f, 0.00001875f, 0.99310946f, 0.00063157f, 0.00531995f, 0.01100852f, 0.00021492f, 0.00049603f, 0.59714299f, 0.00454595f, 0.33691072f, 0.03074775f, 0.00427598f, 0.00512297f, 0.00953417f, 0.00064403f, 0.00001687f, 0.00822414f, 0.00012918f, 0.02522905f, 0.00046274f, 0.95950085f, 0.00174588f, 0.00070707f, 0.00334025f, 0.00014754f, 0.96842438f, 0.00752080f, 0.00713038f, 0.00074491f, 0.00107368f, 0.00245372f, 0.00181830f, 0.00883226f, 0.00185409f, 0.00210863f, 0.00017522f, 0.00039881f, 0.98836052f, 0.00003650f, 0.00535216f, 0.00001887f, 0.00069545f, 0.00265663f, 0.00019714f, 0.00028919f, 0.00026057f, 0.00356666f, 0.00034738f, 0.00413719f, 0.00133701f, 0.98608136f, 0.00009625f, 0.00153734f, 0.00234698f, 0.01427079f, 0.04020482f, 0.04733688f, 0.03817881f, 0.16299380f, 0.04943828f, 0.03522370f, 0.05902825f, 0.23904003f, 0.31428465f, 0.00029359f, 0.00005619f, 0.00007707f, 0.98437482f, 0.00000957f, 0.00828004f, 0.00002787f, 0.00510217f, 0.00087425f, 0.00090444f, 0.00011413f, 0.83918202f, 0.01017746f, 0.03100164f, 0.00308035f, 0.01615586f, 0.02608237f, 0.00337026f, 0.05493741f, 0.01589854f, 0.00053240f, 0.00144792f, 0.00108170f, 0.00027300f, 0.86477506f, 0.00072790f, 0.01062538f, 0.00428096f, 0.00233054f, 0.11392505f, 0.00411633f, 0.33660546f, 0.01735369f, 0.18114267f, 0.03090077f, 0.11699959f, 0.03416851f, 0.06780743f, 0.07481573f, 0.13608985f, 0.00073468f, 0.20941530f, 0.01012138f, 0.17237675f, 0.01661461f, 0.02184150f, 0.03694551f, 0.30870155f, 0.04255475f, 0.18069389f, 0.06343270f, 0.00037455f, 0.06623310f, 0.00041474f, 0.00209181f, 0.04566626f, 0.81232506f, 0.00054500f, 0.00807252f, 0.00084416f, 0.00008067f, 0.00003926f, 0.00225794f, 0.00115743f, 0.01925980f, 0.00010427f, 0.00062067f, 0.02234522f, 0.00210706f, 0.95202768f}); + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/mnist_00.fb"); + //graph->printOut(); + + auto result = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), result); + + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(6)); + + auto z = graph->getVariableSpace()->getVariable(6,0)->getNDArray(); + + ASSERT_EQ(e, *z); + + delete graph; +} + + + +TEST_F(FlatBuffersTest, Test_MNIST_1) { + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/mnist.fb"); + //graph->printOut(); + + auto result = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), result); + + delete graph; +} + +/* +// FIXME: uncomment this test once conv_0 fb reexported +TEST_F(FlatBuffersTest, nhwc_conv_0) { + sd::ops::rank op1; + + auto exp('c', {4, 2}, {2.958640f, 0.602521f, 7.571267f, 1.496686f, -2.292647f, -1.791460f, 13.055838f, 4.278642f}); + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/conv_0.fb"); + + graph->printOut(); + + auto result = GraphExecutioner::execute(graph); + ASSERT_EQ(ND4J_STATUS_OK, result); + + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(11)); + + auto z = graph->getVariableSpace()->getVariable(11)->getNDArray(); + + //z->printShapeInfo("z buffr"); + //z->printIndexedBuffer("z shape"); + +// [[2.96, 0.60], +// [7.57, 1.50], +// [-2.29, -1.79], +// [13.06, 4.28]] + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete graph; +} + +*/ + + +/* +TEST_F(FlatBuffersTest, ReadLoops_SimpleWhile_1) { + // TF graph: + // https://gist.github.com/raver119/2aa49daf7ec09ed4ddddbc6262f213a0 + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/simple_while.fb"); + + ASSERT_TRUE(graph != nullptr); + + Nd4jStatus status = GraphExecutioner::execute(graph); + + ASSERT_EQ(ND4J_STATUS_OK, status); + + delete graph; +} + + */ +#endif diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/FlatUtilsTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/FlatUtilsTests.cpp new file mode 100644 index 000000000..327a4e3c3 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/FlatUtilsTests.cpp @@ -0,0 +1,104 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +// +// @author raver119@gmail.com +// + +#include +#include +#include "testlayers.h" +#include +#include + +using namespace sd; + +class FlatUtilsTests : public testing::Test { +public: + +}; + +TEST_F(FlatUtilsTests, flat_float_serde_1) { + auto array = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + + flatbuffers::FlatBufferBuilder builder(1024); + auto flatArray = FlatUtils::toFlatArray(builder, array); + builder.Finish(flatArray); + + + auto pfArray = GetFlatArray(builder.GetBufferPointer()); + + auto restored = FlatUtils::fromFlatArray(pfArray); + + ASSERT_EQ(array, *restored); + + delete restored; +} + +TEST_F(FlatUtilsTests, flat_int_serde_1) { + auto array = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + + flatbuffers::FlatBufferBuilder builder(1024); + auto flatArray = FlatUtils::toFlatArray(builder, array); + builder.Finish(flatArray); + + + auto pfArray = GetFlatArray(builder.GetBufferPointer()); + + auto restored = FlatUtils::fromFlatArray(pfArray); + + ASSERT_EQ(array, *restored); + + delete restored; +} + +TEST_F(FlatUtilsTests, flat_bool_serde_1) { + auto array = NDArrayFactory::create('c', {4}, {true, false, true, false}); + + flatbuffers::FlatBufferBuilder builder(1024); + auto flatArray = FlatUtils::toFlatArray(builder, array); + builder.Finish(flatArray); + + + auto pfArray = GetFlatArray(builder.GetBufferPointer()); + + auto restored = FlatUtils::fromFlatArray(pfArray); + + ASSERT_EQ(array, *restored); + + delete restored; +} + +TEST_F(FlatUtilsTests, flat_string_serde_1) { + auto array = NDArrayFactory::string( {3}, {"alpha", "beta", "gamma"}); + + flatbuffers::FlatBufferBuilder builder(1024); + auto flatArray = FlatUtils::toFlatArray(builder, array); + builder.Finish(flatArray); + + + auto pfArray = GetFlatArray(builder.GetBufferPointer()); + + auto restored = FlatUtils::fromFlatArray(pfArray); + + ASSERT_EQ(array, *restored); + + delete restored; +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/GraphExecutionerTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/GraphExecutionerTests.cpp new file mode 100644 index 000000000..6de134010 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/GraphExecutionerTests.cpp @@ -0,0 +1,105 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 29.11.17. +// + + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include +#include + +using namespace sd; +using namespace sd::graph; + +class GraphExecutionerTests : public testing::Test { +public: + +}; + +#ifdef GRAPH_TESTS_OK +TEST_F(GraphExecutionerTests, Test_Implicit_Output_1) { + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_slice.fb"); + graph->buildGraph(); + + auto outputs = graph->fetchOutputs(); + + ASSERT_EQ(1, outputs->size()); + + auto var0 = outputs->at(0); + + ASSERT_EQ(7, var0->id()); + ASSERT_EQ(0, var0->index()); + + delete outputs; + delete graph; +} + + +TEST_F(GraphExecutionerTests, Test_Implicit_Output_2) { + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb"); + graph->buildGraph(); + + auto outputs = graph->fetchOutputs(); + + ASSERT_EQ(1, outputs->size()); + + auto var0 = outputs->at(0); + + ASSERT_EQ(3, var0->id()); + ASSERT_EQ(0, var0->index()); + + delete outputs; + delete graph; +} + + +TEST_F(GraphExecutionerTests, Test_Implicit_Output_3) { + auto exp = NDArrayFactory::create('c', {3}, {3, 3, 3}); + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb"); + auto status = GraphExecutioner::execute(graph); + + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto outputs = graph->fetchOutputs(); + + ASSERT_EQ(1, outputs->size()); + + auto var0 = outputs->at(0); + + ASSERT_EQ(3, var0->id()); + ASSERT_EQ(0, var0->index()); + + auto array = var0->getNDArray(); + + ASSERT_TRUE(array != nullptr); + + ASSERT_TRUE(exp.isSameShape(array)); + ASSERT_TRUE(exp.equalsTo(array)); + + delete outputs; + delete graph; +} +#endif diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/GraphHolderTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/GraphHolderTests.cpp new file mode 100644 index 000000000..61058095f --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/GraphHolderTests.cpp @@ -0,0 +1,88 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 11.12.17. +// + +#include "testlayers.h" +#include + +using namespace sd; +using namespace sd::ops; +using namespace sd::graph; + +class GraphHolderTests : public testing::Test { +public: + +}; + +TEST_F(GraphHolderTests, SimpleTests_1) { + Graph graph; + Nd4jLong graphId = 119; + GraphHolder::getInstance().registerGraph(graphId, &graph); + + ASSERT_TRUE(GraphHolder::getInstance().hasGraph(graphId)); + + GraphHolder::getInstance().forgetGraph(graphId); + + ASSERT_FALSE(GraphHolder::getInstance().hasGraph(graphId)); +} + + + +TEST_F(GraphHolderTests, SimpleTests_2) { + auto graph = new Graph; + Nd4jLong graphId = 117; + GraphHolder::getInstance().registerGraph(graphId, graph); + + ASSERT_TRUE(GraphHolder::getInstance().hasGraph(graphId)); + + auto graph2 = GraphHolder::getInstance().cloneGraph(graphId); + + ASSERT_TRUE(graph != graph2); + ASSERT_TRUE(graph2 != nullptr); + + GraphHolder::getInstance().forgetGraph(graphId); + + ASSERT_FALSE(GraphHolder::getInstance().hasGraph(graphId)); + + delete graph; + delete graph2; +} + + +TEST_F(GraphHolderTests, SimpleTests_3) { + auto graph = new Graph; + Nd4jLong graphId = 117; + GraphHolder::getInstance().registerGraph(graphId, graph); + + ASSERT_TRUE(GraphHolder::getInstance().hasGraph(graphId)); + + auto graph2 = GraphHolder::getInstance().cloneGraph(graphId); + + ASSERT_TRUE(graph != graph2); + ASSERT_TRUE(graph2 != nullptr); + + GraphHolder::getInstance().dropGraph(graphId); + + ASSERT_FALSE(GraphHolder::getInstance().hasGraph(graphId)); + + + delete graph2; +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/GraphRandomGeneratorTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/GraphRandomGeneratorTests.cpp new file mode 100644 index 000000000..0cc1c0114 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/GraphRandomGeneratorTests.cpp @@ -0,0 +1,266 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +#include "testlayers.h" +#include +#include +#include +#include + +using namespace sd; +using namespace sd::graph; + +class GraphRandomGeneratorTests : public testing::Test { +public: + +}; + +TEST_F(GraphRandomGeneratorTests, Reproducibility_Test_1) { + sd::graph::RandomGenerator g0(119); + sd::graph::RandomGenerator g1(119); + + auto i0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto i1 = g1.relativeT(15, 0, DataTypeUtils::max()); + + ASSERT_EQ(i0, i1); +} + +TEST_F(GraphRandomGeneratorTests, Reproducibility_Test_2) { + sd::graph::RandomGenerator g0(119); + sd::graph::RandomGenerator g1(117); + + auto i0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto i1 = g1.relativeT(15, 0, DataTypeUtils::max()); + + ASSERT_NE(i0, i1); +} + +TEST_F(GraphRandomGeneratorTests, Reproducibility_Test_3) { + sd::graph::RandomGenerator g0(119, 5); + sd::graph::RandomGenerator g1(119, 10); + + auto i0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto i1 = g1.relativeT(15, 0, DataTypeUtils::max()); + + ASSERT_NE(i0, i1); +} + +TEST_F(GraphRandomGeneratorTests, Reproducibility_Test_4) { + sd::graph::RandomGenerator g0(119, 5); + sd::graph::RandomGenerator g1(117, 5); + + auto i0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto i1 = g1.relativeT(15, 0, DataTypeUtils::max()); + + ASSERT_NE(i0, i1); +} + +TEST_F(GraphRandomGeneratorTests, Sequential_Test_1) { + sd::graph::RandomGenerator g0(119, 5); + sd::graph::RandomGenerator g1(119, 5); + + auto v0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto v1 = g1.relativeT(15, 0, DataTypeUtils::max()); + g0.rewindH(200); + auto r0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto r1 = g1.relativeT(15, 0, DataTypeUtils::max()); + + // values after rewind aren't equal + ASSERT_NE(r0, v0); + + // two generators must give the same output + ASSERT_EQ(v0, v1); + + // but not after one of them was rewinded + ASSERT_NE(r1, r0); +} + +TEST_F(GraphRandomGeneratorTests, Sequential_Test_2) { + sd::graph::RandomGenerator g0(119, 5); + sd::graph::RandomGenerator g1(119, 5); + + auto v0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto v1 = g1.relativeT(15, 0, DataTypeUtils::max()); + g0.rewindH(200); + g1.rewindH(199); + auto r0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto r1 = g1.relativeT(15, 0, DataTypeUtils::max()); + + // values after rewind aren't equal + ASSERT_NE(r0, v0); + + // two generators must give the same output + ASSERT_EQ(v0, v1); + + // but not after they was rewinded with different number of elements + ASSERT_NE(r1, r0); +} + +TEST_F(GraphRandomGeneratorTests, Sequential_Test_3) { + sd::graph::RandomGenerator g0(119, 5); + sd::graph::RandomGenerator g1(119, 5); + + auto v0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto v1 = g1.relativeT(15, 0, DataTypeUtils::max()); + g0.rewindH(200); + g1.rewindH(200); + auto r0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto r1 = g1.relativeT(15, 0, DataTypeUtils::max()); + + // values after rewind aren't equal + ASSERT_NE(r0, v0); + + // two generators must give the same output + ASSERT_EQ(v0, v1); + + // and here output must be equal as well + ASSERT_EQ(r1, r0); +} + +TEST_F(GraphRandomGeneratorTests, Sequential_Test_4) { + sd::graph::RandomGenerator g0(119, 5); + sd::graph::RandomGenerator g1(119, 5); + + auto v0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto v1 = g1.relativeT(15, 0, DataTypeUtils::max()); + g0.rewindH(200); + g1.rewindH(200); + auto r0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto r1 = g1.relativeT(15, 0, DataTypeUtils::max()); + g0.rewindH(200); + g1.rewindH(200); + auto z0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto z1 = g1.relativeT(15, 0, DataTypeUtils::max()); + g0.rewindH(201); + g1.rewindH(199); + auto y0 = g0.relativeT(15, 0, DataTypeUtils::max()); + auto y1 = g1.relativeT(15, 0, DataTypeUtils::max()); + + // values after rewind aren't equal + ASSERT_NE(r0, v0); + + // two generators must give the same output + ASSERT_EQ(v0, v1); + + // and here output must be equal as well + ASSERT_EQ(r0, r1); + + ASSERT_EQ(z0, z1); + + ASSERT_NE(r0, z0); + ASSERT_NE(r1, z1); + + ASSERT_NE(y0, z0); + ASSERT_NE(y1, z1); +} + + +//#ifndef __clang__ + +TEST_F(GraphRandomGeneratorTests, Long_Test_1) { + sd::graph::RandomGenerator g0(119, 5); + sd::graph::RandomGenerator g1(119, 5); + + std::array z0, z1, z2, z3; + + for (int e = 0; e < z0.size(); e++) { + z0[e] = g0.relativeT(e); + z1[e] = g1.relativeT(e); + } + + g0.rewindH(z0.size()); + g1.rewindH(z0.size()); + + for (int e = 0; e < z0.size(); e++) { + z2[e] = g0.relativeT(e); + z3[e] = g1.relativeT(e); + } + + // these sequences should be equal + ASSERT_EQ(z0, z1); + ASSERT_EQ(z2, z3); + + // these sequences should be different due to rewind + ASSERT_NE(z0, z3); + + // we'll be counting values > MAX_INT here + int maxes = 0; + + for (int e = 0; e < z0.size(); e++) { + auto v = z0[e]; + + // we don't want any negatives here + ASSERT_TRUE(v > 0); + + if (v > DataTypeUtils::max()) + maxes++; + } + + // and now we're ensuring there ARE values above MAX_INT + ASSERT_NE(0, maxes); +} + + +TEST_F(GraphRandomGeneratorTests, FloatingPoint_Test_1) { + sd::graph::RandomGenerator g0(119, 5); + sd::graph::RandomGenerator g1(119, 5); + + std::array z0, z1, z2, z3; + + for (int e = 0; e < z0.size(); e++) { + z0[e] = g0.relativeT(e, -1.0, 1.0); + z1[e] = g1.relativeT(e, -1.0, 1.0); + } + + g0.rewindH(z0.size()); + g1.rewindH(z0.size()); + + for (int e = 0; e < z0.size(); e++) { + z2[e] = g0.relativeT(e, -1.0, 1.0); + z3[e] = g1.relativeT(e, -1.0, 1.0); + } + + // these sequences should be equal + ASSERT_EQ(z0, z1); + ASSERT_EQ(z2, z3); + + // these sequences should be different due to rewind + ASSERT_NE(z0, z3); + + // we'll count negatives as well + int negs = 0; + + // make sure every value stays within distribution borders + for (int e = 0; e < z0.size(); e++) { + auto v = z0[e]; + if (!(v >= -1.0 && v <= 1.0)) { + nd4j_printf("Failed at idx [%i]: %f\n", e, (float) v); + ASSERT_TRUE(v >= -1.0 && v <= 1.0); + } + + if (v < 0.0) + negs++; + } + + // there should be negatives + ASSERT_TRUE(negs > 0); + + // and positives + ASSERT_NE(z0.size(), negs); +} + diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/GraphStateTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/GraphStateTests.cpp new file mode 100644 index 000000000..eabe9d965 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/GraphStateTests.cpp @@ -0,0 +1,351 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include + +using namespace sd; +using namespace sd::graph; + +class GraphStateTests : public testing::Test { +public: + GraphStateTests() { + Environment::getInstance().setDebug(false); + Environment::getInstance().setVerbose(false); + }; + + ~GraphStateTests() { + Environment::getInstance().setDebug(false); + Environment::getInstance().setVerbose(false); + } +}; + +/* + * PLAN: + * Create GraphState + * Register Scope + * Add few Ops to it + * Call conditional, that refers to scopes + * Check results + */ + +TEST_F(GraphStateTests, Basic_Tests_1) { + auto state = (GraphState *) getGraphState(117L); + ASSERT_EQ(117L, state->id()); + + // this call will create scope internally + state->registerScope(119); + + sd::ops::add opA; + sd::ops::LegacyTransformSameOp opB(transform::Neg); // simdOps::Neg + + ArgumentsList argsA; + ArgumentsList argsB; + + state->attachOpToScope(119, 1, &opA, argsA); + state->attachOpToScope(119, 2, &opB, argsB); + + auto scope = state->getScope(119); + ASSERT_TRUE(scope != nullptr); + ASSERT_EQ(2, scope->size()); + + deleteGraphState(state); +} + +// just separate case for doubles wrapper in NativeOps, nothing else +TEST_F(GraphStateTests, Basic_Tests_2) { + auto state = (GraphState *) getGraphState(117L); + ASSERT_EQ(117L, state->id()); + + // this call will create scope internally + state->registerScope(119); + + sd::ops::add opA; + sd::ops::LegacyTransformSameOp opB(transform::Neg); // simdOps::Neg + + ArgumentsList argsA; + ArgumentsList argsB; + + state->attachOpToScope(119, 1, &opA, argsA); + state->attachOpToScope(119, 2, &opB, argsB); + + auto scope = state->getScope(119); + ASSERT_TRUE(scope != nullptr); + ASSERT_EQ(2, scope->size()); + + deleteGraphState(state); +} + +/* +TEST_F(GraphStateTests, Stateful_Execution_1) { + auto state = getGraphState(117L); + + Nd4jLong scopes[] = {22, 33}; + //auto status = execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0); + auto status = execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0); + + ASSERT_EQ(Status::THROW(), status); + + deleteGraphState(state); +} + +TEST_F(GraphStateTests, Stateful_Execution_2) { + auto state = (GraphState *) getGraphState(117L); + + state->registerScope(22); + state->registerScope(33); + + Nd4jLong scopes[] = {22, 33}; + auto status = execCustomOpWithScope(nullptr, state, 10, scopes, 2, nullptr, nullptr, 0, nullptr, nullptr, 0); + // it's no-op: just LogicScope + ASSERT_EQ(Status::OK(), status); + + deleteGraphState(state); +} + +// This test checks WHILE loop +TEST_F(GraphStateTests, Stateful_Execution_3) { + auto var0 = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto var1 = NDArrayFactory::create(11.0f); + auto var2 = NDArrayFactory::create(2.0f); + + auto res0 = NDArrayFactory::create('c', {2, 2}); + auto res1 = NDArrayFactory::create(0.0f); + auto res2 = NDArrayFactory::create(0.0f); + + // registering our GraphState holder + auto state = (GraphState *) getGraphState(117L); + + // we're prepping pointers to input/output buffers + Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer(), (Nd4jPointer)var2.buffer()}; + Nd4jPointer ptrShapes[] = {(Nd4jPointer) var0.shapeInfo(), (Nd4jPointer) var1.shapeInfo(), (Nd4jPointer)var2.shapeInfo()}; + + Nd4jPointer outBuffers[] = {(Nd4jPointer) res0.buffer(), (Nd4jPointer) res1.buffer(), (Nd4jPointer) res2.buffer()}; + Nd4jPointer outShapes[] = {(Nd4jPointer) res0.shapeInfo(), (Nd4jPointer) res1.shapeInfo(), (Nd4jPointer) res2.shapeInfo()}; + + // conditional scope + state->registerScope(22); + + sd::ops::LegacyReduceSameOp op1(reduce::Sum); + sd::ops::lt_scalar op2; + + // while sum(var0) < var1 + // this op takes sum + ArgumentsList args1({{0, 0}}); + + // this op compares result of sum to input variable 0:1 + ArgumentsList args2({{1, 0}, {0, 1}}); + + state->attachOpToScope(22, 1, &op1, args1); + state->attachOpToScope(22, 2, &op2, args2); + + // body scope + state->registerScope(33); + + // var0 + var1 + var1 + // this op is var0 + var1 + ArgumentsList args3({{0, 0}, {0, 2}}); + + // this op is result of previous op + 1 + ArgumentsList args4({{3, 0}, {0, 2}}); + + sd::ops::add op3; + sd::ops::add op4; + + state->attachOpToScope(33, 3, &op3, args3); + state->attachOpToScope(33, 4, &op4, args4); + + // Now we define RETURN, which returns 1 modified variable, and 2 unmodified variables + ArgumentsList args5({{4, 0}, {0, 1}, {0, 2}}); + + // so, at the end of body, initial variables will be updated + state->defineReturn(33, 5, args5); + + Nd4jLong scopes[] = {22, 33}; + + // we're executing while loop + auto status = execCustomOpWithScope(nullptr, state, 0, scopes, 2, ptrBuffers, ptrShapes, 3, outBuffers, outShapes, 3); + ASSERT_EQ(Status::OK(), status); + + // now we check provided result array + float sum = res0.reduceNumber(reduce::Sum).e(0); + + // Expected result is {1, 2, 3, 4} + {2} elementwise + {2} elementwise, which gives { 5, 6, 7, 8}, and sum should be 26 + ASSERT_NEAR(26.0f, sum, 1e-5); + + // nd4j_printf("0 ------------------\n",""); + + deleteGraphState(state); + + // nd4j_printf("1 ------------------\n",""); +} + +// This test checks CONDITIONAL execution for FALSE +TEST_F(GraphStateTests, Stateful_Execution_4) { + auto var0 = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto var1 = NDArrayFactory::create(5.0f); + + auto res0 = NDArrayFactory::create('c', {2, 2}); + auto res1 = NDArrayFactory::create(0.0f); + + auto exp = NDArrayFactory::create('c', {2, 2}, {-4, -3, -2, -1}); + + + // registering our GraphState holder + auto state = (GraphState *) getGraphState(117L); + + // we're prepping pointers to input/output buffers + Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer()}; + Nd4jPointer ptrShapes[] = {(Nd4jPointer) var0.shapeInfo(), (Nd4jPointer) var1.shapeInfo()}; + + Nd4jPointer outBuffers[] = {(Nd4jPointer) res0.buffer(), (Nd4jPointer) res1.buffer()}; + Nd4jPointer outShapes[] = {(Nd4jPointer) res0.shapeInfo(), (Nd4jPointer) res1.shapeInfo()}; + + // conditional scope + state->registerScope(22); + + sd::ops::LegacyReduceSameOp op1(reduce::Sum); + sd::ops::lt_scalar op2; + + // if sum(var0) < var1 + // this op takes sum + ArgumentsList args1({{0, 0}}); + + // this op compares result of sum to input variable 0:1 + ArgumentsList args2({{1, 0}, {0, 1}}); + + state->attachOpToScope(22, 1, &op1, args1); + state->attachOpToScope(22, 2, &op2, args2); + + // false scope + state->registerScope(33); + + ArgumentsList args3({{0, 0}, {0, 1}}); + sd::ops::subtract op3; + state->attachOpToScope(33, 3, &op3, args3); + + // return for false scope + ArgumentsList args10({{3, 0}, {0, 1}}); + state->defineReturn(33, 10, args10); + + // true scope + state->registerScope(44); + + ArgumentsList args4({{0, 0}, {0, 1}}); + sd::ops::add op4; + state->attachOpToScope(44, 4, &op4, args4); + + // return for false scope + ArgumentsList args20({{4, 0}, {0, 1}}); + state->defineReturn(44, 20, args20); + + + Nd4jLong scopes[] = {22, 33, 44}; + + // we're executing conditional op + auto status = execCustomOpWithScope(nullptr, state, 20, scopes, 3, ptrBuffers, ptrShapes, 2, outBuffers, outShapes, 2); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(exp.isSameShape(&res0)); + ASSERT_TRUE(exp.equalsTo(&res0)); + + + deleteGraphState(state); +} + + +// This test checks CONDITIONAL execution for TRUE +TEST_F(GraphStateTests, Stateful_Execution_5) { + auto var0 = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto var1 = NDArrayFactory::create(5.0f); + + auto res0 = NDArrayFactory::create('c', {2, 2}); + auto res1 = NDArrayFactory::create(0.0f); + + auto exp = NDArrayFactory::create('c', {2, 2}, {6, 7, 8, 9}); + + + // registering our GraphState holder + auto state = (GraphState *) getGraphState(117L); + + // we're prepping pointers to input/output buffers + Nd4jPointer ptrBuffers[] = {(Nd4jPointer) var0.buffer(), (Nd4jPointer) var1.buffer()}; + Nd4jPointer ptrShapes[] = {(Nd4jPointer) var0.shapeInfo(), (Nd4jPointer) var1.shapeInfo()}; + + Nd4jPointer outBuffers[] = {(Nd4jPointer) res0.buffer(), (Nd4jPointer) res1.buffer()}; + Nd4jPointer outShapes[] = {(Nd4jPointer) res0.shapeInfo(), (Nd4jPointer) res1.shapeInfo()}; + + // conditional scope + state->registerScope(22); + + sd::ops::LegacyReduceSameOp op1(reduce::Sum); + sd::ops::gt_scalar op2; + + // if sum(var0) < var1 + // this op takes sum + ArgumentsList args1({{0, 0}}); + + // this op compares result of sum to input variable 0:1 + ArgumentsList args2({{1, 0}, {0, 1}}); + + state->attachOpToScope(22, 1, &op1, args1); + state->attachOpToScope(22, 2, &op2, args2); + + // false scope + state->registerScope(33); + + ArgumentsList args3({{0, 0}, {0, 1}}); + sd::ops::subtract op3; + state->attachOpToScope(33, 3, &op3, args3); + + // return for false scope + ArgumentsList args10({{3, 0}, {0, 1}}); + state->defineReturn(33, 10, args10); + + // true scope + state->registerScope(44); + + ArgumentsList args4({{0, 0}, {0, 1}}); + sd::ops::add op4; + state->attachOpToScope(44, 4, &op4, args4); + + // return for false scope + ArgumentsList args20({{4, 0}, {0, 1}}); + state->defineReturn(44, 20, args20); + + + Nd4jLong scopes[] = {22, 33, 44}; + + // we're executing conditional op + auto status = execCustomOpWithScope(nullptr, state, 20, scopes, 3, ptrBuffers, ptrShapes, 2, outBuffers, outShapes, 2); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(exp.isSameShape(&res0)); + ASSERT_TRUE(exp.equalsTo(&res0)); + + deleteGraphState(state); +} +*/ \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/GraphTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/GraphTests.cpp new file mode 100644 index 000000000..d2c82f219 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/GraphTests.cpp @@ -0,0 +1,1640 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace sd; +using namespace sd::graph; + +class GraphTests : public testing::Test { +public: + /* + int cShape[] = {2, 2, 2, 2, 1, 0, 1, 99}; + int fShape[] = {2, 2, 2, 1, 2, 0, 1, 102}; + */ + GraphTests() { + //Environment::getInstance().setDebug(true); + //Environment::getInstance().setVerbose(true); + } +}; + +TEST_F(GraphTests, SingleInput1) { + auto graph = new Graph(); + + auto x = NDArrayFactory::create_('c', {5, 5}); + x->assign(-2.0f); + + graph->getVariableSpace()->putVariable(-1, x); + + auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); + auto nodeB = new Node(OpType_TRANSFORM_STRICT, transform::Cosine, 2, {1}, {3}); + auto nodeC = new Node(OpType_TRANSFORM_SAME, transform::Abs, 3, {2}, {}); + + graph->addNode(nodeA); + graph->addNode(nodeB); + graph->addNode(nodeC); + + ASSERT_EQ(1, graph->rootNodes()); + ASSERT_EQ(3, graph->totalNodes()); + + GraphExecutioner::execute(graph); + + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(3)); + + auto node3 = graph->getVariableSpace()->getVariable(3)->getNDArray(); + + ASSERT_NEAR(0.4161468, node3->reduceNumber(reduce::Mean).e(0), 1e-5); + + delete graph; +} + +TEST_F(GraphTests, DoubleInput1) { + auto graph = new Graph(); + + auto x = NDArrayFactory::create_('c', {5, 5}); + x->assign(-2.0); + + auto y = NDArrayFactory::create_('c', {5, 5}); + y->assign(-1.0); + + auto z = NDArrayFactory::create_('c', {5, 5}); + + graph->getVariableSpace()->putVariable(-1, x); + graph->getVariableSpace()->putVariable(-2, y); + graph->getVariableSpace()->putVariable(-3, z); + + auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {3}); + auto nodeB = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {-2}, {3}); + auto nodeC = new Node(OpType_PAIRWISE, pairwise::Add, 3, {1, 2}, {-3}); + + graph->addNode(nodeA); + graph->addNode(nodeB); + graph->addNode(nodeC); + + ASSERT_EQ(2, graph->rootNodes()); + ASSERT_EQ(3, graph->totalNodes()); + + GraphExecutioner::execute(graph); + + ASSERT_NEAR(3.0, z->reduceNumber(reduce::Mean).e(0), 1e-5); + + delete graph; +} + +TEST_F(GraphTests, SingleInput3) { + auto graph = new Graph(); + + auto x = NDArrayFactory::create_('c', {5, 5}); + x->assign(-2.0); + + auto v0 = NDArrayFactory::create_('c', {5, 5}); + auto v1 = NDArrayFactory::create_('c', {5, 5}); + + graph->getVariableSpace()->putVariable(-1, x); + graph->getVariableSpace()->putVariable(-2, v0); + graph->getVariableSpace()->putVariable(-3, v1); + + auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2, 3}); + auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {-2}); + auto nodeC = new Node(OpType_TRANSFORM_SAME, transform::Ones, 3, {1}, {-3}); + + graph->addNode(nodeA); + graph->addNode(nodeB); + graph->addNode(nodeC); + + ASSERT_EQ(1, graph->rootNodes()); + ASSERT_EQ(3, graph->totalNodes()); + + GraphExecutioner::execute(graph); + + ASSERT_NEAR(1.4142135, v0->reduceNumber(reduce::Mean).e(0), 1e-5); + ASSERT_NEAR(1.0, v1->reduceNumber(reduce::Mean).e(0), 1e-5); + + delete graph; +} + +TEST_F(GraphTests, SingleInput4) { + auto graph = new Graph(); + + auto x = NDArrayFactory::create_('c', {5, 5}); + x->assign(-2.0); + + auto v0 = NDArrayFactory::create_('c', {5, 5}); + auto v1 = NDArrayFactory::create_('c', {5, 5}); + + graph->getVariableSpace()->putVariable(-1, x); + graph->getVariableSpace()->putVariable(-2, v0); + graph->getVariableSpace()->putVariable(-3, v1); + + auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); + auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {3}); + auto nodeC = new Node(OpType_TRANSFORM_SAME, transform::Neg, 3, {2}, {4, 5}); + + auto nodeS = new Node(OpType_TRANSFORM_SAME, transform::Ones, 4, {3}, {-2}); + auto nodeE = new Node(OpType_TRANSFORM_SAME, transform::Identity, 5, {3}, {-3}); + + graph->addNode(nodeA); + graph->addNode(nodeB); + graph->addNode(nodeC); + graph->addNode(nodeS); + graph->addNode(nodeE); + + ASSERT_EQ(1, graph->rootNodes()); + ASSERT_EQ(5, graph->totalNodes()); + + GraphExecutioner::execute(graph); + + ASSERT_NEAR(1.0, v0->reduceNumber(reduce::Mean).e(0), 1e-5); + ASSERT_NEAR(-1.4142135, v1->reduceNumber(reduce::Mean).e(0), 1e-5); + + delete graph; +} + + +TEST_F(GraphTests, DoubleInput2) { + auto graph = new Graph(); + + auto x = NDArrayFactory::create_('c', {5, 5}); + x->assign(-2.0); + + auto y = NDArrayFactory::create_('c', {5, 5}); + y->assign(-1.0); + + auto z0 = NDArrayFactory::create_('c', {5, 5}); + auto z1 = NDArrayFactory::create_('c', {5, 5}); + + graph->getVariableSpace()->putVariable(-1, x); + graph->getVariableSpace()->putVariable(-2, y); + graph->getVariableSpace()->putVariable(-3, z0); + graph->getVariableSpace()->putVariable(-4, z1); + + + auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); + auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {3}); + auto nodeC = new Node(OpType_TRANSFORM_SAME, transform::Neg, 3, {2}, {-3}); + + auto nodeT = new Node(OpType_TRANSFORM_SAME, transform::Abs, 11, {-2}, {12}); + auto nodeU = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 12, {11}, {13}); + auto nodeV = new Node(OpType_TRANSFORM_SAME, transform::Neg, 13, {12}, {-4}); + + graph->addNode(nodeA); + graph->addNode(nodeB); + graph->addNode(nodeC); + graph->addNode(nodeT); + graph->addNode(nodeU); + graph->addNode(nodeV); + + ASSERT_EQ(2, graph->rootNodes()); + ASSERT_EQ(6, graph->totalNodes()); + + GraphExecutioner::execute(graph); + + ASSERT_NEAR(-1.4142135, z0->reduceNumber(reduce::Mean).e(0), 1e-5); + ASSERT_NEAR(-1.0, z1->reduceNumber(reduce::Mean).e(0), 1e-5); + + delete graph; +} + + +TEST_F(GraphTests, DoubleInput3) { + auto graph = new Graph(); + + auto x = NDArrayFactory::create_('c', {5, 5}); + x->assign(-2.0); + + auto y = NDArrayFactory::create_('c', {5, 5}); + y->assign(-1.0); + + auto z0 = NDArrayFactory::create_('c', {5, 5}); + auto z1 = NDArrayFactory::create_('c', {5, 5}); + + + auto w = NDArrayFactory::create_('c', {5, 5}); + + graph->getVariableSpace()->putVariable(-1, x); + graph->getVariableSpace()->putVariable(-2, y); + graph->getVariableSpace()->putVariable(-3, z0); + graph->getVariableSpace()->putVariable(-4, z1); + graph->getVariableSpace()->putVariable(-5, w); + + + auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); + auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {3}); + auto nodeC = new Node(OpType_TRANSFORM_SAME, transform::Neg, 3, {2}, {-3, 21}); + + auto nodeT = new Node(OpType_TRANSFORM_SAME, transform::Abs, 11, {-2}, {12}); + auto nodeU = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 12, {11}, {13}); + auto nodeV = new Node(OpType_TRANSFORM_SAME, transform::Neg, 13, {12}, {-4, 21}); + + auto nodeW = new Node(OpType_PAIRWISE, pairwise::Add, 21, {3, 13}, {22}); + auto nodeZ = new Node(OpType_TRANSFORM_SAME, transform::Abs, 22, {21}, {-5}); + + graph->addNode(nodeA); + graph->addNode(nodeB); + graph->addNode(nodeC); + graph->addNode(nodeT); + graph->addNode(nodeU); + graph->addNode(nodeV); + graph->addNode(nodeW); + graph->addNode(nodeZ); + + ASSERT_EQ(2, graph->rootNodes()); + ASSERT_EQ(8, graph->totalNodes()); + + GraphExecutioner::execute(graph); + + ASSERT_NEAR(-1.4142135, z0->reduceNumber(reduce::Mean).e(0), 1e-5); + ASSERT_NEAR(-1.0, z1->reduceNumber(reduce::Mean).e(0), 1e-5); + + ASSERT_NEAR(2.4142135, w->reduceNumber(reduce::Mean).e(0), 1e-5); + + delete graph; +} + + +TEST_F(GraphTests, QuadInput1) { + auto graph = new Graph(); + + auto x0 = NDArrayFactory::create_('c', {5, 5}); + x0->assign(0.0); + + auto x1 = NDArrayFactory::create_('c', {5, 5}); + x1->assign(-1.0); + + auto x2 = NDArrayFactory::create_('c', {5, 5}); + x2->assign(-2.0); + + auto x3 = NDArrayFactory::create_('c', {5, 5}); + x3->assign(-3.0); + + auto z = NDArrayFactory::create_('c', {5, 5}); + z->assign(119.0); + + graph->getVariableSpace()->putVariable(-1, x0); + graph->getVariableSpace()->putVariable(-2, x1); + graph->getVariableSpace()->putVariable(-3, x2); + graph->getVariableSpace()->putVariable(-4, x3); + graph->getVariableSpace()->putVariable(-5, z); + + auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {11}); + auto nodeB = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {-2}, {11}); + auto nodeC = new Node(OpType_TRANSFORM_SAME, transform::Abs, 3, {-3}, {21}); + auto nodeD = new Node(OpType_TRANSFORM_SAME, transform::Abs, 4, {-4}, {21}); + + auto nodeP1 = new Node(OpType_PAIRWISE, pairwise::Add, 11, {1, 2}, {31}); + auto nodeP2 = new Node(OpType_PAIRWISE, pairwise::Add, 21, {3, 4}, {31}); + + auto nodeZ = new Node(OpType_PAIRWISE, pairwise::Add, 31, {11, 21}, {-5}); + + graph->addNode(nodeA); + graph->addNode(nodeB); + graph->addNode(nodeC); + graph->addNode(nodeD); + graph->addNode(nodeP1); + graph->addNode(nodeP2); + graph->addNode(nodeZ); + + ASSERT_EQ(4, graph->rootNodes()); + ASSERT_EQ(7, graph->totalNodes()); + + GraphExecutioner::execute(graph); + + ASSERT_NEAR(6.0, z->reduceNumber(reduce::Mean).e(0), 1e-5); + + delete graph; +} + +TEST_F(GraphTests, InternalBranching1) { + auto graph = new Graph(); + + auto x = NDArrayFactory::create_('c', {5, 5}); + x->assign(0.0); + + auto z = NDArrayFactory::create_('c', {5, 5}); + + graph->getVariableSpace()->putVariable(-1, x); + graph->getVariableSpace()->putVariable(-2, z); + + // 1.0 + auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Ones, 1, {-1}, {11, 21}); + + // -1 + auto nodeK = new Node(OpType_TRANSFORM_SAME, transform::Neg, 11, {1}, {12}); + + // 2.0 + auto nodeL = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 12, {11}, {31}); + + // -1 + auto nodeR = new Node(OpType_TRANSFORM_SAME, transform::Neg, 21, {1}, {22}); + + // 1 + auto nodeS = new Node(OpType_TRANSFORM_SAME, transform::Neg, 22, {21}, {31}); + + // 1.0 + auto nodeZ = new Node(OpType_PAIRWISE, pairwise::Add, 31, {12, 22}, {-2}); + + graph->addNode(nodeA); + graph->addNode(nodeK); + graph->addNode(nodeL); + graph->addNode(nodeR); + graph->addNode(nodeS); + graph->addNode(nodeZ); + + ASSERT_EQ(1, graph->rootNodes()); + ASSERT_EQ(6, graph->totalNodes()); + + GraphExecutioner::execute(graph); + + ASSERT_EQ(3, nodeZ->getLayer()); + + ASSERT_NEAR(3.0, z->reduceNumber(reduce::Mean).e(0), 1e-5); + + delete graph; +} + + +TEST_F(GraphTests, ReductionsTest1) { + auto graph = new Graph(); + + auto x = NDArrayFactory::create_('c', {5, 5}); + for (int r = 0; r < x->rows(); r++) { + for (int c = 0; c < x->columns(); c++) { + x->p(r, c, -c); + } + } + + auto z = NDArrayFactory::create_('c', {5}); + + graph->getVariableSpace()->putVariable(-1, x); + graph->getVariableSpace()->putVariable(-2, z); + +// sd::graph::Node::Node(OpType opType, int opNum, int id, std::initializer_list input, std::initializer_list output, std::initializer_list dimensions, float scalar, std::initializer_list tArgs, std::initializer_list iArgs) { + + auto nodeA = new Node(OpType_REDUCE_FLOAT, reduce::Mean, 1, {-1}, {2}, {1}, {}); + auto nodeB = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {-2}); + + graph->addNode(nodeA); + graph->addNode(nodeB); + + ASSERT_EQ(1, graph->rootNodes()); + ASSERT_EQ(2, graph->totalNodes()); + + GraphExecutioner::execute(graph); + + ASSERT_NEAR(2.0, z->reduceNumber(reduce::Mean).e(0), 1e-5); + + delete graph; +} + + +TEST_F(GraphTests, IndexReductionsTest1) { + auto graph = new Graph(); + + auto x = NDArrayFactory::create_('c', {5, 5}); + for (int r = 0; r < x->rows(); r++) { + for (int c = 0; c < x->columns(); c++) { + x->p(r, c, -c); + } + } + + auto z = NDArrayFactory::create_('c', {5, 1}); + auto axis = NDArrayFactory::create_('c', {1}, {1}); + graph->getVariableSpace()->putVariable(-1, x); + graph->getVariableSpace()->putVariable(-2, z); + //graph->getVariableSpace()->putVariable(-3, axis); + + + auto nodeA = new Node(OpType_INDEX_REDUCE, indexreduce::IndexMin, 1, {-1}, {2}, {1}); + auto nodeB = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {-2}); + + graph->addNode(nodeA); + graph->addNode(nodeB); + + ASSERT_EQ(1, graph->rootNodes()); + ASSERT_EQ(2, graph->totalNodes()); + + GraphExecutioner::execute(graph); + + ASSERT_NEAR(4.0, z->reduceNumber(reduce::Mean).e(0), 1e-5); + + delete graph; + delete axis; +} + +#if 0 +TEST_F(GraphTests, AutoOutput1) { + auto graph = new Graph(); + auto x = NDArrayFactory::create_('c', {5, 5}); + x->assign(-2.0); + + graph->getVariableSpace()->putVariable(-1, x); + + auto nodeA = new Node(OpType_TRANSFORM_FLOAT, 0, 1, {-1}, {2}); + auto nodeB = new Node(OpType_TRANSFORM_FLOAT, 35, 2, {1}, {}); + + graph->addNode(nodeA); + graph->addNode(nodeB); + + ASSERT_EQ(1, graph->rootNodes()); + ASSERT_EQ(2, graph->totalNodes()); + + graph->buildGraph(); + + ASSERT_TRUE(graph->getVariableSpace()->getVariable(2) != nullptr); + + GraphExecutioner::execute(graph); + + auto outputs = graph->fetchOutputs(); + + ASSERT_EQ(1, outputs->size()); + + ASSERT_TRUE(outputs->at(0) != nullptr); + + ASSERT_NEAR(-1.0, outputs->at(0)->getNDArray()->reduceNumber(reduce::Mean).e(0), 1e-5); + + delete outputs; + delete graph; +} + + +TEST_F(GraphTests, AutoOutput2) { + auto graph = new Graph(); + auto x = NDArrayFactory::create_('c', {5, 5}); + x->assign(-2.0); + + graph->getVariableSpace()->putVariable(-1, x); + + auto nodeA = new Node(OpType_TRANSFORM_SAME, 0, 1, {-1}, {2, 3, -1}); + auto nodeB = new Node(OpType_TRANSFORM_SAME, 35, 2, {1}, {}); + auto nodeC = new Node(OpType_TRANSFORM_SAME, 6, 3, {1}, {}); + + graph->addNode(nodeA); + graph->addNode(nodeB); + graph->addNode(nodeC); + + ASSERT_EQ(1, graph->rootNodes()); + ASSERT_EQ(3, graph->totalNodes()); + + graph->buildGraph(); + + ASSERT_TRUE(graph->getVariableSpace()->getVariable(-1) != nullptr); + ASSERT_TRUE(graph->getVariableSpace()->getVariable(2) != nullptr); + ASSERT_TRUE(graph->getVariableSpace()->getVariable(3) != nullptr); + + GraphExecutioner::execute(graph); + + auto outputs = graph->fetchOutputs(); + + ASSERT_EQ(2, outputs->size()); + + ASSERT_TRUE(outputs->at(0) != nullptr); + + ASSERT_NEAR(-1.0, outputs->at(0)->getNDArray()->reduceNumber(reduce::Mean).e(0), 1e-5); + ASSERT_NEAR(-2.0, outputs->at(1)->getNDArray()->reduceNumber(reduce::Mean).e(0), 1e-5); + + delete graph; + delete outputs; +} +#endif + +TEST_F(GraphTests, BroadcastTest1) { + auto graph = new Graph(); + auto x = NDArrayFactory::create_('c', {5, 5}); + x->assign(0.f); + + auto y = NDArrayFactory::create_('c', {1, 5}); + for (int e = 0; e < y->columns(); e++) { + y->p(e, (float)e+1); + } + + auto z = NDArrayFactory::create_('c', {5, 5}); + + graph->getVariableSpace()->putVariable(-1, x); + graph->getVariableSpace()->putVariable(-2, y); + graph->getVariableSpace()->putVariable(-3, z); + + auto nodeA = new Node(OpType_BROADCAST, broadcast::Subtract, 1, {-1, -2}, {2}, {1}); + auto nodeB = new Node(OpType_TRANSFORM_SAME, transform::Neg, 2, {1}, {-3}); + + graph->addNode(nodeA); + graph->addNode(nodeB); + + GraphExecutioner::execute(graph); + + ASSERT_NEAR(3.0, z->reduceNumber(reduce::Mean).e(0), 1e-5); + + delete graph; +} + + +TEST_F(GraphTests, ScalarTest1) { + auto graph = new Graph(); + + auto x = NDArrayFactory::create_('c', {5, 5}); + x->assign(-2.0); + + auto z = NDArrayFactory::create_('c', {5, 5}); + + graph->getVariableSpace()->putVariable(-1, x); + graph->getVariableSpace()->putVariable(-2, z); + + auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); + auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {3}); + auto nodeE = new Node(OpType_SCALAR, scalar::Add, 3, {2}, {-2}, {}, 1.3f); + + graph->addNode(nodeA); + graph->addNode(nodeB); + graph->addNode(nodeE); + + ASSERT_EQ(1, graph->rootNodes()); + ASSERT_EQ(3, graph->totalNodes()); + + GraphExecutioner::execute(graph); + + ASSERT_NEAR(2.714213, z->reduceNumber(reduce::Mean).e(0), 1e-5); + + delete graph; +} + +TEST_F(GraphTests, SymbolicLookupTest1) { + auto graph = new Graph(); + + auto x = NDArrayFactory::create_('c', {5, 5}); + x->assign(-2.0); + + auto z = NDArrayFactory::create_('c', {5, 5}); + + auto vX = new Variable(x); + auto vZ = new Variable(z); + + std::string a("alpha"); + std::string o("omega"); + + vX->setName(&a); + vZ->setName(&o); + + graph->getVariableSpace()->putVariable(-1, vX); + graph->getVariableSpace()->putVariable(-2, vZ); + + auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); + auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {-2}); + + std::string p("phi"); + std::string t("theta"); + + nodeA->setName(&p); + nodeB->setName(&t); + + graph->addNode(nodeA); + graph->addNode(nodeB); + + + auto rX = graph->getVariableSpace()->getVariable(&a); + auto rZ = graph->getVariableSpace()->getVariable(&o); + + std::string om("omicron"); + + ASSERT_TRUE(rX->getNDArray() == vX->getNDArray()); + ASSERT_TRUE(rZ->getNDArray() == vZ->getNDArray()); + ASSERT_FALSE(graph->getVariableSpace()->hasVariable(&om)); + + + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(1)); + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(2)); + + GraphExecutioner::execute(graph); + + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(&p)); + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(&t)); + + ASSERT_NEAR(1.4142135, z->reduceNumber(reduce::Mean).e(0), 1e-5); + + delete graph; +} + +TEST_F(GraphTests, OutputValidation1) { + auto graph = new Graph(); + + graph->getExecutorConfiguration()->_outputMode = OutputMode_EXPLICIT; + + auto x = NDArrayFactory::create_('c', {5, 5}); + x->assign(-2.0); + + auto z = NDArrayFactory::create_('c', {5, 5}); + + auto vX = new Variable(x); + auto vZ = new Variable(z); + + std::string a("alpha"); + std::string o("omega"); + + vX->setName(&a); + vZ->setName(&o); + + graph->getVariableSpace()->putVariable(-1, vX); + graph->getVariableSpace()->putVariable(-2, vZ); + + auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); + auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {-2}); + + graph->addNode(nodeA); + graph->addNode(nodeB); + + + auto outputs = graph->fetchOutputs(); + + ASSERT_EQ(0, outputs->size()); + + delete graph; + delete outputs; +} + +TEST_F(GraphTests, OutputValidation2) { + auto graph = new Graph(); + + graph->getExecutorConfiguration()->_outputMode = OutputMode_EXPLICIT; + + auto x = NDArrayFactory::create_('c', {5, 5}); + x->assign(-2.0); + + auto z = NDArrayFactory::create_('c', {5, 5}); + + auto vX = new Variable(x); + auto vZ = new Variable(z); + + std::string a("alpha"); + std::string o("omega"); + + vX->setName(&a); + vZ->setName(&o); + + graph->getVariableSpace()->putVariable(-1, vX); + graph->getVariableSpace()->putVariable(-2, vZ); + + auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); + auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {-2}); + + graph->addNode(nodeA); + graph->addNode(nodeB); + + graph->addOutput(-2); + + GraphExecutioner::execute(graph); + + auto outputs = graph->fetchOutputs(); + + ASSERT_EQ(1, outputs->size()); + + ASSERT_NEAR(1.4142135, outputs->at(0)->getNDArray()->reduceNumber(reduce::Mean).e(0), 1e-5); + + delete graph; + delete outputs; +} + +TEST_F(GraphTests, OutputValidation3) { + auto graph = new Graph(); + + graph->getExecutorConfiguration()->_outputMode = OutputMode_IMPLICIT; + + auto x = NDArrayFactory::create_('c', {5, 5}); + x->assign(-2.0); + + auto z = NDArrayFactory::create_('c', {5, 5}); + + auto vX = new Variable(x); + auto vZ = new Variable(z); + + std::string a("alpha"); + std::string o("omega"); + + vX->setName(&a); + vZ->setName(&o); + + graph->getVariableSpace()->putVariable(-1, vX); + graph->getVariableSpace()->putVariable(-2, vZ); + + auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); + auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {}); + + graph->addNode(nodeA); + graph->addNode(nodeB); + + GraphExecutioner::execute(graph); + + auto outputs = graph->fetchOutputs(); + + ASSERT_EQ(1, outputs->size()); + + ASSERT_NEAR(1.4142135, outputs->at(0)->getNDArray()->reduceNumber(reduce::Mean).e(0), 1e-5); + + delete graph; + delete outputs; +} + +TEST_F(GraphTests, OutputValidation4) { + auto graph = new Graph(); + + graph->getExecutorConfiguration()->_outputMode = OutputMode_EXPLICIT_AND_IMPLICIT; + + auto x = NDArrayFactory::create_('c', {5, 5}); + x->assign(-2.0); + + auto z = NDArrayFactory::create_('c', {5, 5}); + + auto vX = new Variable(x); + auto vZ = new Variable(z); + + std::string a("alpha"); + std::string o("omega"); + + vX->setName(&a); + vZ->setName(&o); + + graph->getVariableSpace()->putVariable(-1, vX); + graph->getVariableSpace()->putVariable(-2, vZ); + + auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); + auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {-2}); + + graph->addOutput(-1); + + // not a typo. we want this value only once + graph->addOutput(-1); + + graph->addNode(nodeA); + graph->addNode(nodeB); + + GraphExecutioner::execute(graph); + + auto outputs = graph->fetchOutputs(); + + ASSERT_EQ(2, outputs->size()); + + ASSERT_NEAR(1.4142135, outputs->at(1)->getNDArray()->reduceNumber(reduce::Mean).e(0), 1e-5); + + delete graph; + delete outputs; +} + + +TEST_F(GraphTests, OutputValidation5) { + auto graph = new Graph(); + + graph->getExecutorConfiguration()->_outputMode = OutputMode_VARIABLE_SPACE; + + auto x = NDArrayFactory::create_('c', {5, 5}); + x->assign(-2.0); + + auto z = NDArrayFactory::create_('c', {5, 5}); + + auto vX = new Variable(x); + auto vZ = new Variable(z); + + std::string a("alpha"); + std::string o("omega"); + + vX->setName(&a); + vZ->setName(&o); + + graph->getVariableSpace()->putVariable(-1, vX); + graph->getVariableSpace()->putVariable(-2, vZ); + + auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); + auto nodeB = new Node(OpType_TRANSFORM_SAME, transform::Sqrt, 2, {1}, {-2}); + + graph->addOutput(-1); + + graph->addNode(nodeA); + graph->addNode(nodeB); + + GraphExecutioner::execute(graph); + + auto outputs = graph->fetchOutputs(); + + ASSERT_EQ(4, outputs->size()); + + delete graph; + delete outputs; +} + +TEST_F(GraphTests, OutputValidation6) { + auto graph = new Graph(); + + graph->getExecutorConfiguration()->_outputMode = OutputMode_VARIABLE_SPACE; + + auto x = NDArrayFactory::create_('c', {5, 5}); + x->assign(-2.0); + + auto z = NDArrayFactory::create_('c', {5, 5}); + + auto vX = new Variable(x); + auto vZ = new Variable(z); + + std::string a("alpha"); + std::string o("omega"); + + vX->setName(&a); + vZ->setName(&o); + + graph->getVariableSpace()->putVariable(-1, vX); + graph->getVariableSpace()->putVariable(-2, vZ); + + auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); + auto nodeB = new Node(OpType_TRANSFORM_FLOAT, transform::Sqrt, 2, {1}, {}); + + //graph->addOutput(-1); + + graph->addNode(nodeA); + graph->addNode(nodeB); + + GraphExecutioner::execute(graph); + + auto outputs = graph->fetchOutputs(); + +// nd4j_printf("Returned variables: \n", ""); +// for (int e = 0; e < outputs->size(); e++) { +// printf("%i, ", outputs->at(e)->id()); +// } +// printf("\n"); + + ASSERT_EQ(4, outputs->size()); + + //ASSERT_NEAR(1.4142135, graph->fetchOutputs()->at(1)->getNDArray()->reduceNumber(reduce::Mean).e(0), 1e-5); + delete graph; + delete outputs; +} + +TEST_F(GraphTests, TestMultiOutput1) { + sd::ops::testop2i2o op1; + auto graph = new Graph(); + + auto x = NDArrayFactory::create_('c', {5, 5}); + x->assign(-2.0); + + auto y = NDArrayFactory::create_('c', {5, 5}); + y->assign(-3.0); + + graph->getVariableSpace()->putVariable(-1, x); + graph->getVariableSpace()->putVariable(-2, y); + + + // Abs + auto nodeA0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {11}); + nodeA0->markInplace(false); + auto nodeB0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {-2}, {11}); + nodeB0->markInplace(false); + + auto op = sd::ops::OpRegistrator::getInstance().getOperation("testop2i2o"); + + // this op will add 1.0 to first input, and 2.0 for second input + auto nodeT = new Node(op, 11, {1, 2}, {21, 31}, {}, 0.0f); + nodeT->setName("TestOp2i2o"); + nodeT->markInplace(false); + + + // this op will subtract this value from 1.0 + auto nodeX = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 21); + nodeX->markInplace(false); + nodeX->pickInput(11, 0); + + // this op will subtract this value from 1.0 + auto nodeY = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 31); + nodeY->markInplace(false); + nodeY->pickInput(11, 1); + + graph->addNode(nodeA0); + graph->addNode(nodeB0); + graph->addNode(nodeT); + graph->addNode(nodeX); + graph->addNode(nodeY); + + std::pair pair0(11,0); + std::pair pair1(11,1); + + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(pair0)); + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(pair1)); + + Nd4jStatus status = GraphExecutioner::execute(graph); + + ASSERT_EQ(ND4J_STATUS_OK, status); + + ASSERT_NEAR(-2.0f, graph->getVariableSpace()->getVariable(21)->getNDArray()->meanNumber().e(0), 1e-5); + ASSERT_NEAR(-4.0f, graph->getVariableSpace()->getVariable(31)->getNDArray()->meanNumber().e(0), 1e-5); + + delete graph; +} + +TEST_F(GraphTests, TestDivergentNode1) { + auto op = sd::ops::OpRegistrator::getInstance().getOperation("Switch"); + auto nodeY = new Node(op, 1); + + ASSERT_TRUE(nodeY->isDivergencePoint()); + ASSERT_TRUE(nodeY->isActive()); + + delete nodeY; +} + + +TEST_F(GraphTests, MemoryEstimationTest1) { + Graph graph; + + auto x = NDArrayFactory::create_('c', {5, 5}); + x->assign(-2.0); + + graph.getVariableSpace()->putVariable(-1, x); + + auto nodeA0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); + auto nodeA1 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {}); + nodeA1->markInplace(false); + + graph.addNode(nodeA0); + graph.addNode(nodeA1); + + ASSERT_EQ(2, graph.totalNodes()); + ASSERT_EQ(1, graph.rootNodes()); + + auto memReq = graph.estimateRequiredMemory(); + + ASSERT_EQ(25 * x->sizeOfT(), memReq); +} + +TEST_F(GraphTests, MemoryEstimationTest2) { + Graph graph; + + auto x = NDArrayFactory::create_('c', {5, 5}); + x->assign(-2.0); + + graph.getVariableSpace()->putVariable(-1, x); + + auto nodeA0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); + auto nodeA1 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {}); + //nodeA1->markInplace(false); + + graph.addNode(nodeA0); + graph.addNode(nodeA1); + + ASSERT_EQ(2, graph.totalNodes()); + ASSERT_EQ(1, graph.rootNodes()); + + auto memReq = graph.estimateRequiredMemory(); + + ASSERT_EQ(0, memReq); +} + +TEST_F(GraphTests, MemoryEstimationTest3) { + Graph graph; + + auto x = NDArrayFactory::create_('c', {5, 5}); + x->assign(-2.0); + + graph.getVariableSpace()->putVariable(-1, x); + + auto nodeA0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); + auto nodeA1 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {3}); + auto nodeA2 = new Node(OpType_REDUCE_FLOAT, reduce::Mean, 3, {2}, {}, {}); + nodeA1->markInplace(false); + + graph.addNode(nodeA0); + graph.addNode(nodeA1); + graph.addNode(nodeA2); + + ASSERT_EQ(3, graph.totalNodes()); + ASSERT_EQ(1, graph.rootNodes()); + + auto memReq = graph.estimateRequiredMemory(); + + ASSERT_EQ(26 * x->sizeOfT(), memReq); +} + +TEST_F(GraphTests, MemoryEstimationTest4) { + Graph graph; + + auto x = NDArrayFactory::create_('c', {5, 5}); + x->assign(-2.0); + + graph.getVariableSpace()->putVariable(-1, x); + + auto nodeA0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); + auto nodeA1 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {3}); + auto nodeA2 = new Node(OpType_REDUCE_FLOAT, reduce::Mean, 3, {2}, {}, {1}); + nodeA1->markInplace(false); + + graph.addNode(nodeA0); + graph.addNode(nodeA1); + graph.addNode(nodeA2); + + ASSERT_EQ(3, graph.totalNodes()); + ASSERT_EQ(1, graph.rootNodes()); + + auto memReq = graph.estimateRequiredMemory(); + + ASSERT_EQ(30 * x->sizeOfT(), memReq); +} + +TEST_F(GraphTests, MemoryEstimationTest5) { + Graph graph; + + auto x = NDArrayFactory::create_('c', {5, 5}); + x->assign(-2.0); + + graph.getVariableSpace()->putVariable(-1, x); + + sd::ops::testcustom op; + + auto nodeA0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); + auto nodeA1 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {3}); + auto nodeA2 = new Node(&op, 3, {2}, {}, {}); + nodeA1->markInplace(false); + + + graph.addNode(nodeA0); + graph.addNode(nodeA1); + graph.addNode(nodeA2); + + graph.buildGraph(); + + ASSERT_EQ(3, graph.totalNodes()); + ASSERT_EQ(1, graph.rootNodes()); + + auto memReq = graph.estimateRequiredMemory(); + + ASSERT_EQ((25 + 100) * x->sizeOfT(), memReq); +} + +TEST_F(GraphTests, TestGraphInGraph_1) { + // this one is external graph + Graph graphA; + + // and this ons is embedded + Graph graphB; + + auto x = NDArrayFactory::create_('c', {5, 5}); + x->assign(-5.0); + + auto modifier = NDArrayFactory::create_('c', {5, 5}); + modifier->assign(3.0); + + graphA.getVariableSpace()->putVariable(-1, x); + graphB.getVariableSpace()->putVariable(-2, modifier); + + // this is placeholder variable + graphB.getVariableSpace()->putVariable(-1, new Variable(true)); + + // abs, result is 5 + auto nodeA0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); + // 1-, result -4 + auto nodeA1 = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 2, {1}, {3}); + + // graph should return 12: abs(3.0 x -4) + auto nodeA2 = new Node(OpType_GRAPH, -1, 3, {2}, {4}); + + // 1 - 12 = -11 + auto nodeA3 = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 4, {3}, {}); + + nodeA2->setGraph(&graphB); + + graphA.addNode(nodeA0); + graphA.addNode(nodeA1); + graphA.addNode(nodeA2); + graphA.addNode(nodeA3); + + // this is going to be PWT + auto nodeB0 = new Node(OpType_PAIRWISE, pairwise::Multiply, 1, {-1, -2}, {2}); + auto nodeB1 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {}); + + graphB.addNode(nodeB0); + graphB.addNode(nodeB1); + + graphB.buildGraph(); + graphA.buildGraph(); + + ASSERT_EQ(0, nodeA0->getLayer()); + ASSERT_EQ(1, nodeA1->getLayer()); + ASSERT_EQ(2, nodeA2->getLayer()); + ASSERT_EQ(3, nodeA3->getLayer()); + + ASSERT_EQ(0, nodeB0->getLayer()); + ASSERT_EQ(1, nodeB1->getLayer()); + + Nd4jStatus status = GraphExecutioner::execute(&graphA); + ASSERT_EQ(ND4J_STATUS_OK, status); + + float m = graphA.getVariableSpace()->getVariable(4)->getNDArray()->meanNumber().e(0); + + //nd4j_printf("OpResult: %f\n", m); + + ASSERT_NEAR(-11.0, m, 1e-5); +} + +// test for symbolic lookup +TEST_F(GraphTests, TestGraphInGraph_2) { + // this one is external graph + Graph graphA; + + // and this ons is embedded + Graph graphB; + + auto x = NDArrayFactory::create_('c', {5, 5}); + x->assign(-5.0); + + auto modifier = NDArrayFactory::create_('c', {5, 5}); + modifier->assign(3.0); + + std::string nameA1("_nodeA1"); + + graphA.getVariableSpace()->putVariable(-1, x); + graphB.getVariableSpace()->putVariable(-2, modifier); + + // this is placeholder variable + auto placeHolder = new Variable(true); + placeHolder->setName(&nameA1); + graphB.getVariableSpace()->putVariable(-1, placeHolder); + + // abs, result is 5 + auto nodeA0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); + // 1-, result -4 + auto nodeA1 = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 2, {1}, {3}); + nodeA1->setName(nameA1); + + // graph should return 12: abs(3.0 x -4) + auto nodeA2 = new Node(OpType_GRAPH, -1, 3, {2}, {4}); + + // 1 - 12 = -11 + auto nodeA3 = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 4, {3}, {}); + + nodeA2->setGraph(&graphB); + + graphA.addNode(nodeA0); + graphA.addNode(nodeA1); + graphA.addNode(nodeA2); + graphA.addNode(nodeA3); + + // this is going to be PWT + auto nodeB0 = new Node(OpType_PAIRWISE, pairwise::Multiply, 1, {-1, -2}, {2}); + auto nodeB1 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {}); + + graphB.addNode(nodeB0); + graphB.addNode(nodeB1); + + graphB.buildGraph(); + graphA.buildGraph(); + + ASSERT_EQ(0, nodeA0->getLayer()); + ASSERT_EQ(1, nodeA1->getLayer()); + ASSERT_EQ(2, nodeA2->getLayer()); + ASSERT_EQ(3, nodeA3->getLayer()); + + ASSERT_EQ(0, nodeB0->getLayer()); + ASSERT_EQ(1, nodeB1->getLayer()); + + Nd4jStatus status = GraphExecutioner::execute(&graphA); + ASSERT_EQ(ND4J_STATUS_OK, status); + + float m = graphA.getVariableSpace()->getVariable(4)->getNDArray()->meanNumber().e(0); + + //nd4j_printf("OpResult: %f\n", m); + + ASSERT_NEAR(-11.0, m, 1e-5); +} + +#if 0 +TEST_F(GraphTests, Test_Clone_1) { + auto exp = NDArrayFactory::create('c', {3}); + exp.assign(3.0); + + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb"); + auto variableSpace = graph->getVariableSpace(); + //graph->buildGraph(); + + auto clone = graph->clone(); + + Nd4jStatus statusOriginal = GraphExecutioner::execute(graph); + + ASSERT_EQ(ND4J_STATUS_OK, statusOriginal); + ASSERT_TRUE(variableSpace->hasVariable(3)); + + Nd4jStatus statusClone = GraphExecutioner::execute(clone); + + ASSERT_EQ(ND4J_STATUS_OK, statusClone); + + ASSERT_TRUE(variableSpace->hasVariable(3)); + + auto z0 = variableSpace->getVariable(3)->getNDArray(); + auto z1 = clone->getVariableSpace()->getVariable(3)->getNDArray(); + + ASSERT_TRUE(exp.isSameShape(z0)); + ASSERT_TRUE(exp.equalsTo(z0)); + + ASSERT_TRUE(exp.isSameShape(z1)); + ASSERT_TRUE(exp.equalsTo(z1)); + + delete graph; + delete clone; +} + + + + +TEST_F(GraphTests, Test_Clone_2) { + auto exp = NDArrayFactory::create('c', {3}); + exp.assign(3.0); + + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb"); + auto variableSpace = graph->getVariableSpace(); + graph->buildGraph(); + + auto clone = graph->clone(); + + Nd4jStatus statusOriginal = GraphExecutioner::execute(graph); + + ASSERT_EQ(ND4J_STATUS_OK, statusOriginal); + ASSERT_TRUE(variableSpace->hasVariable(3)); + + Nd4jStatus statusClone = GraphExecutioner::execute(clone); + + ASSERT_EQ(ND4J_STATUS_OK, statusClone); + + ASSERT_TRUE(variableSpace->hasVariable(3)); + + auto z0 = variableSpace->getVariable(3)->getNDArray(); + auto z1 = clone->getVariableSpace()->getVariable(3)->getNDArray(); + + ASSERT_TRUE(exp.isSameShape(z0)); + ASSERT_TRUE(exp.equalsTo(z0)); + + ASSERT_TRUE(exp.isSameShape(z1)); + ASSERT_TRUE(exp.equalsTo(z1)); + + delete graph; + delete clone; +} + +TEST_F(GraphTests, Test_Dtype_Conversion_1) { + /*auto expD = NDArrayFactory::create('c', {3}, {3.0, 3.0, 3.0}); + auto expF = NDArrayFactory::create('c', {3}, {3.0, 3.0, 3.0}); + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb"); + graph->buildGraph(); + + + auto gd = graph->template asT(); + auto gf = gd->template asT(); + + // checking float graph + Nd4jStatus statusF = GraphExecutioner::execute(gf); + ASSERT_EQ(ND4J_STATUS_OK, statusF); + + ASSERT_TRUE(gf->getVariableSpace()->hasVariable(3)); + + ASSERT_TRUE(gf->getVariableSpace()->hasVariable(3)); + auto z1 = gf->getVariableSpace()->getVariable(3)->getNDArray(); + + ASSERT_TRUE(expF.isSameShape(z1)); + ASSERT_TRUE(expF.equalsTo(z1)); + + + // checking double graph + Nd4jStatus statusD = GraphExecutioner::execute(gd); + ASSERT_EQ(ND4J_STATUS_OK, statusD); + + ASSERT_TRUE(gd->getVariableSpace()->hasVariable(3)); + auto z2 = gd->getVariableSpace()->getVariable(3)->getNDArray(); + + ASSERT_TRUE(expD.isSameShape(z2)); + ASSERT_TRUE(expD.equalsTo(z2)); + + + delete graph; + delete gd; + delete gf; + */ +} + +TEST_F(GraphTests, Test_Dtype_Conversion_2) { + /* + NDArray expF('c', {5, 4}, {0.32454616f, -0.06604697f, 0.22593613f, 0.43166467f, -0.18320604f, 0.00102305f, -0.06963076f, 0.25266643f, 0.07568010f, -0.03009197f, 0.07805517f, 0.33180334f, -0.06220427f, 0.07249600f, -0.06726961f, -0.22998397f, -0.06343779f, 0.07384885f, -0.06891008f, -0.23745790f}); + NDArray expD('c', {5, 4}, {0.32454616f, -0.06604697f, 0.22593613f, 0.43166467f, -0.18320604f, 0.00102305f, -0.06963076f, 0.25266643f, 0.07568010f, -0.03009197f, 0.07805517f, 0.33180334f, -0.06220427f, 0.07249600f, -0.06726961f, -0.22998397f, -0.06343779f, 0.07384885f, -0.06891008f, -0.23745790f}); + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb"); + graph->buildGraph(); + + + auto gd = graph->template asT(); + auto gf = gd->template asT(); + + // checking float + auto resultF = GraphExecutioner::execute(gf); + ASSERT_EQ(ND4J_STATUS_OK, resultF); + ASSERT_TRUE(gf->getVariableSpace()->hasVariable(18)); + auto zF = gf->getVariableSpace()->getVariable(18)->getNDArray(); + + ASSERT_TRUE(expF.isSameShape(zF)); + ASSERT_TRUE(expF.equalsTo(zF)); + + + // checking double + auto resultD = GraphExecutioner::execute(gd); + ASSERT_EQ(ND4J_STATUS_OK, resultD); + ASSERT_TRUE(gd->getVariableSpace()->hasVariable(18)); + auto zD = gd->getVariableSpace()->getVariable(18)->getNDArray(); + + ASSERT_TRUE(expD.isSameShape(zD)); + ASSERT_TRUE(expD.equalsTo(zD)); + + delete graph; + delete gd; + delete gf; + */ +} + +TEST_F(GraphTests, Test_Hash_Function_1) { + /* + auto graph0 = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb"); + auto graph1 = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb"); + auto graph2 = GraphExecutioner::importFromFlatBuffers("./resources/conv_0.fb"); + + ASSERT_EQ(graph0->hashCode(), graph1->hashCode()); + ASSERT_NE(0L, graph1->hashCode()); + ASSERT_NE(graph0->hashCode(), graph2->hashCode()); + + auto graph0D = graph0->template asT(); + auto graph1D = graph1->template asT(); + + ASSERT_NE(graph0->hashCode(), graph0D->hashCode()); + ASSERT_EQ(graph0D->hashCode(), graph1D->hashCode()); + + delete graph0; + delete graph1; + delete graph2; + delete graph0D; + delete graph1D; + */ +} + +TEST_F(GraphTests, OpListTest_1) { + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb"); ; + + ASSERT_TRUE(graph != nullptr); + std::vector ops = graph->getOperations(); + + ASSERT_TRUE(ops.size() == 11); + GraphUtils::filterOperations(ops); + ASSERT_TRUE(ops.size() == 7); + + std::string exp(" -g \"-DSD_OPS_LIST='-DOP_rank=true -DOP_range=true -DOP_subtract=true -DOP_permute=true -DOP_matmul=true -DOP_biasadd=true -DOP_TRANSFORM{15}=true '\""); + std::string out = GraphUtils::makeCommandLine(ops); +// nd4j_printf("EXP: >%s<\n", exp.c_str()); +// nd4j_printf("OUT: >%s<\n", out.c_str()); + ASSERT_EQ(exp, out); + + delete graph; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(GraphTests, OpListTest_2) { + auto graph0 = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb"); + auto graph1 = GraphExecutioner::importFromFlatBuffers("./resources/tensor_slice.fb"); + + ASSERT_TRUE(graph0 != nullptr); + ASSERT_TRUE(graph1 != nullptr); + + std::vector ops = graph0->getOperations(); + std::vector ops1 = graph1->getOperations(); + std::copy ( ops1.begin(), ops1.end(), std::back_inserter(ops)); + + ASSERT_TRUE(ops.size() == 13); + + GraphUtils::filterOperations(ops); + + std::string exp = " -g \"-DSD_OPS_LIST='-DOP_rank=true -DOP_range=true -DOP_subtract=true -DOP_permute=true -DOP_matmul=true -DOP_biasadd=true -DOP_TRANSFORM{15}=true -DOP_strided_slice=true -DOP_ACCUMULATION{1}=true '\""; + + ASSERT_TRUE(ops.size() == 9); + ASSERT_EQ(exp, GraphUtils::makeCommandLine(ops)); + + delete graph0; + delete graph1; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(GraphTests, OpListTest_3) { + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb"); ; + + ASSERT_TRUE(graph != nullptr); + std::vector ops = graph->getOperations(); + std::vector ops2(ops); + std::copy(ops.begin(), ops.end(), std::back_inserter(ops2)); + + ASSERT_TRUE(ops.size() == 11); + ASSERT_TRUE(ops2.size() == 2 * ops.size()); + + GraphUtils::filterOperations(ops2); + GraphUtils::filterOperations(ops); + ASSERT_TRUE(ops.size() == ops2.size()); + ASSERT_TRUE(ops.size() == 7); + ASSERT_TRUE(GraphUtils::makeCommandLine(ops) == GraphUtils::makeCommandLine(ops2)); + + delete graph; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(GraphTests, OpListTest_4) { + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/conv_0.fb"); ; + + ASSERT_TRUE(graph != nullptr); + std::vector ops = graph->getOperations(); + std::vector ops2(ops); + std::copy(ops.begin(), ops.end(), std::back_inserter(ops2)); + + // nd4j_printf("Total ops before %i\n", ops.size()); + ASSERT_TRUE(ops.size() == 6); + ASSERT_TRUE(ops2.size() == 2 * ops.size()); + + GraphUtils::filterOperations(ops2); + GraphUtils::filterOperations(ops); + ASSERT_TRUE(ops.size() == ops2.size()); + ASSERT_TRUE(ops.size() == 5); + ASSERT_TRUE(GraphUtils::makeCommandLine(ops) == GraphUtils::makeCommandLine(ops2)); + + delete graph; +} + + +TEST_F(GraphTests, Test_Inplace_Execution_1) { + auto exp = NDArrayFactory::create('c', {5, 4}, {0.32454616f, -0.06604697f, 0.22593613f, 0.43166467f, -0.18320604f, 0.00102305f, -0.06963076f, 0.25266643f, 0.07568010f, -0.03009197f, 0.07805517f, 0.33180334f, -0.06220427f, 0.07249600f, -0.06726961f, -0.22998397f, -0.06343779f, 0.07384885f, -0.06891008f, -0.23745790f}); + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb"); + // graph->printOut(); + graph->tagInplaceNodes(); + + ASSERT_FALSE(graph->nodeById(8)->isInplace()); + ASSERT_TRUE(graph->nodeById(9)->isInplace()); + ASSERT_TRUE(graph->nodeById(10)->isInplace()); + ASSERT_FALSE(graph->nodeById(11)->isInplace()); + ASSERT_FALSE(graph->nodeById(12)->isInplace()); + ASSERT_TRUE(graph->nodeById(17)->isInplace()); + ASSERT_TRUE(graph->nodeById(18)->isInplace()); + + auto status = GraphExecutioner::execute(graph, graph->getVariableSpace()); + ASSERT_EQ(Status::OK(), status); + + auto z = graph->getVariableSpace()->getVariable(18)->getNDArray(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + auto z_17 = graph->getVariableSpace()->getVariable(17)->getNDArray(); + ASSERT_TRUE(z_17 == z); + + delete graph; +} + +TEST_F(GraphTests, Test_Inplace_Execution_2) { + Graph graphA; + + auto x = NDArrayFactory::create_('c', {5, 5}); + x->assign(-5.0); + + graphA.getVariableSpace()->putVariable(-1, x); + + // abs, result is 5 + auto nodeA0 = new Node(OpType_TRANSFORM_SAME, 0, 1, {-1}, {}); + // 1-, result -4 + auto nodeA1 = new Node(OpType_TRANSFORM_SAME, 35, 2, {1}, {}); + + // graph should return 4: abs(-4) + auto nodeA2 = new Node(OpType_TRANSFORM_SAME, 0, 3, {2}, {}); + + // graph should return 1 - 4 = -3 + auto nodeA21 = new Node(OpType_TRANSFORM_SAME, 35, 5, {3}, {}); + + // 1 - -4 = 3 + auto nodeA3 = new Node(OpType_TRANSFORM_SAME, 35, 4, {2}, {}); + + // same abs = 3 + auto nodeA31 = new Node(OpType_TRANSFORM_SAME, 35, 6, {4}, {}); + + graphA.addNode(nodeA0); + graphA.addNode(nodeA1); + graphA.addNode(nodeA2); + graphA.addNode(nodeA3); + graphA.addNode(nodeA21); + graphA.addNode(nodeA31); + + graphA.buildGraph(); + graphA.tagInplaceNodes(); + + // nodes have 1 output + ASSERT_TRUE(graphA.nodeById(1)->isInplace()); + ASSERT_TRUE(graphA.nodeById(2)->isInplace()); + + // this 2 nodes share same input: node 2, so they can't be inplace + ASSERT_FALSE(graphA.nodeById(3)->isInplace()); + ASSERT_FALSE(graphA.nodeById(4)->isInplace()); + + // these 2 ops are standalone, so they can be run inplace + ASSERT_TRUE(graphA.nodeById(5)->isInplace()); + ASSERT_TRUE(graphA.nodeById(6)->isInplace()); +} +#endif + +TEST_F(GraphTests, Test_Inplace_Outputs_1) { + auto x = NDArrayFactory::create('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto exp = NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto z = NDArrayFactory::create('c', {2, 3}); + + sd::ops::test_output_reshape op; + auto result = op.execute({&x}, {&z}, {}, {}, {}); + ASSERT_EQ(Status::OK(), result); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +TEST_F(GraphTests, Test_Inplace_Outputs_2) { +#ifndef __APPLE_OS__ + // we dont want testing this on apple. due to try/catch + + auto x = NDArrayFactory::create('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto exp = NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto z = NDArrayFactory::create('c', {3, 3}); + + bool failed = false; + sd::ops::test_output_reshape op; + try { + op.execute({&x}, {&z}, {}, {}, {}); + + } catch (const std::runtime_error& e) { + failed = true; + } + + + ASSERT_TRUE(failed); +#endif +} + +/* +TEST_F(GraphTests, Test_Minifier_1) { + // run preprocessor to produce single header + // if all ok - return value is 0, if error - non-zero value will be returned + std::string input("../include/ops/ops.h"); //declarable/CustomOperations.h"); + + ASSERT_EQ(0, GraphUtils::runPreprocessor(input.c_str(), "libnd4j_mini.hpp")); + // remove file from filesystem +#ifdef __linux__ + ASSERT_EQ(0, unlink("libnd4j_mini.hpp")); +#endif +} +*/ + +TEST_F(GraphTests, Test_Minifier_2) { + + // run preprocessor to produce single header + // if all ok - return value is 0, if error - non-zero value will be returned + ASSERT_EQ(0, GraphUtils::runPreprocessor("../include/ops/specials.h", "libnd4j_mini2.hpp")); + // remove file from filesystem +#ifdef __linux__ + ASSERT_EQ(0, unlink("libnd4j_mini2.hpp")); +#endif +} + +TEST_F(GraphTests, Test_Minifier_3) { + + // run preprocessor to produce single header + // if all ok - return value is 0, if error - non-zero value will be returned +#ifdef __linux__ + ASSERT_EQ(0x100, GraphUtils::runPreprocessor("/include/ops/ops.h", "libnd4j_mini3.hpp")); +#endif + // remove file from filesystem + //ASSERT_EQ(0, unlink("libnd4j_mini3.hpp")); + +} diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/HashUtilsTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/HashUtilsTests.cpp new file mode 100644 index 000000000..bbd7b9ed4 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/HashUtilsTests.cpp @@ -0,0 +1,45 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 02.09.17. +// + +#include "testlayers.h" +#include + +class HashUtilsTests : public testing::Test { + +}; + + +TEST_F(HashUtilsTests, TestEquality1) { + std::string str("Conv2D"); + + Nd4jLong hash1 = sd::ops::HashHelper::getInstance().getLongHash(str); + ASSERT_EQ(-1637140380760460323L, hash1); +} + + + +TEST_F(HashUtilsTests, TestEquality2) { + std::string str("switch"); + + Nd4jLong hash1 = sd::ops::HashHelper::getInstance().getLongHash(str); + ASSERT_EQ(-1988317239813741487L, hash1); +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/HelpersTests1.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/HelpersTests1.cpp new file mode 100644 index 000000000..e8be972ee --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/HelpersTests1.cpp @@ -0,0 +1,2347 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +using namespace sd; + +class HelpersTests1 : public testing::Test { +public: + + HelpersTests1() { + + std::cout<('c', {4}, {14,17,3,1}); +// auto exp = NDArrayFactory::create('c', {4,4}, {-0.629253, -0.764093, -0.13484, -0.0449467, -0.764093, 0.641653, -0.0632377, -0.0210792, -0.13484,-0.0632377, 0.98884,-0.00371987, -0.0449467,-0.0210792,-0.00371987, 0.99876}); + +// auto result = ops::helpers::Householder::evalHHmatrix(x); +// ASSERT_TRUE(result.isSameShape(&exp)); +// ASSERT_TRUE(result.equalsTo(&exp)); + +// } + +// /////////////////////////////////////////////////////////////////// +// TEST_F(HelpersTests1, evalHHmatrix_test2) { + +// #ifdef __CUDABLAS__ +// return; +// #endif +// auto x = NDArrayFactory::create('c', {3}, {14,-4,3}); +// auto exp = NDArrayFactory::create('c', {3,3}, {-0.941742, 0.269069,-0.201802, 0.269069, 0.962715,0.0279639, -0.201802,0.0279639, 0.979027}); + +// auto result = ops::helpers::Householder::evalHHmatrix(x); + +// ASSERT_TRUE(result.isSameShape(&exp)); +// ASSERT_TRUE(result.equalsTo(&exp)); + +// } + +///////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, evalHHmatrixData_test1) { + + auto x = NDArrayFactory::create('c', {4}, {14,17,3,1}); + auto tail = NDArrayFactory::create('c', {3}); + auto expTail = NDArrayFactory::create('c', {3}, {0.468984, 0.0827618, 0.0275873}); + const double normXExpected = -22.2486; + const double coeffExpected = 1.62925; + + double normX, coeff; + ops::helpers::Householder::evalHHmatrixData(x, tail, coeff, normX); + + ASSERT_NEAR(normX, normXExpected, 1e-5); + ASSERT_NEAR(coeff, coeffExpected, 1e-5); + ASSERT_TRUE(tail.isSameShapeStrict(expTail)); + ASSERT_TRUE(tail.equalsTo(&expTail)); +} + +///////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, Householder_mulLeft_test1) { + + auto x = NDArrayFactory::create('c', {4,4}, {12 ,19 ,14 ,3 ,10 ,4 ,17 ,19 ,19 ,18 ,5 ,3 ,6 ,4 ,2 ,16}); + auto tail = NDArrayFactory::create('c', {1,3}, {0.5,0.5,0.5}); + auto exp = NDArrayFactory::create('c', {4,4}, {9.05,15.8,11.4, 0.8, 8.525, 2.4,15.7,17.9, 17.525,16.4, 3.7, 1.9, 4.525, 2.4, 0.7,14.9}); + + ops::helpers::Householder::mulLeft(x, tail, 0.1); + + ASSERT_TRUE(x.isSameShapeStrict(exp)); + ASSERT_TRUE(x.equalsTo(&exp)); +} + +///////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, Householder_mulLeft_test2) { + + auto x = NDArrayFactory::create('c', {4,4}, {12 ,19 ,14 ,3 ,10 ,4 ,17 ,19 ,19 ,18 ,5 ,3 ,6 ,4 ,2 ,16}); + auto tail = NDArrayFactory::create('c', {3,1}, {0.5,0.5,0.5}); + auto exp = NDArrayFactory::create('c', {4,4}, {9.05,15.8,11.4, 0.8, 8.525, 2.4,15.7,17.9, 17.525,16.4, 3.7, 1.9, 4.525, 2.4, 0.7,14.9}); + + ops::helpers::Householder::mulLeft(x, tail, 0.1); + + ASSERT_TRUE(x.isSameShapeStrict(exp)); + ASSERT_TRUE(x.equalsTo(&exp)); + +} + +///////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, Householder_mulRight_test1) { + + auto x = NDArrayFactory::create('c', {4,4}, {12 ,19 ,14 ,3 ,10 ,4 ,17 ,19 ,19 ,18 ,5 ,3 ,6 ,4 ,2 ,16}); + auto tail = NDArrayFactory::create('c', {1,3}, {0.5,0.5,0.5}); + auto exp = NDArrayFactory::create('c', {4,4}, {9,17.5,12.5, 1.5, 7, 2.5,15.5, 17.5, 15.8,16.4, 3.4, 1.4, 4.3,3.15,1.15,15.15}); + + ops::helpers::Householder::mulRight(x, tail, 0.1); + + ASSERT_TRUE(x.isSameShapeStrict(exp)); + ASSERT_TRUE(x.equalsTo(&exp)); +} + +///////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, BiDiagonalizeUp_test1) { + + auto matrix = NDArrayFactory::create('c', {4,4}, {9,13,3,6,13,11,7,6,3,7,4,7,6,6,7,10}); + auto hhMatrixExp = NDArrayFactory::create('c', {4,4}, {1.524000, 1.75682,0.233741,0.289458, 0.496646, 1.5655, 1.02929,0.971124, 0.114611,-0.451039, 1.06367,0, 0.229221,-0.272237,0.938237,0}); + auto hhBidiagExp = NDArrayFactory::create('c', {4,4}, {-17.1756, 24.3869, 0, 0, 0,-8.61985,-3.89823, 0, 0, 0, 4.03047,4.13018, 0, 0, 0,1.21666}); + + ops::helpers::BiDiagonalUp object(matrix); + // object._HHmatrix.printBuffer(); + + ASSERT_TRUE(hhMatrixExp.isSameShapeStrict(object._HHmatrix)); + ASSERT_TRUE(hhMatrixExp.equalsTo(&object._HHmatrix)); + ASSERT_TRUE(hhBidiagExp.isSameShapeStrict(object._HHbidiag)); + ASSERT_TRUE(hhBidiagExp.equalsTo(&object._HHbidiag)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, BiDiagonalizeUp_test2) { + + auto matrix = NDArrayFactory::create('c', {5,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12}); + auto hhMatrixExp = NDArrayFactory::create('c', {5,4}, {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.66025, 1.66979,-0.444696, 0.114105,0.130601, 1.58392, 0, -0.22821, 0.215638,0.0524781, 1.99303, 0.0760699,0.375605, 0.509835,0.0591568}); + auto hhBidiagExp = NDArrayFactory::create('c', {4,4}, {-17.2916,7.03123, 0, 0, 0, 16.145,-22.9275, 0, 0, 0, -9.9264,-11.5516, 0, 0, 0,-12.8554}); + + ops::helpers::BiDiagonalUp object(matrix); + + ASSERT_TRUE(hhMatrixExp.isSameShapeStrict(object._HHmatrix)); + ASSERT_TRUE(hhMatrixExp.equalsTo(&object._HHmatrix)); + ASSERT_TRUE(hhBidiagExp.isSameShapeStrict(object._HHbidiag)); + ASSERT_TRUE(hhBidiagExp.equalsTo(&object._HHbidiag)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, BiDiagonalizeUp_test3) { + + auto matrix = NDArrayFactory::create('c', {6,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12, 0,-15,10,2}); + auto hhMatrixExp = NDArrayFactory::create('c', {6,4}, {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.65232, 1.59666,-0.502606, 0.114105, 0.129651, 1.35075, 0, -0.22821, 0.214071, 0.103749, 1.61136, 0.0760699, 0.372875, 0.389936, 0.2398, 0,0.0935171,-0.563777, 0.428587}); + auto hhBidiagExp = NDArrayFactory::create('c', {4,4}, {-17.2916,7.03123, 0, 0, 0,16.3413,-20.7828, 0, 0, 0,-18.4892,4.13261, 0, 0, 0,-21.323}); + + ops::helpers::BiDiagonalUp object(matrix); + // object._HHmatrix.printBuffer(); + + ASSERT_TRUE(hhMatrixExp.isSameShapeStrict(object._HHmatrix)); + ASSERT_TRUE(hhMatrixExp.equalsTo(&object._HHmatrix)); + ASSERT_TRUE(hhBidiagExp.isSameShapeStrict(object._HHbidiag)); + ASSERT_TRUE(hhBidiagExp.equalsTo(&object._HHbidiag)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, HHsequence_test1) { + + auto matrix = NDArrayFactory::create('c', {5,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12}); + auto vectorsUseqExp = NDArrayFactory::create('c', {5,4}, {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.66025, 1.66979,-0.444696, 0.114105,0.130601, 1.58392, 0, -0.22821,0.215638,0.0524781, 1.99303, 0.0760699,0.375605, 0.509835,0.0591568}); + auto vectorsVseqExp = NDArrayFactory::create('c', {5,4}, {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.66025, 1.66979,-0.444696, 0.114105,0.130601, 1.58392, 0, -0.22821,0.215638,0.0524781, 1.99303, 0.0760699,0.375605, 0.509835,0.0591568}); + auto coeffsUseqExp = NDArrayFactory::create('c', {4,1}, {1.52048,1.66025,1.58392,1.99303}); + auto coeffsVseqExp = NDArrayFactory::create('c', {3,1}, {1.37012,1.66979,0}); + + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + + ASSERT_TRUE(uSeq._vectors.isSameShapeStrict(vectorsUseqExp)); + ASSERT_TRUE(vSeq._vectors.isSameShapeStrict(vectorsVseqExp)); + ASSERT_TRUE(uSeq._vectors.equalsTo(&vectorsUseqExp)); + ASSERT_TRUE(vSeq._vectors.equalsTo(&vectorsVseqExp)); + + ASSERT_TRUE(vSeq._diagSize == uSeq._diagSize - 1); + ASSERT_TRUE(vSeq._shift == 1); + ASSERT_TRUE(uSeq._shift == 0); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, HHsequence_test2) { + + auto matrix = NDArrayFactory::create('c', {6,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12 ,0,-15,10,2}); + auto vectorsUseqExp = NDArrayFactory::create('c', {6,4}, {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.65232, 1.59666,-0.502606, 0.114105, 0.129651, 1.35075, 0, -0.22821, 0.214071, 0.103749, 1.61136, 0.0760699, 0.372875, 0.389936, 0.2398, 0,0.0935171,-0.563777, 0.428587}); + auto vectorsVseqExp = NDArrayFactory::create('c', {6,4}, {1.52048, 1.37012, 0.636326, -0.23412, 0.494454, 1.65232, 1.59666,-0.502606, 0.114105, 0.129651, 1.35075, 0, -0.22821, 0.214071, 0.103749, 1.61136, 0.0760699, 0.372875, 0.389936, 0.2398, 0,0.0935171,-0.563777, 0.428587}); + auto coeffsUseqExp = NDArrayFactory::create('c', {4,1}, {1.52048,1.65232,1.35075,1.61136}); + auto coeffsVseqExp = NDArrayFactory::create('c', {3,1}, {1.37012,1.59666,0}); + + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + + ASSERT_TRUE(uSeq._vectors.isSameShapeStrict(vectorsUseqExp)); + ASSERT_TRUE(vSeq._vectors.isSameShapeStrict(vectorsVseqExp)); + ASSERT_TRUE(uSeq._vectors.equalsTo(&vectorsUseqExp)); + ASSERT_TRUE(vSeq._vectors.equalsTo(&vectorsVseqExp)); + + ASSERT_TRUE(vSeq._diagSize == uSeq._diagSize - 1); + ASSERT_TRUE(vSeq._shift == 1); + ASSERT_TRUE(uSeq._shift == 0); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, HHsequence_test3) { + + auto matrix = NDArrayFactory::create('c', {4,4}, {9,13,3,6, 13,11,7,6, 3,7,4,7, 6,6,7,10}); + auto vectorsUseqExp = NDArrayFactory::create('c', {4,4}, {1.524, 1.75682,0.233741,0.289458, 0.496646, 1.5655, 1.02929,0.971124, 0.114611,-0.451039, 1.06367, 0, 0.229221,-0.272237,0.938237, 0}); + auto vectorsVseqExp = NDArrayFactory::create('c', {4,4}, {1.524, 1.75682,0.233741,0.289458, 0.496646, 1.5655, 1.02929,0.971124, 0.114611,-0.451039, 1.06367, 0, 0.229221,-0.272237,0.938237, 0}); + auto coeffsUseqExp = NDArrayFactory::create('c', {4,1}, { 1.524, 1.5655,1.06367,0}); + auto coeffsVseqExp = NDArrayFactory::create('c', {3,1}, {1.75682,1.02929, 0}); + + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + + ASSERT_TRUE(uSeq._vectors.isSameShapeStrict(vectorsUseqExp)); + ASSERT_TRUE(vSeq._vectors.isSameShapeStrict(vectorsVseqExp)); + ASSERT_TRUE(uSeq._vectors.equalsTo(&vectorsUseqExp)); + ASSERT_TRUE(vSeq._vectors.equalsTo(&vectorsVseqExp)); + + ASSERT_TRUE(vSeq._diagSize == uSeq._diagSize - 1); + ASSERT_TRUE(vSeq._shift == 1); + ASSERT_TRUE(uSeq._shift == 0); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, HHsequence_test4) { + + auto matrix = NDArrayFactory::create('c', {4,4}, {9,13,3,6, 13,11,7,6, 3,7,4,7, 6,6,7,10}); + auto exp = NDArrayFactory::create('c', {4,4}, {2.49369, 2.62176, 5.88386, 7.69905, -16.0588,-18.7319,-9.15007,-12.6164, 4.7247, 3.46252, 1.02038, -1.4533, 2.9279,-2.29178, 1.90139,-0.66187}); + + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); + uSeq.mulLeft(matrix); + + ASSERT_TRUE(matrix.equalsTo(&exp)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, HHsequence_test5) { + + auto matrix = NDArrayFactory::create('c', {5,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12}); + auto exp = NDArrayFactory::create('c', {5,4}, {4.52891, 8.09473,-2.73704,-13.0302, -11.0752, 7.41549,-3.75125,0.815252, -7.76818,-15.9102,-9.90869,-11.8677, 1.63942,-17.0312,-9.05102,-4.49088, -9.63311,0.540226,-1.52764, 5.79111}); + + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); + uSeq.mulLeft(matrix); + + ASSERT_TRUE(matrix.equalsTo(&exp)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, HHsequence_test6) { + + auto matrix = NDArrayFactory::create('c', {5,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12}); + auto matrix2 = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); + auto exp = NDArrayFactory::create('c', {6,4}, {9,-1,3,9, -4.43019,-15.1713, -3.2854,-7.65743, -9.39162,-7.03599, 8.03827, 9.48453, -2.97785, -16.424, 5.35265,-20.1171, -0.0436177, -13.118,-8.37287,-17.3012, -1.14074, 4.18282,-10.0914,-5.69014}); + + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); + uSeq.mulLeft(matrix2); + + ASSERT_TRUE(matrix2.equalsTo(&exp)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, HHsequence_test7) { + + auto matrix = NDArrayFactory::create('c', {4,4}, {9,13,3,6, 13,11,7,6, 3,7,4,7, 6,6,7,10}); + auto exp = NDArrayFactory::create('c', {4,4}, {9,13,3,6,-5.90424,-2.30926,-0.447417, 3.05712, -10.504,-9.31339, -8.85493,-10.8886, -8.29494,-10.6737, -5.94895,-7.55591}); + + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + vSeq.mulLeft(matrix); + + ASSERT_TRUE(matrix.equalsTo(&exp)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, HHsequence_test8) { + + auto matrix = NDArrayFactory::create('c', {5,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12}); + auto exp = NDArrayFactory::create('c', {5,4}, {9, -13, 3, 6, 13, 11, 7, -6, -6.90831,-5.01113, 0.381677,0.440128, -0.80107,0.961605,-0.308019,-1.96153, -0.795985, 18.6538, 12.0731, 16.9988}); + + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + vSeq.mulLeft(matrix); + + ASSERT_TRUE(matrix.equalsTo(&exp)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, HHsequence_test9) { + + auto matrix = NDArrayFactory::create('c', {6,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12 ,0,-15,10,2}); + auto exp = NDArrayFactory::create('c', {6,4}, {9, -13, 3, 6, 13, 11, 7, -6, 3, 7, 4, 7, 3.77597, 18.6226,-0.674868, 4.61365, 5.02738,-14.1486, -2.22877,-8.98245, -0.683766, 1.73722, 14.9859, 12.0843}); + + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + vSeq.mulLeft(matrix); + + ASSERT_TRUE(matrix.equalsTo(&exp)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, HHsequence_test10) { + + auto matrix = NDArrayFactory::create('c', {4,4}, {9,13,3,6, 13,11,7,6, 3,7,4,7, 6,6,7,10}); + auto matrix2 = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); + auto exp = NDArrayFactory::create('c', {6,4}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, 2.58863, 11.0295,-4.17483,-0.641012, -1.21892,-16.3151, 6.12049, -20.0239, -0.901799,-15.0389,-12.4944, -20.2394}); + + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + vSeq.mulLeft(matrix2); + + ASSERT_TRUE(matrix2.equalsTo(&exp)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, HHsequence_test11) { + + auto matrix = NDArrayFactory::create('c', {5,4}, {9,-13,3,6, 13,11,7,-6, 3,7,4,7, -6,6,7,10, 2,17,9,12}); + auto matrix2 = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); + auto exp = NDArrayFactory::create('c', {6,4}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, 1.14934, 4.40257, 8.70127,-1.18824, 1.5132,0.220419,-11.6285,-11.7549, 2.32148, 24.3838,0.256531, 25.9116}); + + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + vSeq.mulLeft(matrix2); + + ASSERT_TRUE(matrix2.equalsTo(&exp)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, HHsequence_test12) { + + auto matrix = NDArrayFactory::create('c', {5,3}, {9,-13,3, 13,11,7, 3,7,4, -6,6,7, 2,17,9}); + auto matrix2 = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); + auto exp = NDArrayFactory::create('c', {6,4}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, -1, 6, 7, 19, -2.62252,-22.2914, 4.76743,-19.6689, -1.05943,-9.00514,-11.8013,-7.94571}); + + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + vSeq.mulLeft(matrix2); + + ASSERT_TRUE(matrix2.equalsTo(&exp)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, HHsequence_test13) { + + auto matrix = NDArrayFactory::create('c', {5,3}, {9,-13,3, 13,11,7, 3,7,4, -6,6,7, 2,17,9}); + auto matrix2 = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); + auto exp = NDArrayFactory::create('c', {6,4}, {9 , -1 , 3 , 9, -4.65167, 3.44652, 7.83593, 22.6899, -9.48514, -21.902, 5.66559,-13.0533, -0.343184, 15.2895, 7.2888, 14.0489, 0.289638,-1.87752, 3.944,-1.49707, -2.48845, 3.18285,-10.6685,0.406502}); + + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); + uSeq.mulLeft(matrix2); + + ASSERT_TRUE(matrix2.equalsTo(&exp)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, HHsequence_test14) { + + auto matrix = NDArrayFactory::create('c', {5,3}, {9,-13,3, 13,11,7, 3,7,4, -6,6,7, 2,17,9}); + auto matrix2 = NDArrayFactory::create('c',{5,5}, {9,-1,3,9,10, 11,-7,-5,3, 2, 4,7,-1,6,7, 19,2,17,9,15, 2,17,-9,15,2}); + auto exp = NDArrayFactory::create('c', {5,5}, {1.78958, 8.06962,-6.13687, 4.36267, 1.06472, -14.9578, -8.1522, 1.30442,-18.3343,-13.2578, 13.5536, 5.50764, 15.7859, 7.60831, 11.7871, -1.3626,-0.634986, 7.60934, -2.1841, 5.62694, -13.0577, 15.1554, -7.6511, 3.76365,-5.87368}); + + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); + uSeq.mulLeft(matrix2); + + ASSERT_TRUE(matrix2.equalsTo(&exp)); +} + + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, HHsequence_test15) { + + auto matrix = NDArrayFactory::create('c', {5,3}, {9,-13,3, 13,11,7, 3,7,4, -6,6,7, 2,17,9}); + auto matrix2 = NDArrayFactory::create('c',{5,5}, {9,-1,3,9,10, 11,-7,-5,3, 2, 4,7,-1,6,7, 19,2,17,9,15, 2,17,-9,15,2}); + auto exp = NDArrayFactory::create('c', {5,5}, {9, -1, 3, 9, 10, 11, -7, -5, 3, 2, 4, 7, -1, 6, 7, -9.26566,-16.4298, 1.64125,-17.3243,-7.70257, -16.7077, 4.80216,-19.1652,-2.42279,-13.0258}); + + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + vSeq.mulLeft(matrix2); + + ASSERT_TRUE(matrix2.equalsTo(&exp)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, HHsequence_test16) { + + auto matrix = NDArrayFactory::create('c', {5,5}, {9,-1,3,9,10, 11,-7,-5,3, 2, 4,7,-1,6,7, 19,2,17,9,15, 2,17,-9,15,2}); + auto matrix2 = NDArrayFactory::create('c', {10,10}); + matrix2 = 100.; + auto exp = NDArrayFactory::create('c',{5,5}, {-0.372742, 0.295145, 0.325359, 0.790947, 0.20615, -0.455573,-0.824221,-0.239444, 0.216163,-0.0951492, -0.165663, 0.285319, -0.18501, 0.130431, -0.916465, -0.7869, 0.245393, 0.116952,-0.541267, 0.117997, -0.0828315, 0.303191,-0.888202, 0.133021, 0.3076}); + + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); + uSeq.applyTo(matrix2); + + ASSERT_TRUE(matrix2.equalsTo(&exp)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, HHsequence_test17) { + + auto matrix = NDArrayFactory::create('c', {5,5}, {9,-1,3,9,10, 11,-7,-5,3, 2, 4,7,-1,6,7, 19,2,17,9,15, 2,17,-9,15,2}); + auto matrix2 = NDArrayFactory::create('c', {10,10}); + matrix2 = 100.; + auto exp = NDArrayFactory::create('c',{5,5}, {1, 0, 0, 0, 0, 0,-0.022902, 0.986163, 0.0411914, 0.158935, 0, -0.44659, 0.021539, 0.797676,-0.404731, 0,-0.554556, 0.103511, -0.600701, -0.56649, 0,-0.701784,-0.127684,-0.0342758, 0.700015}); + + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + vSeq.applyTo(matrix2); + + ASSERT_TRUE(matrix2.equalsTo(&exp)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, HHsequence_test18) { + + auto matrix = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); + auto matrix2 = NDArrayFactory::create('c', {10,10}); + matrix2 = 100.; + auto exp = NDArrayFactory::create('c',{6,6}, {-0.637993, 0.190621,-0.524821,-0.312287, 0.407189, 0.133659, -0.708881, 0.0450803, 0.47462, 0.232701,-0.204602,-0.417348, -0.212664,-0.0405892,-0.297123,0.0240276,-0.821557, 0.435099, 0.0708881, -0.432466, -0.49252,-0.145004,-0.199312,-0.710367, -0.141776, -0.56468,-0.180549, 0.706094, 0.274317, 0.233707, -0.141776, -0.673865, 0.368567,-0.572848,0.0490246, 0.243733}); + + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence uSeq = object.makeHHsequence('u'); + uSeq.applyTo(matrix2); + + ASSERT_TRUE(matrix2.equalsTo(&exp)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, HHsequence_test19) { + + auto matrix = NDArrayFactory::create('c',{6,4}, {9,-1,3,9, 10,11,-7,-5, 3,2,4,7, -1,6,7,19, 2,17,9,15, 2,17,-9,15}); + auto matrix2 = NDArrayFactory::create('c', {10,10}); + matrix2 = 100.; + auto exp = NDArrayFactory::create('c',{4,4}, {1, 0, 0, 0, 0,-0.859586, 0.28601, -0.42345, 0, 0.19328,-0.585133,-0.787567, 0,-0.473027,-0.758826, 0.447693}); + + ops::helpers::BiDiagonalUp object(matrix); + ops::helpers::HHsequence vSeq = object.makeHHsequence('v'); + vSeq.applyTo(matrix2); + + ASSERT_TRUE(matrix2.equalsTo(&exp)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, HHcolPivQR_1) { + + auto matrix1 = NDArrayFactory::create('c', {5,6}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); + + auto expQR = NDArrayFactory::create('c', {5,6}, {-32.6649659, -4.9594419, -8.2657365, 7.2248659, 16.5927006, 11.7251002, -0.1354883, -29.0586293, 10.9775804, -14.6886248, 4.1884104, 20.7115773, 0.3483986, 0.3236753, 25.5376258, 1.6432380, 9.6395914, -9.0237996, -0.0580664, 0.0798999, -0.0799029, 19.5280665, -4.9773587, 16.0968604, 0.3483986, -0.6667832, 0.0252425, 0.0159188, 10.6978354, -4.6919842}); + auto expCoeffs = NDArrayFactory::create('c', {1,5}, {1.58166, 1.28555, 1.98605, 1.99949, 0}); + auto expPermut = NDArrayFactory::create('c', {6,6}, {0,1,0,0,0,0, 0,0,1,0,0,0, 1,0,0,0,0,0, 0,0,0,0,0,1, 0,0,0,0,1,0, 0,0,0,1,0,0}); + + ops::helpers::HHcolPivQR qr(matrix1); + + ASSERT_TRUE(expQR.equalsTo(&qr._qr)); + ASSERT_TRUE(expCoeffs.equalsTo(&qr._coeffs)); + ASSERT_TRUE(expPermut.equalsTo(&qr._permut)); + + ASSERT_TRUE(expQR.isSameShapeStrict(qr._qr)); + ASSERT_TRUE(expCoeffs.isSameShapeStrict(qr._coeffs)); + ASSERT_TRUE(expPermut.isSameShapeStrict(qr._permut)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, HHcolPivQR_2) { + + auto matrix1 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); + + auto expQR = NDArrayFactory::create('c', {6,6}, {38.1707, -3.03898, 5.16103, 23.0805, -7.57126, -13.885, -0.41519, 34.3623, 3.77403, 2.62327, -8.17784, 9.10312, 0.394431, 0.509952,-30.2179, -6.78341, 12.8421, 28.5491, -0.290633, 0.111912,0.450367, 28.1139, 15.5195, 2.60562, 0.332152, 0.405161,0.308163,0.0468127, 22.294,-2.94931, 0.249114,0.0627956,0.657873, 0.76767,-0.752594,-7.46986}); + auto expCoeffs = NDArrayFactory::create('c', {1,6}, {1.26198, 1.38824, 1.15567, 1.25667, 1.27682, 0}); + auto expPermut = NDArrayFactory::create('c', {6,6}, {0,0,1,0,0,0, 0,0,0,0,1,0, 0,0,0,1,0,0, 0,1,0,0,0,0, 0,0,0,0,0,1, 1,0,0,0,0,0}); + + ops::helpers::HHcolPivQR qr(matrix1); + + ASSERT_TRUE(expQR.equalsTo(&qr._qr)); + ASSERT_TRUE(expCoeffs.equalsTo(&qr._coeffs)); + ASSERT_TRUE(expPermut.equalsTo(&qr._permut)); + + ASSERT_TRUE(expQR.isSameShapeStrict(qr._qr)); + ASSERT_TRUE(expCoeffs.isSameShapeStrict(qr._coeffs)); + ASSERT_TRUE(expPermut.isSameShapeStrict(qr._permut)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, HHcolPivQR_3) { + + NDArray matrix1('c', {6,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); + + auto expQR = NDArrayFactory::create('c', {6,5}, {-37.054 , 0.323852 , 8.04231 , -22.9395 ,-13.089, 0.105164, 32.6021, 6.42277, -0.262898,-1.58766, 0.140218, -0.485058, 29.2073, -9.92301,-23.7111, -0.262909,-0.00866538, 0.103467, 8.55831,-1.86455, -0.315491, 0.539207, 0.40754,-0.0374124,-7.10401, 0.315491, 0.385363,-0.216459, -0.340008,0.628595}); + auto expCoeffs = NDArrayFactory::create('c', {1,5}, {1.53975, 1.19431, 1.63446, 1.7905, 1.43356}); + auto expPermut = NDArrayFactory::create('c', {5,5}, {0,0,0,1,0, 1,0,0,0,0, 0,0,0,0,1, 0,0,1,0,0, 0,1,0,0,0}); + + ops::helpers::HHcolPivQR qr(matrix1); + + ASSERT_TRUE(expQR.equalsTo(&qr._qr)); + ASSERT_TRUE(expCoeffs.equalsTo(&qr._coeffs)); + ASSERT_TRUE(expPermut.equalsTo(&qr._permut)); + + ASSERT_TRUE(expQR.isSameShapeStrict(qr._qr)); + ASSERT_TRUE(expCoeffs.isSameShapeStrict(qr._coeffs)); + ASSERT_TRUE(expPermut.isSameShapeStrict(qr._permut)); + +} + +#ifndef __CUDABLAS__ +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, JacobiSVD_test1) { + + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto left = NDArrayFactory::create('c', {2,2}); + auto right = NDArrayFactory::create('c', {2,2}); + + auto expLeft = NDArrayFactory::create('c', {2,2}, {0.972022, 0.23489, -0.23489, 0.972022}); + auto expRight = NDArrayFactory::create('c', {2,2}, {0.827657, 0.561234, -0.561234, 0.827657}); + + ops::helpers::JacobiSVD::svd2x2(matrix3, 1, 3, left, right); + + ASSERT_TRUE(expLeft.equalsTo(&left)); + ASSERT_TRUE(expRight.equalsTo(&right)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, JacobiSVD_test2) { + + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto matrix4 = NDArrayFactory::create('c', {5,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19}); + auto matrix5 = NDArrayFactory::create('c', {5,5}, {3 ,-8 ,5 ,7 ,-8 ,4 ,-19 ,-12 ,-4 ,-5 ,-11 ,19 ,-2 ,-7 ,1 ,16 ,-5 ,10 ,19 ,-19 ,0 ,-20 ,0 ,-8 ,-13}); + + auto exp3 = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, -0.609208,19.6977, 8.63044,-11.9811,-4.67059, -2, -11, 8, 2, -6, 3.55371, 0,-12.5903, 7.51356, -5.5844, 16, 15, -3, 7, 0}); + auto exp4 = NDArrayFactory::create('c', {5,5}, {12, -10.9657,19,24.5714, -6, 3, -2.6399, 2,8.83351, -7, 14,-0.406138,18,18.7839, 18, -14, 12.8949, 1,-7.9197, 2, -3, 23.353, 8, 8.2243,-19}); + auto exp5 = NDArrayFactory::create('c', {5,5}, {3 ,-8 ,5 ,7 ,-8 ,4 ,-19 ,-12 ,-4 ,-5 ,-11 ,19 ,-2 ,-7 ,1 ,16 ,-5 ,10 ,19 ,-19 ,0 ,-20 ,0 ,-8 ,-13}); + + ops::helpers::JacobiSVD jac(matrix3, true, true, true); + jac._m = matrix3; + jac._u = matrix4; + jac._v = matrix5; + + double maxElem; + bool result = jac.isBlock2x2NotDiag(matrix3, 1, 3, maxElem); + + // ASSERT_NEAR(maxElem, 19.69772, 1e-5); + ASSERT_TRUE(exp3.equalsTo(&matrix3)); + ASSERT_TRUE(exp4.equalsTo(&jac._u)); + ASSERT_TRUE(exp5.equalsTo(&jac._v)); + + ASSERT_TRUE(result); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, JacobiSVD_test3) { + + auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); + + auto expected = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, -1.14919,-12.1206,3.59677, 4.34919,-4.24758, -1.94919, 11.7427,11.6698,-10.4444,-2.74919, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + + ops::helpers::JacobiSVD::mulRotationOnLeft(1, 2, matrix, rotation); + + ASSERT_TRUE(expected.equalsTo(&matrix)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, JacobiSVD_test4) { + + auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); + + auto expected = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 1.94919, 4.92056,-8.79677,1.25081, 5.04758, 1.14919,-16.1427,-8.46976,11.2444,0.349193, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + + ops::helpers::JacobiSVD::mulRotationOnLeft(2, 1, matrix, rotation); + + ASSERT_TRUE(expected.equalsTo(&matrix)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, JacobiSVD_test5) { + + auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); + + auto expected = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 2, -18, -13, 14, 2, 1.14919,6.32056,-4.59677,-1.14919, 3.44758, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + + ops::helpers::JacobiSVD::mulRotationOnLeft(2, 2, matrix, rotation); + + ASSERT_TRUE(expected.equalsTo(&matrix)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, JacobiSVD_test6) { + + auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); + + auto expected = NDArrayFactory::create('c', {5,5}, {-18,-14.5173, 4.5746,-7, 1, 2, 6.46976,-16.5427,14, 2, -2,-8.39677,-6.92056, 2,-6, -3,-7.79677,-4.59677,-2, 7, 16, 5.32379, 11.019, 7, 0}); + + ops::helpers::JacobiSVD::mulRotationOnRight(1, 2, matrix, rotation); + + ASSERT_TRUE(expected.equalsTo(&matrix)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, JacobiSVD_test7) { + + auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); + + auto expected = NDArrayFactory::create('c', {5,5}, {-18, 14.9173, 3.0254,-7, 1, 2,-13.6698,11.3427,14, 2, -2, 3.99677,10.1206, 2,-6, -3, 4.59677,7.79677,-2, 7, 16, 0.67621,-12.219, 7, 0}); + + ops::helpers::JacobiSVD::mulRotationOnRight(2, 1, matrix, rotation); + + ASSERT_TRUE(expected.equalsTo(&matrix)); +} + +////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, JacobiSVD_test8) { + + auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto rotation = NDArrayFactory::create('c', {2,2}, {0.2, math::nd4j_sqrt(0.6), -math::nd4j_sqrt(0.6), 0.2}); + + auto expected = NDArrayFactory::create('c', {5,5}, {-18, 1, 18.5173,-7, 1, 2,-18,-12.6698,14, 2, -2,-11, 7.79677, 2,-6, -3, -8, 7.79677,-2, 7, 16, 15,-2.92379, 7, 0}); + + ops::helpers::JacobiSVD::mulRotationOnRight(2, 2, matrix, rotation); + + ASSERT_TRUE(expected.equalsTo(&matrix)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, JacobiSVD_test9) { + + auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + + auto expS = NDArrayFactory::create('c', {5,1}, {35.7975, 29.1924, 11.1935, 9.2846, 6.77071}); + auto expU = NDArrayFactory::create('c', {5,5}, {0.744855,0.0686476, 0.079663,0.0889877, 0.65285, -0.386297,-0.760021,0.00624688, 0.156774, 0.498522, 0.186491,-0.322427, 0.773083,-0.468826,-0.209299, 0.246053,-0.215594, 0.240942, 0.821793,-0.399475, -0.447933, 0.516928, 0.581295, 0.269001, 0.349106}); + auto expV = NDArrayFactory::create('c', {5,5}, {-0.627363, 0.23317, 0.501211, 0.160272, -0.524545, -0.0849394, 0.917171,-0.155876,-0.0124053, 0.356555, 0.66983, 0.182569, 0.696897, 0.179807,0.000864568, -0.387647, -0.264316, 0.416597, 0.0941014, 0.772955, 0.0160818,-0.0351459,-0.255484, 0.965905, 0.0161524}); + + ops::helpers::JacobiSVD jac(matrix, true, true, true); + + ASSERT_TRUE(expS.equalsTo(&jac._s)); + ASSERT_TRUE(expU.equalsTo(&jac._u)); + ASSERT_TRUE(expV.equalsTo(&jac._v)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, JacobiSVD_test10) { + + auto matrix = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + + auto expS = NDArrayFactory::create('c', {5,1}, {35.7975, 29.1924, 11.1935, 9.2846, 6.77071}); + auto expU = NDArrayFactory::create('c', {5,5}, {0.744855,0.0686476, 0.079663,0.0889877, 0.65285, -0.386297,-0.760021,0.00624688, 0.156774, 0.498522, 0.186491,-0.322427, 0.773083,-0.468826,-0.209299, 0.246053,-0.215594, 0.240942, 0.821793,-0.399475, -0.447933, 0.516928, 0.581295, 0.269001, 0.349106}); + auto expV = NDArrayFactory::create('c', {5,5}, {-0.627363, 0.23317, 0.501211, 0.160272, -0.524545, -0.0849394, 0.917171,-0.155876,-0.0124053, 0.356555, 0.66983, 0.182569, 0.696897, 0.179807,0.000864568, -0.387647, -0.264316, 0.416597, 0.0941014, 0.772955, 0.0160818,-0.0351459,-0.255484, 0.965905, 0.0161524}); + + ops::helpers::JacobiSVD jac(matrix, true, true, false); + + ASSERT_TRUE(expS.equalsTo(&jac._s)); + ASSERT_TRUE(expU.equalsTo(&jac._u)); + ASSERT_TRUE(expV.equalsTo(&jac._v)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, JacobiSVD_test11) { + + auto matrix = NDArrayFactory::create('c', {6,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0, 3, -11, 2, 12, 10}); + + auto expS = NDArrayFactory::create('c', {5,1}, {36.27, 32.1997, 15.9624, 10.6407, 6.9747}); + auto expU = NDArrayFactory::create('c', {6,5}, {0.720125,-0.149734, 0.227784,-0.0288531, 0.595353, -0.509487,-0.567298, -0.237169,-0.0469077, 0.38648, 0.120912, -0.32916,-0.0202265, 0.921633, -0.153994, 0.180033,-0.294831, 0.357867, -0.194106, -0.646595, -0.354033, 0.521937, 0.556566, 0.305582, 0.211013, -0.222425,-0.433662, 0.673515, -0.128465, 0.099309}); + auto expV = NDArrayFactory::create('c', {5,5}, {-0.581609, 0.315327,0.333158, 0.34476, -0.576582, 0.117364, 0.889461,0.175174,-0.166603, 0.369651, 0.643246,-0.0899117,0.613288, 0.442462,-0.0790943, -0.480818, -0.264384,0.395122, 0.223126, 0.702145, -0.0548207, -0.177325,0.571031,-0.779632, -0.1779}); + + ops::helpers::JacobiSVD jac(matrix, true, true, false); + + ASSERT_TRUE(expS.equalsTo(&jac._s)); + ASSERT_TRUE(expU.equalsTo(&jac._u)); + ASSERT_TRUE(expV.equalsTo(&jac._v)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, JacobiSVD_test12) { + + auto matrix = NDArrayFactory::create('c', {6,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0, 3, -11, 2, 12, 10}); + + auto expS = NDArrayFactory::create('c', {5,1}, {36.27, 32.1997, 15.9624, 10.6407, 6.9747}); + auto expU = NDArrayFactory::create('c', {6,6}, {0.720125,-0.149734, 0.227784,-0.0288531, 0.595353,-0.227676, -0.509487,-0.567298, -0.237169,-0.0469077, 0.38648,-0.459108, 0.120912, -0.32916,-0.0202265, 0.921633,-0.153994,0.0591992, 0.180033,-0.294831, 0.357867, -0.194106,-0.646595,-0.544823, -0.354033, 0.521937, 0.556566, 0.305582, 0.211013,-0.393155, -0.222425,-0.433662, 0.673515, -0.128465, 0.099309, 0.531485}); + auto expV = NDArrayFactory::create('c', {5,5}, {-0.581609, 0.315327,0.333158, 0.34476, -0.576582, 0.117364, 0.889461,0.175174,-0.166603, 0.369651, 0.643246,-0.0899117,0.613288, 0.442462,-0.0790943, -0.480818, -0.264384,0.395122, 0.223126, 0.702145, -0.0548207, -0.177325,0.571031,-0.779632, -0.1779}); + + ops::helpers::JacobiSVD jac(matrix, true, true, true); + + ASSERT_TRUE(expS.equalsTo(&jac._s)); + ASSERT_TRUE(expU.equalsTo(&jac._u)); + ASSERT_TRUE(expV.equalsTo(&jac._v)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, JacobiSVD_test13) { + + auto matrix = NDArrayFactory::create('c', {5,6}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0, 3, -11, 2, 12, 10}); + + auto expS = NDArrayFactory::create('c', {5,1}, {40.499, 23.5079, 17.8139, 14.4484, 7.07957}); + auto expU = NDArrayFactory::create('c', {5,5}, {0.592324,-0.121832,-0.484064,-0.624878,-0.0975619, 0.651331, 0.367418, 0.117429, 0.370792, 0.538048, -0.272693,-0.138725, 0.249336,-0.540587, 0.742962, 0.263619,-0.903996, 0.179714, 0.276206, 0.0686237, -0.284717,-0.117079,-0.810818, 0.321741, 0.379848}); + auto expV = NDArrayFactory::create('c', {6,6}, {-0.619634,-0.158345, 0.462262,-0.021009,-0.299779, 0.53571, -0.183441,-0.504296,-0.150804,-0.251078,-0.563079,-0.556052, 0.724925,-0.404744, 0.154104,-0.177039,-0.262604, 0.431988, 0.0335645,-0.501546, 0.221702, 0.797602, 0.186339,-0.165176, -0.0675636,0.0663677,-0.728788, 0.414614,-0.390566, 0.368038, -0.226262, -0.54849,-0.399426,-0.311613, 0.580387, 0.233392}); + + ops::helpers::JacobiSVD jac(matrix, true, true, true); + + ASSERT_TRUE(expS.equalsTo(&jac._s)); + ASSERT_TRUE(expU.equalsTo(&jac._u)); + ASSERT_TRUE(expV.equalsTo(&jac._v)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, JacobiSVD_test14) { + + auto matrix = NDArrayFactory::create('c', {5,6}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0, 3, -11, 2, 12, 10}); + + auto expS = NDArrayFactory::create('c', {5,1}, {40.499, 23.5079, 17.8139, 14.4484, 7.07957}); + auto expU = NDArrayFactory::create('c', {5,5}, {0.592324,-0.121832,-0.484064,-0.624878,-0.0975619, 0.651331, 0.367418, 0.117429, 0.370792, 0.538048, -0.272693,-0.138725, 0.249336,-0.540587, 0.742962, 0.263619,-0.903996, 0.179714, 0.276206, 0.0686237, -0.284717,-0.117079,-0.810818, 0.321741, 0.379848}); + auto expV = NDArrayFactory::create('c', {6,5}, {-0.619634,-0.158345, 0.462262,-0.021009,-0.299779, -0.183441,-0.504296,-0.150804,-0.251078,-0.563079, 0.724925,-0.404744, 0.154104,-0.177039,-0.262604, 0.0335645,-0.501546, 0.221702, 0.797602, 0.186339, -0.0675636,0.0663677,-0.728788, 0.414614,-0.390566, -0.226262, -0.54849,-0.399426,-0.311613, 0.580387}); + + ops::helpers::JacobiSVD jac(matrix, true, true, false); + + ASSERT_TRUE(expS.equalsTo(&jac._s)); + ASSERT_TRUE(expU.equalsTo(&jac._u)); + ASSERT_TRUE(expV.equalsTo(&jac._v)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, JacobiSVD_test15) { + + auto matrix = NDArrayFactory::create('c', {5,6}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0, 3, -11, 2, 12, 10}); + + auto expS = NDArrayFactory::create('c', {5,1}, {40.499, 23.5079, 17.8139, 14.4484, 7.07957}); + auto expU = NDArrayFactory::create('c', {5,5}, {0.592324,-0.121832,-0.484064,-0.624878,-0.0975619, 0.651331, 0.367418, 0.117429, 0.370792, 0.538048, -0.272693,-0.138725, 0.249336,-0.540587, 0.742962, 0.263619,-0.903996, 0.179714, 0.276206, 0.0686237, -0.284717,-0.117079,-0.810818, 0.321741, 0.379848}); + auto expV = NDArrayFactory::create('c', {6,5}, {-0.619634,-0.158345, 0.462262,-0.021009,-0.299779, -0.183441,-0.504296,-0.150804,-0.251078,-0.563079, 0.724925,-0.404744, 0.154104,-0.177039,-0.262604, 0.0335645,-0.501546, 0.221702, 0.797602, 0.186339, -0.0675636,0.0663677,-0.728788, 0.414614,-0.390566, -0.226262, -0.54849,-0.399426,-0.311613, 0.580387}); + + ops::helpers::JacobiSVD jac(matrix, false, false, false); + + ASSERT_TRUE(expS.equalsTo(&jac._s)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, JacobiSVD_test16) { + + NDArray rotation('c', {2,2}, sd::DataType::DOUBLE); + + NDArray exp1('c', {2,2}, {1,0,0,1 }, sd::DataType::DOUBLE); + NDArray exp2('c', {2,2}, {0,1,-1,0}, sd::DataType::DOUBLE); + NDArray exp3('c', {2,2}, {-1,0,0,-1}, sd::DataType::DOUBLE); + NDArray exp4('c', {2,2}, {0.983282, 0.182089, -0.182089, 0.983282}, sd::DataType::DOUBLE); + NDArray exp5('c', {2,2}, {0.249041, 0.968493, -0.968493, 0.249041}, sd::DataType::DOUBLE); + + ops::helpers::JacobiSVD::createJacobiRotationGivens(0, 0, rotation); + ASSERT_TRUE(rotation.equalsTo(exp1)); + ASSERT_TRUE(rotation.isSameShapeStrict(exp1)); + + ops::helpers::JacobiSVD::createJacobiRotationGivens(0, -0.5, rotation); + ASSERT_TRUE(rotation.equalsTo(exp2)); + ASSERT_TRUE(rotation.isSameShapeStrict(exp2)); + + ops::helpers::JacobiSVD::createJacobiRotationGivens(-0.5, 0, rotation); + ASSERT_TRUE(rotation.equalsTo(exp3)); + ASSERT_TRUE(rotation.isSameShapeStrict(exp3)); + + + ops::helpers::JacobiSVD::createJacobiRotationGivens(2.7, -0.5, rotation); + ASSERT_TRUE(rotation.equalsTo(exp4)); + ASSERT_TRUE(rotation.isSameShapeStrict(exp4)); + + ops::helpers::JacobiSVD::createJacobiRotationGivens(2.7, -10.5, rotation); + ASSERT_TRUE(rotation.equalsTo(exp5)); + ASSERT_TRUE(rotation.isSameShapeStrict(exp5)); +} + +TEST_F(HelpersTests1, test_binary_search_1) { + std::array array = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + + auto idx = sd::ops::helpers::binarySearch(array.data(), 2, 10); + ASSERT_EQ(2, idx); +} + +TEST_F(HelpersTests1, test_binary_search_2) { + std::array array = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + + auto idx = sd::ops::helpers::binarySearch(array.data(), 18, 10); + ASSERT_EQ(-1, idx); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, SVD_test1) { + + auto matrix = NDArrayFactory::create('c', {5,5}, {-17 ,14 ,9 ,-12 ,-12 ,5 ,-4 ,-19 ,-7 ,-12 ,15 ,16 ,17 ,-6 ,8 ,-10 ,14 ,-15 ,6 ,-10 ,-14 ,12 ,-1 ,-16 ,3}); + auto matrix2 = NDArrayFactory::create('c', {5,5}, {18 ,3 ,2 ,7 ,-11 ,7 ,7 ,10 ,-13 ,-8 ,13 ,20 ,-4 ,-16 ,-9 ,-17 ,-5 ,-7 ,-19 ,-8 ,-9 ,9 ,6 ,14 ,-11}); + auto expM = NDArrayFactory::create('c', {5,5}, {-17,14,9,-12,-12, 5,-4, -19, -7,-12, 15,16,17.0294, -6, 8, -10,14, -15, 6,-10, -14,12, 0,-16, 0}); + auto expU = NDArrayFactory::create('c', {5,5}, {18,3, 2,7,-11, 7, 7.75131,10,-12.5665, -8, 13, 20.905,-4,-14.7979, -9, -17,-3.87565,-7,-19.2608, -8, -9, 9, 6, 14,-11}); + + ops::helpers::SVD svd(matrix, 4, true, true, true, 't'); + svd._m = matrix; + svd._u = matrix2; + svd.deflation1(1,1,2,2); + + ASSERT_TRUE(expM.equalsTo(&svd._m)); + ASSERT_TRUE(expU.equalsTo(&svd._u)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, SVD_test2) { + + auto matrix= NDArrayFactory::create('c', {5,5}, {-17 ,14 ,9 ,-12 ,-12 ,5 ,-4 ,-19 ,-7 ,-12 ,15 ,16 ,17 ,-6 ,8 ,-10 ,14 ,-15 ,6 ,-10 ,-14 ,12 ,-1 ,-16 ,3}); + auto matrix2 = NDArrayFactory::create('c', {5,5}, {18 ,3 ,2 ,7 ,-11 ,7 ,7 ,10 ,-13 ,-8 ,13 ,20 ,-4 ,-16 ,-9 ,-17 ,-5 ,-7 ,-19 ,-8 ,-9 ,9 ,6 ,14 ,-11}); + auto expM = NDArrayFactory::create('c', {5,5}, {22.6716,14, 9,-12,-12, 5,-4,-19, -7,-12, 0,16, 0, -6, 8, -10,14,-15, 6,-10, -14,12, -1,-16, 3}); + auto expU = NDArrayFactory::create('c', {5,5}, {-12.1738, 3, -13.4089, 7,-11, 1.36735, 7, -12.1297,-13, -8, -12.3944,20, -5.60173,-16, -9, -17,-5,-7,-19, -8, -9, 9, 6, 14,-11}); + + ops::helpers::SVD svd(matrix, 4, true, true, true); + svd._m = matrix; + svd._u = matrix2; + svd.deflation1(0,0,2,2); + + ASSERT_TRUE(expM.equalsTo(&svd._m)); + ASSERT_TRUE(expU.equalsTo(&svd._u)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, SVD_test3) { + + auto matrix= NDArrayFactory::create('c', {5,5}, {-17 ,14 ,9 ,-12 ,-12 ,5 ,-4 ,-19 ,-7 ,-12 ,15 ,16 ,17 ,-6 ,8 ,-10 ,14 ,-15 ,6 ,-10 ,-14 ,12 ,-1 ,-16 ,3}); + auto matrix2 = NDArrayFactory::create('c', {2,6}, {18 ,3 ,2 ,7 ,-11 ,7 ,7 ,10 ,-13 ,-8 ,13 ,20}); + auto expM = NDArrayFactory::create('c', {5,5}, {-17,14,9,-12,-12, 5,-4, -19, -7,-12, 15,16,17.0294, -6, 8, -10,14, -15, 6,-10, -14,12, 0,-16, 0}); + auto expU = NDArrayFactory::create('c', {2,6}, {18, 2.58377, 2, 7.16409,-11, 7, 7 ,10.4525 ,-13, -7.39897 ,13 ,20}); + + ops::helpers::SVD svd(matrix, 4, false, true, true, 't'); + svd._m = matrix; + svd._u = matrix2; + svd.deflation1(1,1,2,2); + + ASSERT_TRUE(expM.equalsTo(&svd._m)); + ASSERT_TRUE(expU.equalsTo(&svd._u)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, SVD_test4) { + + auto matrix1 = NDArrayFactory::create('c', {6,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); + auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto expM = NDArrayFactory::create('c', {6,5}, {12, 20, 19,-18, -6, 3, 6, 2, -7, -7, 14, 8, 18,-17, 18, -14,-15,8.06226, 2, 2, -3,-18, 0,-17, 2, 12, 18, 6, -2,-17}); + auto expU = NDArrayFactory::create('c', {6,6}, {-10,-16, -20, 13, 20,-10, -9, -1,-20.7138,4.46525, -4, 20, -11, 19,-18.4812,2.72876, 12,-19, 18,-18, 17, -10,-19, 14, -2, -7, -17, -14, -4,-16, 18, -6, -18, 1,-15,-12}); + auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 2,-18, -13, 14, 2, -2,-11,2.97683,-7.69015,-6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + + ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); + svd._m = matrix1; + svd._u = matrix2; + svd._v = matrix3; + svd.deflation2(1, 2, 2, 1, 1, 2, 1); + + ASSERT_TRUE(expM.equalsTo(&svd._m)); + ASSERT_TRUE(expU.equalsTo(&svd._u)); + ASSERT_TRUE(expV.equalsTo(&svd._v)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, SVD_test5) { + + auto matrix1 = NDArrayFactory::create('c', {6,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); + auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto expM = NDArrayFactory::create('c', {6,5}, {18.4391, 20, 19,-18, -6, 3, 6, 2, -7, -7, 0, 8,18.4391,-17, 18, -14,-15, 1, 2, 2, -3,-18, 8,-17,-19, 12, 18, 6, -2,-17}); + auto expU = NDArrayFactory::create('c', {6,6}, {-10,-16,-20,13, 20,-10, -9,-15.8359, -7,-12.2566, -4, 20, -11,-1.30158, -5,-26.1401, 12,-19, 18,-19.3068, 17, 7.15871,-19, 14, -2, -7,-17, -14, -4,-16, 18, -6,-18, 1,-15,-12}); + auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 2,-1.08465,-13,22.7777, 2, -2,-5.64019, 8,9.65341,-6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + + ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); + svd._m = matrix1; + svd._u = matrix2; + svd._v = matrix3; + svd.deflation2(1, 0, 1, 1, 0, 2, 2); + + ASSERT_TRUE(expM.equalsTo(&svd._m)); + ASSERT_TRUE(expU.equalsTo(&svd._u)); + ASSERT_TRUE(expV.equalsTo(&svd._v)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, SVD_test6) { + + auto matrix1 = NDArrayFactory::create('c', {6,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); + auto matrix2 = NDArrayFactory::create('c', {2,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto expM = NDArrayFactory::create('c', {6,5}, {18.4391, 20, 19,-18, -6, 3, 6, 2, -7, -7, 0, 8,18.4391,-17, 18, -14,-15, 1, 2, 2, -3,-18, 8,-17,-19, 12, 18, 6, -2,-17}); + auto expU = NDArrayFactory::create('c', {2,6}, {-10, -0.542326,-20, 20.6084,20,-10, -9, -15.8359, -7,-12.2566,-4, 20}); + auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 2,-1.08465,-13,22.7777, 2, -2,-5.64019, 8,9.65341,-6, -3, -8, 8, -2, 7, 16, 15, -3, 7, 0}); + + ops::helpers::SVD svd(matrix3, 4, false, true, true, 't'); + svd._m = matrix1; + svd._u = matrix2; + svd._v = matrix3; + svd.deflation2(1, 0, 1, 1, 0, 2, 2); + + ASSERT_TRUE(expM.equalsTo(&svd._m)); + ASSERT_TRUE(expU.equalsTo(&svd._u)); + ASSERT_TRUE(expV.equalsTo(&svd._v)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, SVD_test7) { + + auto matrix1 = NDArrayFactory::create('c', {6,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); + auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + + auto expM = NDArrayFactory::create('c', {6,5}, {12, 20, 19,-18, -6, 3, 6, 2, -7, -7, 14, 8,19.6977,-17, 18, -14,-15, 1, 2, 2, -3,-18, 0,-17, 0, 12, 18, 6, -2,-17}); + auto expU = NDArrayFactory::create('c', {6,6}, {-10, -16,-20, 13, 20,-10, -9,-9.03658, -7,-17.8701, -4, 20, -11, 10.0519, -5,-24.1652, 12,-19, 18, -20.51, 17,-1.82762,-19, 14, -2,-12.0826,-17,-9.95039, -4,-16, 18, -6,-18, 1,-15,-12}); + auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19,-7, 1, 2,-18,-13,14, 2, -2,-11, 8, 2,-6, -3, -8, 8,-2, 7, 16, 15, -3, 7, 0}); + + ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); + svd._m = matrix1; + svd._u = matrix2; + svd._v = matrix3; + svd.deflation(1, 3, 1, 1, 2, 1); + + ASSERT_TRUE(expM.equalsTo(&svd._m)); + ASSERT_TRUE(expU.equalsTo(&svd._u)); + ASSERT_TRUE(expV.equalsTo(&svd._v)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, SVD_test8) { + + auto matrix1 = NDArrayFactory::create('c', {6,5}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,-7 ,14 ,8 ,18 ,-17 ,18 ,-14 ,-15 ,1 ,2 ,2 ,-3 ,-18 ,8 ,-17 ,-19 ,12 ,18 ,6 ,-2 ,-17}); + auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + + auto expM = NDArrayFactory::create('c', {6,5}, {12, 20,19,-18, -6, 3, 6, 2, -7, -7, 14,-15, 2,-17, 18, -14, 8, 1, 18, 2, -3,-18, 8,-17,-19, 12, 18, 6, -2,-17}); + auto expU = NDArrayFactory::create('c', {6,6}, {-10,-20,-16, 13, 20,-10, -9, -7, -1,-20, -4, 20, -11, -5, 19,-18, 12,-19, 18, 17,-18,-10,-19, 14, -2, -7,-17,-14, -4,-16, 18, -6,-18, 1,-15,-12}); + auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19,-7, 1, 2,-18,-13, 2,14, -2,-11, 8,-6, 2, -3, -8, 8, 7,-2, 16, 15, -3, 7, 0}); + + ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); + svd._m = matrix1; + svd._u = matrix2; + svd._v = matrix3; + svd.deflation(0, 2, 2, 1, 2, 1); + + ASSERT_TRUE(expM.equalsTo(&svd._m)); + ASSERT_TRUE(expU.equalsTo(&svd._u)); + ASSERT_TRUE(expV.equalsTo(&svd._v)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, SVD_test9) { + + auto col0 = NDArrayFactory::create('c', {10,1}, {12 ,20 ,19 ,-18 ,-6 ,3 ,6 ,2 ,-7 ,14}); + auto diag = NDArrayFactory::create('c', {10,1}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2}); + auto permut = NDArrayFactory::create('c', {1,10}, {8 ,1 ,4 ,0, 5 ,2 ,9 ,3 ,7 ,6}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + + auto expSingVals = NDArrayFactory::create('c', {10,1}, {-2, 15.304323, 11.2, -1, 1.73489, -12, -15.3043, -12.862, 5.6, 41.4039}); + auto expShifts = NDArrayFactory::create('c', {10,1}, {1, 19, 19, 1, 2, -18, -18, -13, 2, 2}); + auto expMus = NDArrayFactory::create('c', {10,1}, {-3, -3.695677, -7.8, -2, -0.265108, 6, 2.69568, 0.138048, 3.6, 39.4039}); + + auto singVals = NDArrayFactory::create('c', {10,1}); + auto shifts = NDArrayFactory::create('c', {10,1}); + auto mus = NDArrayFactory::create('c', {10,1}); + + ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); + svd.calcSingVals(col0, diag, permut, singVals, shifts, mus); + + ASSERT_TRUE(expSingVals.equalsTo(&singVals)); + ASSERT_TRUE(expShifts.equalsTo(&shifts)); + ASSERT_TRUE(expMus.equalsTo(&mus)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, SVD_test10) { + + auto singVals = NDArrayFactory::create('c', {4,1}, {1 ,1 ,1 ,1}); + auto col0 = NDArrayFactory::create('c', {4,1}, {1 ,1 ,1 ,1}); + auto diag = NDArrayFactory::create('c', {4,1}, {5 ,7 ,-13 ,14}); + auto permut = NDArrayFactory::create('c', {1,4}, {0 ,2 ,3 ,1 }); + auto mus = NDArrayFactory::create('c', {4,1}, {4,1,4,6}); + auto shifts = NDArrayFactory::create('c', {4,1}, {4,2,5,6}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + + auto expZhat = NDArrayFactory::create('c', {4,1}, {0, 0.278208, 72.501953, 0}); + + auto zhat = NDArrayFactory::create('c', {4,1}); + + ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); + svd.perturb(col0, diag, permut, singVals, shifts, mus, zhat); + + ASSERT_NEAR(expZhat.e(1), zhat.e(1), EPS); + ASSERT_NEAR(expZhat.e(2), zhat.e(2), EPS); +} + + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, SVD_test11) { + + auto singVals = NDArrayFactory::create('c', {4,1}, {1 ,1 ,1 ,1}); + auto zhat = NDArrayFactory::create('c', {4,1}, {2 ,1 ,2 ,1}); + auto diag = NDArrayFactory::create('c', {4,1}, {5 ,7 ,-13 ,14}); + auto permut = NDArrayFactory::create('c', {1,4}, {0 ,2 ,3 ,1 }); + auto mus = NDArrayFactory::create('c', {4,1}, {4,1,4,6}); + auto shifts = NDArrayFactory::create('c', {4,1}, {4,2,5,6}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + + auto expU = NDArrayFactory::create('c', {5,5}, {-0.662161, 0.980399,-0.791469,-0.748434, 0, -0.744931, 0.183825,-0.593602,-0.392928, 0, 0.0472972, 0.061275,0.0719517, 0.104781, 0, 0.0662161,0.0356509, 0.126635, 0.523904, 0, 0, 0, 0, 0, 1}); + auto expV = NDArrayFactory::create('c', {4,4}, {-0.745259,-0.965209, -0.899497, -0.892319, -0.652102, 0.21114, -0.39353, -0.156156, -0.0768918,-0.130705,-0.0885868,-0.0773343, 0.115929,0.0818966, 0.167906, 0.416415}); + auto U = NDArrayFactory::create('c', {5,5}); + auto V = NDArrayFactory::create('c', {4,4}); + + + ops::helpers::SVD svd(matrix3, 4, true, true, true, 't'); + svd.calcSingVecs(zhat, diag,permut, singVals, shifts, mus, U, V); + + ASSERT_TRUE(expU.equalsTo(&U)); + ASSERT_TRUE(expV.equalsTo(&V)); + +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, SVD_test12) { + + auto matrix1 = NDArrayFactory::create('c', {6,5}, {-2 ,-3 ,2 ,1 ,0 ,0 ,-4 ,5 ,-2 ,-3 ,-4 ,0 ,5 ,-1 ,-5 ,-3 ,-5 ,3 ,3 ,3 ,-5 ,5 ,-5 ,0 ,2 ,-2 ,-3 ,-4 ,-5 ,-3}); + auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto matrix4 = NDArrayFactory::create('c', {5,5}, {3 ,-8 ,5 ,7 ,-8 ,4 ,-19 ,-12 ,-4 ,-5 ,-11 ,19 ,-2 ,-7 ,1 ,16 ,-5 ,10 ,19 ,-19 ,0 ,-20 ,0 ,-8 ,-13}); + + auto expSingVals = NDArrayFactory::create('c', {4,1}, {8.43282, 5, 2.3, 1.10167}); + auto expU = NDArrayFactory::create('c', {5,5}, {0.401972,0, 0.206791, 0.891995,0, 0,1, 0, 0,0, 0.816018,0,-0.522818,-0.246529,0, -0.415371,0,-0.826982, 0.378904,0, 0,0, 0, 0,1}); + auto expV = NDArrayFactory::create('c', {4,4}, {-0.951851,0,-0.133555,-0.275939, 0,1, 0, 0, 0.290301,0,-0.681937,-0.671333, -0.098513,0,-0.719114, 0.687873}); + + ops::helpers::SVD svd(matrix4, 4, true, true, true, 't'); + svd._m = matrix1; + svd._u = matrix2; + svd._v = matrix3; + NDArray U, singVals, V; + svd.calcBlockSVD(1, 4, U, singVals, V); + + ASSERT_TRUE(expSingVals.equalsTo(&singVals)); + ASSERT_TRUE(expU.equalsTo(&U)); + ASSERT_TRUE(expV.equalsTo(&V)); + + ASSERT_TRUE(expSingVals.isSameShapeStrict(singVals)); + ASSERT_TRUE(expU.isSameShapeStrict(U)); + ASSERT_TRUE(expV.isSameShapeStrict(V)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, SVD_test16) { + + auto matrix1 = NDArrayFactory::create('c', {6,5}, {-2 ,-3 ,2 ,1 ,0 ,0 ,-4 ,5 ,-2 ,-3 ,-4 ,0 ,5 ,-1 ,-5 ,-3 ,-5 ,3 ,3 ,3 ,-5 ,5 ,-5 ,0 ,2 ,-2 ,-3 ,-4 ,-5 ,-3}); + auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto matrix4 = NDArrayFactory::create('c', {5,5}, {3 ,-8 ,5 ,7 ,-8 ,4 ,-19 ,-12 ,-4 ,-5 ,-11 ,19 ,-2 ,-7 ,1 ,16 ,-5 ,10 ,19 ,-19 ,0 ,-20 ,0 ,-8 ,-13}); + + auto expM = NDArrayFactory::create('c', {6,5}, {-2, -3, 2, 1, 0, 0,7.07022, 0, 0, 0, -4, 0,5.09585, 0, 0, -3, 0, 0,3.32256, 0, -5, 0, 0, 0,1.00244, -2, -3, -4, -5, 0}); + auto expU = NDArrayFactory::create('c', {6,6}, {-5.58884,-2.18397,-11.0944, 3.30292, 0,-10, 8.19094, 5.05917, 16.9641,-4.53112, 0, 20, 6.55878, 3.76734, 15.9255,-3.76399, 0,-19, 1.36021, 23.3551,-8.01165, -1.5816, 0, 14, -15.6318,-2.85386, 8.83051, 2.74286, 1,-16, 18, -6, -18, 1,-15,-12}); + auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 2, 14.5866, 3.90133, 1.06593, 9.99376, -2, 9.97311, 2.44445, 6.85159, 2.37014, -3, 0.56907,-8.93313,-5.31596, 3.10096, 16,-10.6859, 1.70708,-7.24295,-10.6975}); + + ops::helpers::SVD svd(matrix4, 4, true, true, true, 't'); + svd._m = matrix1; + svd._u = matrix2; + svd._v = matrix3; + + svd.DivideAndConquer(0, 3, 1, 1, 1); + // svd._m.printIndexedBuffer(); + ASSERT_TRUE(expM.isSameShapeStrict(svd._m)); + ASSERT_TRUE(expU.isSameShapeStrict(svd._u)); + ASSERT_TRUE(expV.isSameShapeStrict(svd._v)); + + ASSERT_TRUE(expM.equalsTo(&svd._m)); + ASSERT_TRUE(expU.equalsTo(&svd._u)); + ASSERT_TRUE(expV.equalsTo(&svd._v)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, SVD_test17) { + + auto matrix1 = NDArrayFactory::create('c', {6,5}, {-2 ,-3 ,2 ,1 ,0 ,0 ,-4 ,5 ,-2 ,-3 ,-4 ,0 ,5 ,-1 ,-5 ,-3 ,-5 ,3 ,3 ,3 ,-5 ,5 ,-5 ,0 ,2 ,-2 ,-3 ,-4 ,-5 ,-3}); + auto matrix2 = NDArrayFactory::create('c', {6,6}, {-10 ,-16 ,-20 ,13 ,20 ,-10 ,-9 ,-1 ,-7 ,-20 ,-4 ,20 ,-11 ,19 ,-5 ,-18 ,12 ,-19 ,18 ,-18 ,17 ,-10 ,-19 ,14 ,-2 ,-7 ,-17 ,-14 ,-4 ,-16 ,18 ,-6 ,-18 ,1 ,-15 ,-12}); + auto matrix3 = NDArrayFactory::create('c', {5,5}, {-18 ,1 ,19 ,-7 ,1 ,2 ,-18 ,-13 ,14 ,2 ,-2 ,-11 ,8 ,2 ,-6 ,-3 ,-8 ,8 ,-2 ,7 ,16 ,15 ,-3 ,7 ,0}); + auto matrix4 = NDArrayFactory::create('c', {5,5}, {3 ,-8 ,5 ,7 ,-8 ,4 ,-19 ,-12 ,-4 ,-5 ,-11 ,19 ,-2 ,-7 ,1 ,16 ,-5 ,10 ,19 ,-19 ,0 ,-20 ,0 ,-8 ,-13}); + + auto expM = NDArrayFactory::create('c', {6,5}, {-2, -3, 2, 1, 0, 0,12.1676, 0, 0, 0, -4, 0,7.49514, 0, 0, -3, 0, 0,5.00951, 0, -5, 0, 0, 0, 1.63594, -2, 0, 0, 0, 0}); + auto expU = NDArrayFactory::create('c', {6,6}, {0.295543,-0.238695, 0.262095,-0.231772, -0.85631,-10, 0.519708,0.0571492,-0.368706,-0.727615, 0.247527, 20, 0.313717,-0.561567,-0.602941, 0.469567,-0.0468295,-19, 0.474589,-0.372165, 0.656962, 0.124776, 0.434845, 14, -0.564717,-0.697061,0.0150082, -0.4252, 0.119081,-16, 18, -6, -18, 1, -15,-12}); + auto expV = NDArrayFactory::create('c', {5,5}, {-18, 1, 19, -7, 1, 2,-0.0366659, 0.977361,-0.0316106,0.205967, -2, -0.670795, -0.151697, -0.503288,0.523185, -3, 0.740124,-0.0841435, -0.486714,0.456339, 16, 0.0300945, -0.121135, 0.71331,0.689645}); + + ops::helpers::SVD svd(matrix4, 10, true, true, true, 't'); + svd._m = matrix1; + svd._u = matrix2; + svd._v = matrix3; + + svd.DivideAndConquer(0, 3, 1, 1, 1); + + ASSERT_TRUE(expM.equalsTo(&svd._m)); + ASSERT_TRUE(expU.equalsTo(&svd._u)); + ASSERT_TRUE(expV.equalsTo(&svd._v)); + + ASSERT_TRUE(expM.isSameShapeStrict(svd._m)); + ASSERT_TRUE(expU.isSameShapeStrict(svd._u)); + ASSERT_TRUE(expV.isSameShapeStrict(svd._v)); +} + +// /////////////////////////////////////////////////////////////////// +// TEST_F(HelpersTests1, SVD_test18) { + +// auto matrix('c', {10,10}, {10 ,7 ,5 ,2 ,17 ,18 ,-18 ,10 ,18 ,1 ,4 ,2 ,-7 ,-18 ,20 ,14 , +// -3 ,-10 ,-4 ,2 ,-17 ,-17 ,1 ,2 ,-9 ,-6 ,-13 ,16 ,-18 ,-13 , +// -10 ,16 ,-10 ,-13 ,-11 ,-6 ,-19 ,17 ,-12 ,3 ,-14 ,7 ,7 ,-9 , +// 5 ,-16 ,7 ,16 ,13 ,12 ,2 ,18 ,6 ,3 ,-8 ,11 ,-1 ,5 ,16 ,-16 , +// -9 ,8 ,10 ,-7 ,-4 ,1 ,-10 ,0 ,20 ,7 ,-11 ,-13 ,-3 ,20 ,-6 , +// 9 ,10 ,8 ,-20 ,1 ,19 ,19 ,-12 ,-20 ,-2 ,17 ,-18 ,-5 ,-14 ,0 +// ,9 ,-16 ,9 ,-15 ,7 ,18 ,-10 ,8 ,-11 ,-4}); + +// auto expS('c', {10, 1}, {65.0394, 56.1583, 48.9987, 39.2841, 35.7296, 22.8439, 17.474, 15.2708, 15.0768, 0.846648}); + +// auto expU('c', {10,10}, {0.413187, 0.159572,0.0238453, 0.601154,-0.0428558, -0.461779, 0.41787, -0.221153, 0.0206268, 0.0532219, +// 0.364377,-0.154281, 0.199857,-0.0943331, 0.415653, -0.139834, -0.258458, 0.10677, 0.72003,-0.0749772, +// -0.315063,-0.418079,-0.377499, 0.37031, 0.0123835, 0.300036, 0.153702, -0.129223, 0.390675, 0.403962, +// 0.102001,-0.216667, -0.74093,-0.166164,-0.0269665, -0.240065, 0.0549761,-0.0178001, 0.0197525, -0.55134, +// -0.107298, 0.386899,-0.377536, 0.033214, 0.486739, -0.245438, -0.43788, -0.208875, -0.170449, 0.365491, +// 0.18026, 0.240482,-0.115801, 0.237399, -0.643413, 0.139274, -0.582963, -0.116222, 0.224524,-0.0525887, +// 0.141172, 0.340505,-0.261653, 0.186411, 0.0625811, 0.19585, 0.128195, 0.832893, 0.0319884, 0.0864513, +// -0.385777,-0.330504, 0.128342, 0.156083, -0.200883, -0.648548, -0.256507, 0.40519,-0.0434365, 0.0909978, +// 0.574478,-0.371028,-0.136672,-0.328417, -0.190226,-0.0476664,-0.0399815, 0.0687528, -0.242039, 0.549918, +// 0.209886,-0.398294,0.0919207, 0.490454, 0.305228, 0.280486, -0.341358, 0.0540678, -0.432618, -0.264332}); + +// auto expV('c', {10,10}, {0.423823,-0.0845148, 0.389647, -0.10717,-0.168732, 0.123783, 0.159237, -0.450407, -0.611513,-0.0629076, +// 0.412121, 0.317493, -0.355665,-0.383203,-0.382616,-0.309073, -0.21869,-0.0746378, 0.0829771, 0.392186, +// -0.0603483, 0.232234, 0.0383737, 0.435441,0.0829318, 0.327822,-0.206101, 0.184083, -0.34018, 0.667018, +// -0.453935, 0.119616, 0.288392, 0.184366,-0.524289, -0.42264, 0.41005,-0.0505891,0.00333608, 0.195602, +// 0.247802, 0.0776165, 0.33026, 0.190986, 0.526809,-0.345006,0.0651023, -0.386472, 0.395169, 0.284091, +// 0.426355, -0.269507, 0.304685, 0.386708,-0.257916,-0.287742,-0.329622, 0.463719, 0.0613767, -0.16261, +// -0.384582, 0.241486, 0.425935,-0.292636,0.0465594,-0.125018,-0.685871, -0.112806,-0.0977978, -0.127356, +// -0.121678, -0.06796, -0.501443, 0.473165,0.0422977,-0.369324,-0.248758, -0.408769, -0.305785, -0.211138, +// 0.186099, 0.809997, 0.0338281, 0.268965, -0.04829, 0.141617, 0.12121, 0.0362537, 0.0831986, -0.436428, +// 0.0174496, 0.161638,-0.0334757,-0.224027, 0.439364,-0.478697, 0.237318, 0.457809, -0.483235,-0.0253522}); + +// ops::helpers::SVD svd(matrix, 8, true, true, true); +// // svd._u.printShapeInfo(); +// // svd._u.printIndexedBuffer(); + +// ASSERT_TRUE(expS.equalsTo(&svd._s)); +// ASSERT_TRUE(expU.equalsTo(&svd._u)); +// ASSERT_TRUE(expV.equalsTo(&svd._v)); + +// ASSERT_TRUE(expS.isSameShapeStrict(svd._s)); +// ASSERT_TRUE(expU.isSameShapeStrict(svd._u)); +// ASSERT_TRUE(expV.isSameShapeStrict(svd._v)); +// } + + +// /////////////////////////////////////////////////////////////////// +// TEST_F(HelpersTests1, SVD_test19) { + +// auto matrix('c', {11,10}, {10 ,7 ,5 ,2 ,17 ,18 ,-18 ,10 ,18 ,1 ,4 ,2 ,-7 ,-18 ,20 ,14 , +// -3 ,-10 ,-4 ,2 ,-17 ,-17 ,1 ,2 ,-9 ,-6 ,-13 ,16 ,-18 ,-13 , +// -10 ,16 ,-10 ,-13 ,-11 ,-6 ,-19 ,17 ,-12 ,3 ,-14 ,7 ,7 ,-9 , +// 5 ,-16 ,7 ,16 ,13 ,12 ,2 ,18 ,6 ,3 ,-8 ,11 ,-1 ,5 ,16 ,-16 , +// -9 ,8 ,10 ,-7 ,-4 ,1 ,-10 ,0 ,20 ,7 ,-11 ,-13 ,-3 ,20 ,-6 , +// 9 ,10 ,8 ,-20 ,1 ,19 ,19 ,-12 ,-20 ,-2 ,17 ,-18 ,-5 ,-14 ,0 +// ,9 ,-16 ,9 ,-15 ,7 ,18 ,-10 ,8 ,-11 ,-4, +// -7, 1, -2, 15, 0, 4, -9,19, -3, 10 }); + +// auto expS('c', {10, 1}, {65.5187, 56.305, 50.9808, 41.6565, 35.8698, 29.3898, 17.9743, 15.3568, 15.2223, 0.846847}); + +// auto expU('c', {11,11}, {-0.387999,-0.117659, 0.162976, 0.641067,-0.0174306, -0.181469,-0.218643, -0.308042, 0.0670776,-0.0632539, -0.462228, +// -0.37021, 0.14822, -0.195157,-0.0467394, -0.381275, -0.183363, 0.326599, -0.370579, -0.56626, 0.0798798, 0.225133, +// 0.339692, 0.433146, 0.30841, 0.134184, -0.108725, 0.466056,-0.153546, -0.359783, -0.189621, -0.402737, 0.0605675, +// -0.0650167, 0.268868, 0.662416, -0.327524, 0.0339198,-0.0916729,0.0415428, -0.0765093,-0.0288338, 0.546108, -0.247418, +// 0.114029,-0.361828, 0.379255,-0.0935836, -0.488912, -0.125232, 0.480666,-0.00544881, 0.280747, -0.36698,-0.0648559, +// -0.174798, -0.21859, 0.178313, 0.212153, 0.579101, 0.369942, 0.551063, -0.139813,-0.0296135, 0.0572204, 0.212783, +// -0.133981,-0.311817, 0.304673, 0.0865395, -0.104221, 0.196295,-0.191271, 0.571084, -0.603697,-0.0868996,-0.0196788, +// 0.398676, 0.319697, -0.112145, 0.235089, 0.201666, -0.337134, 0.43406, 0.261686, -0.283102,-0.0999458, -0.411893, +// -0.559998, 0.392802, 0.0996997, -0.281135, 0.24017, -0.136769,0.0121463, 0.218664, 0.127577, -0.550001,0.00227476, +// -0.197522, 0.403875,-0.0647804, 0.383315, -0.388502, 0.335719, 0.20912, 0.404926, 0.309087, 0.266437, 0.0942471, +// 0.140425,0.0934688, 0.325994, 0.345081, 0.0825574, -0.521239,-0.129018, 0.0806886, 0.0442647, 0.014397, 0.665103}); + +// auto expV('c', {10,10}, {-0.4428, 0.0661762,-0.361903, 0.0307317, 0.19574,-0.0356551,-0.241991, 0.0866805, 0.74701, 0.062837, +// -0.400091, -0.277277, 0.375095, -0.323052, 0.443668, -0.264809, 0.292881, -0.106586,-0.00623963,-0.392226, +// 0.0536693, -0.232105,0.0106246, 0.332557, -0.167406, 0.400872,0.0835708, 0.414598, 0.141906,-0.666936, +// 0.473793, -0.121962,-0.147941, 0.414665, 0.538964, -0.372149,-0.285458, -0.132952, -0.0166319,-0.195945, +// -0.251722,-0.0813691,-0.233887, 0.280439, -0.512597, -0.328782, 0.074277, -0.581806, -0.0327555,-0.284121, +// -0.406324, 0.284462,-0.168731, 0.518021, 0.226396, -0.109282, 0.381083, 0.305342, -0.359301, 0.162524, +// 0.335857, -0.302206,-0.484806, -0.196382,0.00286755, -0.111789, 0.672115, 0.0705632, 0.191787, 0.127533, +// 0.185896, 0.134279, 0.608397, 0.382412,-0.0997649, -0.117987, 0.326934,-0.0941208, 0.496913, 0.210914, +// -0.201675, -0.795446,0.0916484, 0.267237,0.00604554, 0.167517, -0.13914,-0.0355323, -0.0869256, 0.436465, +// 0.00123325, -0.142684,0.0978458,-0.0945446, -0.349755, -0.674457,-0.196126, 0.587134,-0.00964182,0.0249317}); + +// ops::helpers::SVD svd(matrix, 8, true, true, true); + +// ASSERT_TRUE(expS.equalsTo(&svd._s)); +// ASSERT_TRUE(expU.equalsTo(&svd._u)); +// ASSERT_TRUE(expV.equalsTo(&svd._v)); + +// ASSERT_TRUE(expS.isSameShapeStrict(svd._s)); +// ASSERT_TRUE(expU.isSameShapeStrict(svd._u)); +// ASSERT_TRUE(expV.isSameShapeStrict(svd._v)); +// } + + +// /////////////////////////////////////////////////////////////////// +// TEST_F(HelpersTests1, SVD_test20) { + +// auto matrix('c', {10,11}, {10 ,7 ,5 ,2 ,17 ,18 ,-18 ,10 ,18 ,1 ,4 ,2 ,-7 ,-18 ,20 ,14 , +// -3 ,-10 ,-4 ,2 ,-17 ,-17 ,1 ,2 ,-9 ,-6 ,-13 ,16 ,-18 ,-13 , +// -10 ,16 ,-10 ,-13 ,-11 ,-6 ,-19 ,17 ,-12 ,3 ,-14 ,7 ,7 ,-9 , +// 5 ,-16 ,7 ,16 ,13 ,12 ,2 ,18 ,6 ,3 ,-8 ,11 ,-1 ,5 ,16 ,-16 , +// -9 ,8 ,10 ,-7 ,-4 ,1 ,-10 ,0 ,20 ,7 ,-11 ,-13 ,-3 ,20 ,-6 , +// 9 ,10 ,8 ,-20 ,1 ,19 ,19 ,-12 ,-20 ,-2 ,17 ,-18 ,-5 ,-14 ,0 +// ,9 ,-16 ,9 ,-15 ,7 ,18 ,-10 ,8 ,-11 ,-4, +// -7, 1, -2, 15, 0, 4, -9,19, -3, 10 }); + +// auto expS('c', {10, 1}, {68.9437, 54.8773, 50.7858, 42.4898, 35.1984, 26.6285, 21.376, 12.2334, 5.9112, 0.38292}); + +// auto expU('c', {10,10}, {0.30332,-0.0677785, 0.155514, -0.722623,-0.0843687,-0.0712535, 0.414936, -0.15422, -0.381536,-0.057561, +// 0.473286, 0.0231518, 0.0878106, 0.45493, -0.311654, 0.138957, 0.311305, 0.509971, -0.288207,0.0656506, +// -0.131548, 0.32051, 0.489848,-0.0539042, -0.521328, -0.363728, -0.328685,-0.0329672,-0.0726502, 0.344431, +// 0.072974, 0.522632, -0.477056, 0.0618953,-0.0980883, -0.095653, -0.26596, -0.15453, -0.475107,-0.388594, +// 0.267569, -0.336154,-0.0930604, -0.261336, -0.39945, 0.480346, -0.568317, 0.0593335, 0.102036,-0.106029, +// -0.0919782, -0.460136, 0.106434, 0.327722, 0.0952523, 0.0915698, -0.129052, -0.460878, -0.59722, 0.240608, +// -0.248827, -0.48834, -0.243788, -0.106636,-0.0803772, -0.567457, -0.12005, 0.480504, -0.188409,-0.139802, +// 0.643408, -0.16245, -0.152596, 0.16849,-0.0120438, -0.51616,-0.0694232, -0.36172, 0.322169,0.0440701, +// -0.229467,-0.0227008, -0.588303,-0.0327104, -0.482264, 0.0794715, 0.340158, -0.175969, 0.108784, 0.449731, +// 0.229718, 0.169979, -0.227516, -0.21815, 0.454459, 0.017476, -0.278516, 0.287333, -0.148844, 0.655637}); + +// auto expV('c', {11,11}, {0.190806, -0.193628, 0.383793,-0.0266376, 0.113035, 0.158361, 0.0297803, -0.793229, -0.13761,-0.260666, -0.152503, +// -0.303449, 0.0392386, 0.250627, -0.165231, 0.141567, 0.0479565, 0.72763, 0.14053, -0.339907, 0.224366, -0.280806, +// -0.159724, -0.38984, -0.256355, -0.337861, 0.075089, -0.237427, -0.153718, -0.217747, 0.320899, 0.455058, -0.446697, +// 0.376823, -0.560303, 0.269135, 0.265416,-0.00742902, 0.0263377, -0.192808, 0.435842, -0.275365,0.0511804, -0.30799, +// 0.522537, 0.209791, -0.44191, -0.282323, -0.12139, 0.226382, 0.221075, 0.0844301, 0.0285412,-0.297578, -0.443394, +// 0.0588008, 0.115035, 0.54835, -0.52266, -0.141345, 0.411122, -0.182423, 0.213721, 0.353022, 0.119504, 0.0508673, +// -0.299021,-0.0424794, -0.285618, 0.177961, 0.35831, 0.769783, -0.215983,-0.00423939, -0.110575,0.0928082,-0.0841152, +// -0.0977062, -0.624782, -0.240391, -0.276154, -0.342018, 0.199695, 0.268881, 0.00402219,-0.0536164, -0.17679, 0.450283, +// 0.428931, 0.0748696, -0.120853, -0.360103, 0.37093,-0.0611563, -0.100263, -0.0604207, -0.432926, 0.412875, 0.39142, +// -0.35553, 0.127463,-0.0199906, -0.343149, -0.315968, -0.115698, -0.442585, 0.0126156, -0.584161,-0.219242, -0.20156, +// -0.134753, -0.154272, 0.037343, -0.281348, 0.666324, -0.213813,-0.0427932, 0.238783, 0.132347,-0.557478, 0.0253325}); + +// ops::helpers::SVD svd(matrix, 8, true, true, true); + +// ASSERT_TRUE(expS.equalsTo(&svd._s)); +// ASSERT_TRUE(expU.equalsTo(&svd._u)); +// ASSERT_TRUE(expV.equalsTo(&svd._v)); + +// ASSERT_TRUE(expS.isSameShapeStrict(svd._s)); +// ASSERT_TRUE(expU.isSameShapeStrict(svd._u)); +// ASSERT_TRUE(expV.isSameShapeStrict(svd._v)); +// } + + +///////////////////////////////////////////////////////////////////// +//TEST_F(HelpersTests1, reverseArray_test1) { +// +// auto inArr = NDArrayFactory::create('c', {2,5}, {1,2,3,4,5,6,7,8,9,10}); +// auto exp = NDArrayFactory::create('c', {2,5}, {10,9,8,7,6,5,4,3,2,1}); +// auto outArr = NDArrayFactory::create('c', {2,5}); +// +// +// ops::helpers::reverseArray(sd::LaunchContext ::defaultContext(), inArr.getBuffer(), inArr.shapeInfo(), outArr.getBuffer(), outArr.shapeInfo()); +// +// ASSERT_TRUE(outArr.equalsTo(&exp)); +// ASSERT_TRUE(outArr.isSameShapeStrict(exp)); +//} +// +// +///////////////////////////////////////////////////////////////////// +//TEST_F(HelpersTests1, reverseArray_test2) { +// +// auto inArr = NDArrayFactory::create('c', {2,5}, {1,2,3,4,5,6,7,8,9,10}); +// auto exp = NDArrayFactory::create('c', {2,5}, {10,9,8,7,6,5,4,3,2,1}); +// +// +// ops::helpers::reverseArray(sd::LaunchContext ::defaultContext(), inArr.getBuffer(), inArr.shapeInfo(), inArr.getBuffer(), inArr.shapeInfo()); +// +// ASSERT_TRUE(inArr.equalsTo(&exp)); +// ASSERT_TRUE(inArr.isSameShapeStrict(exp)); +//} +// +// +///////////////////////////////////////////////////////////////////// +//TEST_F(HelpersTests1, reverseArray_test3) { +// +// auto inArr = NDArrayFactory::create('c', {2,5}, {1,2,3,4,5,6,7,8,9,10}); +// auto exp = NDArrayFactory::create('c', {2,5}, {5,4,3,2,1,6,7,8,9,10}); +// auto outArr = NDArrayFactory::create('c', {2,5}); +// +// ops::helpers::reverseArray(sd::LaunchContext ::defaultContext(), inArr.getBuffer(), inArr.shapeInfo(), outArr.getBuffer(), outArr.shapeInfo(), 5); +// +// ASSERT_TRUE(outArr.equalsTo(&exp)); +// ASSERT_TRUE(outArr.isSameShapeStrict(exp)); +//} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, rnnCell_test1) { + + const int bS = 2; + const int inSize = 4; + const int numUnits = 4; + + NDArray xt('c', {bS, inSize}, sd::DataType::DOUBLE); + NDArray ht_1('c', {bS, numUnits}, sd::DataType::DOUBLE); + NDArray Wx('c', {inSize, numUnits}, sd::DataType::DOUBLE); + NDArray Wh('c', {numUnits, numUnits}, sd::DataType::DOUBLE); + NDArray b ('c', {2*numUnits}, {0.0,0.0,0.0,0.0, 0.1,0.2,0.3,0.4}); + NDArray ht('c', {bS, numUnits}, sd::DataType::DOUBLE); + + xt.assign(0.1); + ht_1.assign(0.2); + Wx.assign(0.3); + Wh.assign(0.4); + + NDArray expHt('c', {bS, numUnits}, {0.492988, 0.56489956, 0.6291452 , 0.6858091,0.492988, 0.56489956, 0.6291452 , 0.6858091}); + + ops::helpers::rnnCell(sd::LaunchContext ::defaultContext(), &xt, &Wx, &Wh, &b, &ht_1, &ht); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); +} + + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, rnnCell_test2) { + + const int bS = 2; + const int inSize = 10; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {bS, inSize}); + auto ht_1 = NDArrayFactory::create('c', {bS, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2*numUnits}, {0.0,0.0,0.0,0.0, 0.1,0.2,0.3,0.4}); + + auto ht = NDArrayFactory::create('c', {bS, numUnits}); + + xt.assign(0.1); + ht_1.assign(0.2); + Wx.assign(0.3); + Wh.assign(0.4); + + auto expHt = NDArrayFactory::create('c', {bS, numUnits}, {0.6169093,0.67506987,0.72589741,0.76986654,0.6169093,0.67506987,0.72589741,0.76986654}); + + ops::helpers::rnnCell(sd::LaunchContext ::defaultContext(), &xt, &Wx, &Wh, &b, &ht_1, &ht); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, rnnCell_test3) { + + const int bS = 2; + const int inSize = 10; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {bS, inSize}); + auto ht_1 = NDArrayFactory::create('c', {bS, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2*numUnits}, {0.01,0.02,0.03,0.04, 0.05,0.06,0.07,0.08}); + + auto ht = NDArrayFactory::create('c', {bS, numUnits}); + + xt.assign(0.1); + ht_1.assign(0.2); + Wx.assign(0.3); + Wh.assign(0.4); + + auto expHt = NDArrayFactory::create('c', {bS, numUnits}, {0.5915195, 0.6043678, 0.6169093, 0.6291452,0.5915195, 0.6043678, 0.6169093, 0.6291452}); + + ops::helpers::rnnCell(sd::LaunchContext ::defaultContext(), &xt, &Wx, &Wh, &b, &ht_1, &ht); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, rnnCell_test4) { + + const int bS = 2; + const int inSize = 3; + const int numUnits = 4; + + auto xt = NDArrayFactory::create('c', {bS, inSize}); + auto ht_1 = NDArrayFactory::create('c', {bS, numUnits}); + auto Wx = NDArrayFactory::create('c', {inSize, numUnits}); + auto Wh = NDArrayFactory::create('c', {numUnits, numUnits}); + auto b = NDArrayFactory::create('c', {2*numUnits}); + + auto ht = NDArrayFactory::create('c', {bS, numUnits}); + + xt.linspace(0.01, 0.01); + ht_1 = 0.2; + Wx = 0.3; + Wh = 0.4; + b = 0.25; + + auto expHt = NDArrayFactory::create('c', {bS, numUnits}, {0.68474828, 0.68474828, 0.68474828, 0.68474828,0.69882484, 0.69882484, 0.69882484, 0.69882484}); + + ops::helpers::rnnCell(sd::LaunchContext ::defaultContext(), &xt, &Wx, &Wh, &b, &ht_1, &ht); + + ASSERT_TRUE(expHt.isSameShape(ht)); + ASSERT_TRUE(expHt.equalsTo(ht)); +} + +#endif +//////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, mmulHelper_test_1) { + + auto x = NDArrayFactory::create('c', {3,3}, {10,11,12,13,14,15,16,17,18}); + auto y = NDArrayFactory::create('c', {3,3}, {1,2,3,4,5,6,7,8,9}); + auto expected = NDArrayFactory::create('c', {3,3}, {138.,171.,204. ,174.,216.,258. ,210.,261.,312.}); + + auto result = MmulHelper::mmul(&x, &y, nullptr, 1., 0.); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + delete result; + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, mmulHelper_test_2) { + + auto x = NDArrayFactory::create('c', {3,3}, {10,11,12,13,14,15,16,17,18}); + auto y = NDArrayFactory::create('c', {3,3}, {1,2,3,4,5,6,7,8,9}); + auto expected = NDArrayFactory::create('c', {3,3}, {138.,171.,204. ,174.,216.,258. ,210.,261.,312.}); + auto result = NDArrayFactory::create('c', {3,3}); + + MmulHelper::mmul(&x, &y, &result, 1., 0.); + + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, mmulHelper_test_3) { + + auto x = NDArrayFactory::create('c', {3,4}); x.linspace(1); + auto y = NDArrayFactory::create('c', {4,5}); y.linspace(1); + auto expected = NDArrayFactory::create('c', {3,5}, {110.,120.,130.,140.,150.,246.,272.,298.,324.,350.,382.,424.,466.,508.,550.}); + + auto result = MmulHelper::mmul(&x, &y, nullptr, 1., 0.); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + delete result; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, mmulHelper_test_4) { + + auto x = NDArrayFactory::create('c', {3,4}); x.linspace(1); + auto y = NDArrayFactory::create('c', {4,5}); y.linspace(1); + auto expected = NDArrayFactory::create('c', {3,5}, {110.,120.,130.,140.,150.,246.,272.,298.,324.,350.,382.,424.,466.,508.,550.}); + auto result = NDArrayFactory::create('c', {3,5}); + + MmulHelper::mmul(&x, &y, &result, 1., 0.); + + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); +} + + +//////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, mmulHelper_test_5) { + + auto x = NDArrayFactory::create('c', {4,3}); x.linspace(1); + auto y = NDArrayFactory::create('c', {3,5}); y.linspace(1); + auto expected = NDArrayFactory::create('c', {4,5}, {46., 52., 58., 64., 70.,100.,115.,130.,145.,160.,154.,178.,202.,226.,250.,208.,241.,274.,307.,340.}); + + auto result = MmulHelper::mmul(&x, &y, nullptr, 1., 0.); + + ASSERT_TRUE(expected.isSameShape(result)); + ASSERT_TRUE(expected.equalsTo(result)); + + delete result; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, mmulHelper_test_6) { + + auto x = NDArrayFactory::create('c', {4,3}); x.linspace(1); + auto y = NDArrayFactory::create('c', {3,5}); y.linspace(1); + auto expected = NDArrayFactory::create('c', {4,5}, {46., 52., 58., 64., 70.,100.,115.,130.,145.,160.,154.,178.,202.,226.,250.,208.,241.,274.,307.,340.}); + auto result = NDArrayFactory::create('c', {4,5}); + + MmulHelper::mmul(&x, &y, &result, 1., 0.); + + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, mmulHelper_test_7) { + + auto x = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4, 4}, {1,2, 3, 4,2,4, 6, 8,3,6, 9,12,4,8,12,16}); + auto result = NDArrayFactory::create('c', {4,4}); + + MmulHelper::mmul(&x, &y, &result, 1., 0.); + + ASSERT_TRUE(exp.isSameShape(&result)); + ASSERT_TRUE(exp.equalsTo(&result)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, tensordot_test_1) { + + auto a = NDArrayFactory::create('c', {2, 3, 4}); + auto b = NDArrayFactory::create('c', {2, 5, 3}); + + auto c = MmulHelper::tensorDot(&a, &b, {1}, {2}); + + ASSERT_TRUE(c->isSameShape({2,4,2,5})); + delete c; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, tensordot_test_2) { + + auto a = NDArrayFactory::create('c', {7, 3, 4, 6}); + auto b = NDArrayFactory::create('c', {2, 5, 3, 8, 4}); + + auto c = MmulHelper::tensorDot(&a, &b, {2,1}, {4,2}); + + ASSERT_TRUE(c->isSameShape({7,6,2,5,8})); + delete c; +} + +//////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, tensordot_test_3) { + + auto a = NDArrayFactory::create('c', {7, 3, 4, 6}); + auto b = NDArrayFactory::create('c', {2, 5, 3, 8, 4}); + auto c = NDArrayFactory::create('f', {7,6,2,8,5}); + + MmulHelper::tensorDot(&a, &b, &c, {2,1}, {4,2}, {0,1,2,4,3}); + + ASSERT_TRUE(c.isSameShape({7,6,2,8,5})); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, tensordot_test_4) { + + auto a = NDArrayFactory::create('c', {7, 3, 4, 3}); + auto b = NDArrayFactory::create('c', {2, 5, 3, 2, 4}); + auto c = NDArrayFactory::create('f', {7,3,2,2,5}); + auto expected = NDArrayFactory::create('c', {7,3,2,2,5}, { 754.5, 2014.5, 3274.5, 4534.5 , 5794.5, 964.5, 2224.5, 3484.5, 4744.5, 6004.5, 7054.5, 8314.5, 9574.5, 10834.5, 12094.5, 7264.5, 8524.5, 9784.5, 11044.5, 12304.5, 786. , 2118. , 3450. , 4782. , 6114. , 1008. , 2340. , 3672. , 5004. , 6336. , + 7446. , 8778. , 10110. , 11442. , 12774. , 7668. , 9000. , 10332. , 11664. , 12996. , 817.5, 2221.5, 3625.5, 5029.5, 6433.5, 1051.5, 2455.5, 3859.5, 5263.5, 6667.5, 7837.5, 9241.5, 10645.5, 12049.5, 13453.5, 8071.5, 9475.5, 10879.5, 12283.5, 13687.5, + 1888.5, 5740.5, 9592.5, 13444.5, 17296.5, 2530.5, 6382.5, 10234.5, 14086.5, 17938.5,21148.5, 25000.5, 28852.5, 32704.5, 36556.5,21790.5, 25642.5, 29494.5, 33346.5, 37198.5, 1920. , 5844. , 9768. , 13692. , 17616. , 2574. , 6498. , 10422. , 14346. , 18270. , + 21540. , 25464. , 29388. , 33312. , 37236. ,22194. , 26118. , 30042. , 33966. , 37890. , 1951.5, 5947.5, 9943.5, 13939.5, 17935.5, 2617.5, 6613.5, 10609.5, 14605.5, 18601.5,21931.5, 25927.5, 29923.5, 33919.5, 37915.5,22597.5, 26593.5, 30589.5, 34585.5, 38581.5, + 3022.5, 9466.5, 15910.5, 22354.5, 28798.5, 4096.5, 10540.5, 16984.5, 23428.5, 29872.5,35242.5, 41686.5, 48130.5, 54574.5, 61018.5,36316.5, 42760.5, 49204.5, 55648.5, 62092.5, 3054. , 9570. , 16086. , 22602. , 29118. , 4140. , 10656. , 17172. , 23688. , 30204. , + 35634. , 42150. , 48666. , 55182. , 61698. ,36720. , 43236. , 49752. , 56268. , 62784. , 3085.5, 9673.5, 16261.5, 22849.5, 29437.5, 4183.5, 10771.5, 17359.5, 23947.5, 30535.5,36025.5, 42613.5, 49201.5, 55789.5, 62377.5,37123.5, 43711.5, 50299.5, 56887.5, 63475.5, + 4156.5, 13192.5, 22228.5, 31264.5, 40300.5, 5662.5, 14698.5, 23734.5, 32770.5, 41806.5,49336.5, 58372.5, 67408.5, 76444.5, 85480.5,50842.5, 59878.5, 68914.5, 77950.5, 86986.5, 4188. , 13296. , 22404. , 31512. , 40620. , 5706. , 14814. , 23922. , 33030. , 42138. , + 49728. , 58836. , 67944. , 77052. , 86160. ,51246. , 60354. , 69462. , 78570. , 87678. , 4219.5, 13399.5, 22579.5, 31759.5, 40939.5, 5749.5, 14929.5, 24109.5, 33289.5, 42469.5,50119.5, 59299.5, 68479.5, 77659.5, 86839.5,51649.5, 60829.5, 70009.5, 79189.5, 88369.5, + 5290.5, 16918.5, 28546.5, 40174.5, 51802.5, 7228.5, 18856.5, 30484.5, 42112.5, 53740.5,63430.5, 75058.5, 86686.5, 98314.5,109942.5,65368.5, 76996.5, 88624.5,100252.5,111880.5, 5322. , 17022. , 28722. , 40422. , 52122. , 7272. , 18972. , 30672. , 42372. , 54072. , + 63822. , 75522. , 87222. , 98922. ,110622. ,65772. , 77472. , 89172. ,100872. ,112572. , 5353.5, 17125.5, 28897.5, 40669.5, 52441.5, 7315.5, 19087.5, 30859.5, 42631.5, 54403.5,64213.5, 75985.5, 87757.5, 99529.5,111301.5,66175.5, 77947.5, 89719.5,101491.5,113263.5, + 6424.5, 20644.5, 34864.5, 49084.5, 63304.5, 8794.5, 23014.5, 37234.5, 51454.5, 65674.5,77524.5, 91744.5,105964.5,120184.5,134404.5,79894.5, 94114.5,108334.5,122554.5,136774.5, 6456. , 20748. , 35040. , 49332. , 63624. , 8838. , 23130. , 37422. , 51714. , 66006. , + 77916. , 92208. ,106500. ,120792. ,135084. ,80298. , 94590. ,108882. ,123174. ,137466. , 6487.5, 20851.5, 35215.5, 49579.5, 63943.5, 8881.5, 23245.5, 37609.5, 51973.5, 66337.5,78307.5, 92671.5,107035.5,121399.5,135763.5,80701.5, 95065.5,109429.5,123793.5,138157.5, + 7558.5, 24370.5, 41182.5, 57994.5, 74806.5,10360.5, 27172.5, 43984.5, 60796.5, 77608.5,91618.5,108430.5,125242.5,142054.5,158866.5,94420.5,111232.5,128044.5,144856.5,161668.5, 7590. , 24474. , 41358. , 58242. , 75126. ,10404. , 27288. , 44172. , 61056. , 77940. , + 92010. ,108894. ,125778. ,142662. ,159546. ,94824. ,111708. ,128592. ,145476. ,162360. , 7621.5, 24577.5, 41533.5, 58489.5, 75445.5,10447.5, 27403.5, 44359.5, 61315.5, 78271.5,92401.5,109357.5,126313.5,143269.5,160225.5,95227.5,112183.5,129139.5,146095.5,163051.5}); + + a.linspace(0.5, 0.5); + b.linspace(0.5, 0.5); + + MmulHelper::tensorDot(&a, &b, &c, {2,1}, {4,2}, {0,1,2,4,3}); + + ASSERT_TRUE(c.isSameShape(expected)); + ASSERT_TRUE(c.equalsTo(expected)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, tensordot_test_5) { + + auto a = NDArrayFactory::create('c', {2, 3}); + auto b = NDArrayFactory::create('c', {3, 4}); + auto c = NDArrayFactory::create('f', {2, 4}); + auto expected = NDArrayFactory::create('c', {2, 4}, {9.5,11.,12.5 ,14.,20.75 ,24.5,28.25,32.}); + + a.linspace(0.5, 0.5); + b.linspace(0.5, 0.5); + + MmulHelper::tensorDot(&a, &b, &c, {1}, {0}); + // c.printIndexedBuffer(); + + ASSERT_TRUE(c.isSameShape(expected)); + ASSERT_TRUE(c.equalsTo(expected)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, tensordot_test_6) { + + int bS=2, iH=3,iW=2, iC=2,mC=2, kH=2,kW=2; + int oC=iC*mC; + int oH=3,oW=2; + + auto a = NDArrayFactory::create('c', {bS, iC, kH, kW, oH, oW}); + auto b = NDArrayFactory::create('c', {kH, kW, iC, mC}); + auto c = NDArrayFactory::create('c', {bS, oH, oW, iC*mC}); + auto expected = NDArrayFactory::create('c', {bS, oH, oW, iC*mC}, {100.,110.,336.,370.,107.,118.,345.,380.,114.,126.,354.,390.,121.,134.,363.,400.,128.,142.,372.,410.,135.,150.,381.,420., + 436.,494.,768.,850.,443.,502.,777.,860.,450.,510.,786.,870.,457.,518.,795.,880.,464.,526.,804.,890.,471.,534.,813.,900.}); + + a.linspace(0.5, 0.5); + b.linspace(0.5, 0.5); + + auto cR = c.reshape(a.ordering(), {bS, oH, oW, iC, mC}); + + // [iC, bS*oH*oW, kW*kH] x [iC, kH*kW, mC] = [iC, bS*oH*oW, mC] + MmulHelper::tensorDot(&a, &b, &cR, {{1,0,4,5,2,3}, {iC,bS*oH*oW,kW*kH}}, {{2,0,1,3},{iC,kH*kW,mC}}, {{3,0,1,2,4},{iC, bS*oH*oW, mC}}); + // c.printBuffer(); + + ASSERT_TRUE(c.isSameShape(expected)); + ASSERT_TRUE(c.equalsTo(expected)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, mmmulHelperAgain) { + auto x = NDArrayFactory::create('c', {128, 156}); + auto y = NDArrayFactory::create('c', {156, 256}); + auto z = NDArrayFactory::create('c', {128, 256}); + auto e = NDArrayFactory::create('c', {128, 256}); + + x.assign(1.0f); + y.assign(1.0f); + e.assign(156.0f); + + MmulHelper::mmul(&x, &y, &z); + + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, OpArgsHolder_test1) { + + auto x1 = NDArrayFactory::create('c', {1, 1}); + auto x2 = NDArrayFactory::create('c', {2, 2}); + auto x3 = NDArrayFactory::create('c', {3, 3}); + + OpArgsHolder holder1({&x1}); + OpArgsHolder holder2({&x1,&x2,&x3}, {4.f, 5.f}, {6}); + + ASSERT_TRUE(holder1.getNumInArrs() == 1); + ASSERT_TRUE(holder1.getNumTArgs() == 0); + ASSERT_TRUE(holder1.getNumIArgs() == 0); + + ASSERT_TRUE(holder2.getNumInArrs() == 3); + ASSERT_TRUE(holder2.getNumTArgs() == 2); + ASSERT_TRUE(holder2.getNumIArgs() == 1); + + const std::vector& isArrAlloc1 = holder1.getAllocInfo(); + ASSERT_TRUE(isArrAlloc1.size() == 0); + + const std::vector& isArrAlloc2 = holder2.getAllocInfo(); + ASSERT_TRUE(isArrAlloc2.size() == 0); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, OpArgsHolder_test2) { + + auto x1 = NDArrayFactory::create('c', {1, 1}); + auto x2 = NDArrayFactory::create('c', {2, 2}); + auto x3 = NDArrayFactory::create('c', {3, 3}); + auto grad = NDArrayFactory::create('c', {2, 3}); + + OpArgsHolder holderFF({&x1,&x2,&x3}, {4.f, 5.f}, {6}); + OpArgsHolder holderBP1 = holderFF.createArgsHolderForBP({&grad}); + OpArgsHolder holderBP2 = holderFF.createArgsHolderForBP({&grad}, true); + + ASSERT_TRUE(holderBP1.getNumInArrs() == 4); + ASSERT_TRUE(holderBP1.getNumTArgs() == 2); + ASSERT_TRUE(holderBP1.getNumIArgs() == 1); + ASSERT_TRUE(holderBP2.getNumInArrs() == 4); + ASSERT_TRUE(holderBP2.getNumTArgs() == 2); + ASSERT_TRUE(holderBP2.getNumIArgs() == 1); + + const std::vector& isArrAllocBP1 = holderBP1.getAllocInfo(); + ASSERT_TRUE(isArrAllocBP1.size() == 0); + + const std::vector& isArrAllocBP2 = holderBP2.getAllocInfo(); + for(int i = 0; i < holderFF.getNumInArrs(); ++i) { + ASSERT_TRUE(static_cast(isArrAllocBP2[i]) == true); + } + + ASSERT_TRUE(static_cast(isArrAllocBP2[holderFF.getNumInArrs()+1]) == false); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, OpArgsHolder_test3) { + + auto input = NDArrayFactory::create('c', {2, 3}, {1.,2.,3.,4.,5.,6.}); + auto gradO = NDArrayFactory::create('c', {4, 9}); + auto exp = NDArrayFactory::create('c', {4, 9}, {1, 2, 3, 1, 2, 3, 1, 2, 3,4, 5, 6, 4, 5, 6, 4, 5, 6,1, 2, 3, 1, 2, 3, 1, 2, 3,4, 5, 6, 4, 5, 6, 4, 5, 6}); + auto gradIExp = NDArrayFactory::create('c', {2, 3}, {0.78, 0.84, 0.9,1.32, 1.38, 1.44}); + + gradO.linspace(0.01, 0.01); + + OpArgsHolder holderFF({&input}, {}, {2, 3}); + sd::ops::tile opFF; // the kind of op doesn't matter, we simply check here whether op.execute() works with OpArgsHolder correctly + auto results = opFF.execute(holderFF); + auto tiled = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(exp.isSameShape(tiled)); + ASSERT_TRUE(exp.equalsTo(tiled)); + + OpArgsHolder holderBP = holderFF.createArgsHolderForBP({&gradO}, true); + sd::ops::tile_bp opBP; + results = opBP.execute(holderBP); + auto gradI = results.at(0); + ASSERT_EQ(Status::OK(), results.status()); + ASSERT_TRUE(gradIExp.isSameShape(gradI)); + ASSERT_TRUE(gradIExp.equalsTo(gradI)); + +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, checkGrad_test1) { + + auto x = NDArrayFactory::create('c', {2, 3}, {0.1, 0.2, 0.3, 0.4, 0.5 ,0.6}); + auto gradO = NDArrayFactory::create('c', {2, 3}); + + const OpArgsHolder argsHolderFF({&x}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &gradO}, {}, {}); + + sd::ops::sigmoid opFF; + sd::ops::sigmoid_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, checkGrad_test2) { + + auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); + auto weights = NDArrayFactory::create('c', {2, 1, 2, 2}); + auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); + + x.linspace(1); + weights.linspace(0.1, 0.1); + weights.permutei({2,3,1,0}); + + const OpArgsHolder argsHolderFF({&x, &weights}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderBP({&x, &weights, &gradO}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); + + sd::ops::conv2d opFF; + sd::ops::conv2d_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, checkGrad_test3) { + + auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); + auto weights = NDArrayFactory::create('c', {2, 1, 2, 2}); + auto bias = NDArrayFactory::create('c', {2, 1}); + auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); + + x.linspace(1); + weights.linspace(0.1, 0.1); + bias = 0.5; + weights.permutei({2,3,1,0}); + + const OpArgsHolder argsHolderFF({&x, &weights, &bias}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderBP({&x, &weights, &bias, &gradO}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); + + sd::ops::conv2d opFF; + sd::ops::conv2d_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, checkGrad_test4) { + + auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); + auto weights = NDArrayFactory::create('c', {2, 1, 2, 2}); + auto bias = NDArrayFactory::create('c', {2, 1}); + auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); + + x.linspace(1); + weights.linspace(0.1, 0.1); + bias = 0.5; + weights.permutei({2,3,1,0}); + + const OpArgsHolder argsHolderFF({&x, &weights, &bias}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderBP({&x, &weights, &bias, &gradO}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); + + sd::ops::conv2d opFF; + sd::ops::conv2d_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 0, 1}); + + ASSERT_TRUE(isGradCorrect); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, checkGrad_test5) { + + auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); + auto weights = NDArrayFactory::create('c', {2, 1, 2, 2}); + auto bias = NDArrayFactory::create('c', {2, 1}); + auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); + + x.linspace(1); + weights.linspace(0.1, 0.1); + bias = 0.5; + weights.permutei({2,3,1,0}); + + const OpArgsHolder argsHolderFF({&x, &weights, &bias}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderBP({&x, &weights, &bias, &gradO}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); + + sd::ops::conv2d opFF; + sd::ops::conv2d_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 1, 1}, {0.5, 1}); + + ASSERT_TRUE(isGradCorrect); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, checkGrad_test6) { + + auto x = NDArrayFactory::create('c', {1, 1, 3, 3}); + auto weights = NDArrayFactory::create('c', {2, 1, 2, 2}); + auto bias = NDArrayFactory::create('c', {2, 1}); + auto gradO = NDArrayFactory::create('c', {1, 2, 3, 3}); + + x.linspace(1); + weights.linspace(0.1, 0.1); + bias = 0.5; + weights.permutei({2,3,1,0}); + + const OpArgsHolder argsHolderFF({&x, &weights, &bias}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); + const OpArgsHolder argsHolderBP({&x, &weights, &bias, &gradO}, {}, {2, 2, 1, 1, 0, 0, 1, 1, 1}); + + sd::ops::conv2d opFF; + sd::ops::conv2d_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {1, 0, 1}, {0.5, 1}, GradCheck::MEAN); + + ASSERT_TRUE(isGradCorrect); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, softMaxForVector_test1) { + + auto input = NDArrayFactory::create('c', {1,5}, {1,2,3,4,5}); + auto output = NDArrayFactory::create('c', {1,5}); + auto expOutput = NDArrayFactory::create('c', {1,5}); + expOutput = 1; + + ops::helpers::softmax(sd::LaunchContext ::defaultContext(), input, output, 0); + + ASSERT_TRUE(output.equalsTo(&expOutput)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, softMaxForVector_test2) { + + auto input = NDArrayFactory::create('c', {5,1}, {1,2,3,4,5}); + auto output = NDArrayFactory::create('c', {5,1}); + auto expOutput = NDArrayFactory::create('c', {5,1}, {0.01165623, 0.03168492, 0.08612854, 0.23412166, 0.63640865}); + + ops::helpers::softmax(sd::LaunchContext ::defaultContext(), input, output, 0); + + ASSERT_TRUE(output.equalsTo(&expOutput)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, softMaxForVector_test3) { + + auto input= NDArrayFactory::create('c', {5}, {1,2,3,4,5}); + auto output = NDArrayFactory::create('c', {5}); + auto expOutput = NDArrayFactory::create('c', {5}, {0.01165623, 0.03168492, 0.08612854, 0.23412166, 0.63640865}); + + ops::helpers::softmax(sd::LaunchContext ::defaultContext(), input, output, 0); + + ASSERT_TRUE(output.equalsTo(&expOutput)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, softMaxForVector_test4) { + + NDArray input('c', {1500}, sd::DataType::DOUBLE); + NDArray output('c', {1500}, sd::DataType::DOUBLE); + NDArray expOutput('c', {1500}, {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, +0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.00001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, +0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001,0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, +0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000001, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002,0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, +0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000002, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003,0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000003, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, +0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000004, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005, 0.000005,0.000005, 0.000005, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000006, 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, 0.000007, +0.000007, 0.000007, 0.000007, 0.000007, 0.000007, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000008, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009, 0.000009,0.000009, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000010, 0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000011, 0.000012, 0.000012, 0.000012, 0.000012, 0.000012, 0.000012, 0.000012, 0.000012, +0.000012, 0.000013, 0.000013, 0.000013, 0.000013, 0.000013, 0.000013, 0.000013, 0.000014, 0.000014, 0.000014, 0.000014, 0.000014, 0.000014, 0.000014, 0.000015, 0.000015, 0.000015, 0.000015, 0.000015, 0.000015, 0.000015, 0.000016, 0.000016, 0.000016, 0.000016, 0.000016, 0.000016,0.000017, 0.000017, 0.000017, 0.000017, 0.000017, 0.000017, 0.000018, 0.000018, 0.000018, 0.000018, 0.000018, 0.000018, 0.000019, 0.000019, 0.000019, 0.000019, 0.000019, 0.000020, 0.000020, 0.000020, 0.000020, 0.000020, 0.000021, 0.000021, 0.000021, 0.000021, 0.000021, 0.000022, +0.000022, 0.000022, 0.000022, 0.000023, 0.000023, 0.000023, 0.000023, 0.000023, 0.000024, 0.000024, 0.000024, 0.000024, 0.000025, 0.000025, 0.000025, 0.000025, 0.000026, 0.000026, 0.000026, 0.000026, 0.000027, 0.000027, 0.000027, 0.000028, 0.000028, 0.000028, 0.000028, 0.000029,0.000029, 0.000029, 0.000030, 0.000030, 0.000030, 0.000030, 0.000031, 0.000031, 0.000031, 0.000032, 0.000032, 0.000032, 0.000033, 0.000033, 0.000033, 0.000034, 0.000034, 0.000034, 0.000035, 0.000035, 0.000035, 0.000036, 0.000036, 0.000036, 0.000037, 0.000037, 0.000038, 0.000038, +0.000038, 0.000039, 0.000039, 0.000039, 0.000040, 0.000040, 0.000041, 0.000041, 0.000041, 0.000042, 0.000042, 0.000043, 0.000043, 0.000044, 0.000044, 0.000044, 0.000045, 0.000045, 0.000046, 0.000046, 0.000047, 0.000047, 0.000048, 0.000048, 0.000049, 0.000049, 0.000050, 0.000050,0.000051, 0.000051, 0.000052, 0.000052, 0.000053, 0.000053, 0.000054, 0.000054, 0.000055, 0.000055, 0.000056, 0.000057, 0.000057, 0.000058, 0.000058, 0.000059, 0.000059, 0.000060, 0.000061, 0.000061, 0.000062, 0.000063, 0.000063, 0.000064, 0.000064, 0.000065, 0.000066, 0.000066, +0.000067, 0.000068, 0.000068, 0.000069, 0.000070, 0.000070, 0.000071, 0.000072, 0.000073, 0.000073, 0.000074, 0.000075, 0.000076, 0.000076, 0.000077, 0.000078, 0.000079, 0.000079, 0.000080, 0.000081, 0.000082, 0.000083, 0.000084, 0.000084, 0.000085, 0.000086, 0.000087, 0.000088,0.000089, 0.000090, 0.000090, 0.000091, 0.000092, 0.000093, 0.000094, 0.000095, 0.000096, 0.000097, 0.000098, 0.000099, 0.000100, 0.000101, 0.000102, 0.000103, 0.000104, 0.000105, 0.000106, 0.000107, 0.000108, 0.000109, 0.000111, 0.000112, 0.000113, 0.000114, 0.000115, 0.000116, +0.000117, 0.000119, 0.000120, 0.000121, 0.000122, 0.000123, 0.000125, 0.000126, 0.000127, 0.000128, 0.000130, 0.000131, 0.000132, 0.000134, 0.000135, 0.000136, 0.000138, 0.000139, 0.000141, 0.000142, 0.000143, 0.000145, 0.000146, 0.000148, 0.000149, 0.000151, 0.000152, 0.000154,0.000155, 0.000157, 0.000158, 0.000160, 0.000162, 0.000163, 0.000165, 0.000167, 0.000168, 0.000170, 0.000172, 0.000173, 0.000175, 0.000177, 0.000179, 0.000180, 0.000182, 0.000184, 0.000186, 0.000188, 0.000190, 0.000192, 0.000194, 0.000195, 0.000197, 0.000199, 0.000201, 0.000203, +0.000205, 0.000208, 0.000210, 0.000212, 0.000214, 0.000216, 0.000218, 0.000220, 0.000223, 0.000225, 0.000227, 0.000229, 0.000232, 0.000234, 0.000236, 0.000239, 0.000241, 0.000244, 0.000246, 0.000248, 0.000251, 0.000253, 0.000256, 0.000259, 0.000261, 0.000264, 0.000266, 0.000269,0.000272, 0.000275, 0.000277, 0.000280, 0.000283, 0.000286, 0.000289, 0.000292, 0.000295, 0.000297, 0.000300, 0.000303, 0.000307, 0.000310, 0.000313, 0.000316, 0.000319, 0.000322, 0.000325, 0.000329, 0.000332, 0.000335, 0.000339, 0.000342, 0.000346, 0.000349, 0.000353, 0.000356, +0.000360, 0.000363, 0.000367, 0.000371, 0.000374, 0.000378, 0.000382, 0.000386, 0.000390, 0.000394, 0.000398, 0.000402, 0.000406, 0.000410, 0.000414, 0.000418, 0.000422, 0.000426, 0.000431, 0.000435, 0.000439, 0.000444, 0.000448, 0.000453, 0.000457, 0.000462, 0.000467, 0.000471,0.000476, 0.000481, 0.000486, 0.000490, 0.000495, 0.000500, 0.000505, 0.000510, 0.000516, 0.000521, 0.000526, 0.000531, 0.000537, 0.000542, 0.000547, 0.000553, 0.000559, 0.000564, 0.000570, 0.000576, 0.000581, 0.000587, 0.000593, 0.000599, 0.000605, 0.000611, 0.000617, 0.000623, +0.000630, 0.000636, 0.000642, 0.000649, 0.000655, 0.000662, 0.000669, 0.000675, 0.000682, 0.000689, 0.000696, 0.000703, 0.000710, 0.000717, 0.000724, 0.000732, 0.000739, 0.000746, 0.000754, 0.000762, 0.000769, 0.000777, 0.000785, 0.000793, 0.000801, 0.000809, 0.000817, 0.000825,0.000833, 0.000842, 0.000850, 0.000859, 0.000867, 0.000876, 0.000885, 0.000894, 0.000903, 0.000912, 0.000921, 0.000930, 0.000939, 0.000949, 0.000958, 0.000968, 0.000978, 0.000988, 0.000998, 0.001008, 0.001018, 0.001028, 0.001038, 0.001049, 0.001059, 0.001070, 0.001081, 0.001092, +0.001103, 0.001114, 0.001125, 0.001136, 0.001148, 0.001159, 0.001171, 0.001182, 0.001194, 0.001206, 0.001218, 0.001231, 0.001243, 0.001256, 0.001268, 0.001281, 0.001294, 0.001307, 0.001320, 0.001333, 0.001347, 0.001360, 0.001374, 0.001388, 0.001402, 0.001416, 0.001430, 0.001444,0.001459, 0.001473, 0.001488, 0.001503, 0.001518, 0.001534, 0.001549, 0.001565, 0.001580, 0.001596, 0.001612, 0.001628, 0.001645, 0.001661, 0.001678, 0.001695, 0.001712, 0.001729, 0.001746, 0.001764, 0.001782, 0.001800, 0.001818, 0.001836, 0.001854, 0.001873, 0.001892, 0.001911, +0.001930, 0.001950, 0.001969, 0.001989, 0.002009, 0.002029, 0.002049, 0.002070, 0.002091, 0.002112, 0.002133, 0.002155, 0.002176, 0.002198, 0.002220, 0.002242, 0.002265, 0.002288, 0.002311, 0.002334, 0.002357, 0.002381, 0.002405, 0.002429, 0.002454, 0.002478, 0.002503, 0.002528,0.002554, 0.002579, 0.002605, 0.002632, 0.002658, 0.002685, 0.002712, 0.002739, 0.002767, 0.002794, 0.002822, 0.002851, 0.002879, 0.002908, 0.002938, 0.002967, 0.002997, 0.003027, 0.003057, 0.003088, 0.003119, 0.003151, 0.003182, 0.003214, 0.003247, 0.003279, 0.003312, 0.003345, +0.003379, 0.003413, 0.003447, 0.003482, 0.003517, 0.003552, 0.003588, 0.003624, 0.003660, 0.003697, 0.003734, 0.003772, 0.003810, 0.003848, 0.003887, 0.003926, 0.003965, 0.004005, 0.004045, 0.004086, 0.004127, 0.004169, 0.004211, 0.004253, 0.004296, 0.004339, 0.004382, 0.004426,0.004471, 0.004516, 0.004561, 0.004607, 0.004653, 0.004700, 0.004747, 0.004795, 0.004843, 0.004892, 0.004941, 0.004991, 0.005041, 0.005092, 0.005143, 0.005194, 0.005247, 0.005299, 0.005353, 0.005406, 0.005461, 0.005516, 0.005571, 0.005627, 0.005684, 0.005741, 0.005798, 0.005857, +0.005916, 0.005975, 0.006035, 0.006096, 0.006157, 0.006219, 0.006281, 0.006345, 0.006408, 0.006473, 0.006538, 0.006603, 0.006670, 0.006737, 0.006805, 0.006873, 0.006942, 0.007012, 0.007082, 0.007153, 0.007225, 0.007298, 0.007371, 0.007445, 0.007520, 0.007596, 0.007672, 0.007749,0.007827, 0.007906, 0.007985, 0.008065, 0.008147, 0.008228, 0.008311, 0.008395, 0.008479, 0.008564, 0.008650, 0.008737, 0.008825, 0.008914, 0.009003, 0.009094, 0.009185, 0.009277, 0.009371, 0.009465, 0.009560, 0.009656, 0.009753, 0.009851, 0.009950}, sd::DataType::DOUBLE); + input.linspace(0.01, 0.01); + + ops::helpers::softmax(sd::LaunchContext ::defaultContext(), input, output, 0); + + ASSERT_TRUE(output.equalsTo(&expOutput)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, logSoftMaxForVector_test1) { + + auto input = NDArrayFactory::create('c', {1,5}, {1,2,3,4,5}); + auto output = NDArrayFactory::create('c', {1,5}); + auto expOutput = NDArrayFactory::create('c', {1,5}); + expOutput = 0; + + ops::helpers::logSoftmax(sd::LaunchContext ::defaultContext(), input, output, 0); + + ASSERT_TRUE(output.equalsTo(&expOutput)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, logSoftMaxForVector_test2) { + + auto input= NDArrayFactory::create('c', {5,1}, {1,2,3,4,5}); + auto output = NDArrayFactory::create('c', {5,1}); + auto expOutput = NDArrayFactory::create('c', {5,1}, {-4.4519144, -3.4519144, -2.4519144, -1.4519144, -0.4519144}); + + ops::helpers::logSoftmax(sd::LaunchContext ::defaultContext(), input, output, 0); + + ASSERT_TRUE(output.equalsTo(&expOutput)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, logSoftMaxForVector_test3) { + + auto input= NDArrayFactory::create('c', {5}, {1,2,3,4,5}); + auto output = NDArrayFactory::create('c', {5}); + auto expOutput = NDArrayFactory::create('c', {5}, {-4.4519144, -3.4519144, -2.4519144, -1.4519144, -0.4519144}); + + ops::helpers::logSoftmax(sd::LaunchContext ::defaultContext(), input, output, 0); + + ASSERT_TRUE(output.equalsTo(&expOutput)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, logSoftMaxForVector_test4) { + + NDArray input('c', {1500}, sd::DataType::DOUBLE); + NDArray output('c', {1500}, sd::DataType::DOUBLE); + NDArray expOutput('c', {1500}, {-8.154773, -8.153772, -8.152773, -8.151772, -8.150773, -8.149773, -8.148773, -8.147773, -8.146772, -8.145773, -8.144773, -8.143773, -8.142773, -8.141773, -8.140773, -8.139772, -8.138773, -8.137773, -8.136773, -8.135773, -8.134773, -8.133773, -8.132772, -8.131773, -8.130773, -8.129773, -8.128773, -8.127772, -8.126773, -8.125772, -8.124773, -8.123773, -8.122773, -8.121773, -8.120772, -8.119773, -8.118773, -8.117773, -8.116773, -8.115773, -8.114773, -8.113772, -8.112773, -8.111773, -8.110773, -8.109773, -8.108773, -8.107773, -8.106772, -8.105773, -8.104773, -8.103773, -8.102773, -8.101772, -8.100773, -8.099772, -8.098773, -8.097773, -8.096773, -8.095773, -8.094772, -8.093773, -8.092772, -8.091773, -8.090773, -8.089773, -8.088773, -8.087772, -8.086773, -8.085773, -8.084773, -8.083773, -8.082773, -8.081773, -8.080772, -8.079773, -8.078773, -8.077773, -8.076773, -8.075773, -8.074773, -8.073772, -8.072773, -8.071773, -8.070773, -8.069773, -8.068772, -8.067773, -8.066772, -8.065773, -8.064773, -8.063773, -8.062773, -8.061772, -8.060773, -8.059772, -8.058773, -8.057773, -8.056773, -8.055773, -8.054772, +-8.053773, -8.052773, -8.051773, -8.050773, -8.049773, -8.048773, -8.047772, -8.046773, -8.045773, -8.044773, -8.043773, -8.042773, -8.041773, -8.040772, -8.039773, -8.038773, -8.037773, -8.036773, -8.035772, -8.034773, -8.033772, -8.032773, -8.031773, -8.030773, -8.029773, -8.028772, -8.027773, -8.026772, -8.025773, -8.024773, -8.023773, -8.022773, -8.021772, -8.020773, -8.019773, -8.018773, -8.017773, -8.016773, -8.015773, -8.014772, -8.013773, -8.012773, -8.011773, -8.010773, -8.009773, -8.008773, -8.007772, -8.006773, -8.005773, -8.004773, -8.003773, -8.002772, -8.001773, -8.000772, -7.999773, -7.998773, -7.997773, -7.996773, -7.995773, -7.994773, -7.993773, -7.992773, -7.991773, -7.990773, -7.989773, -7.988773, -7.987773, -7.986773, -7.985773, -7.984773, -7.983773, -7.982773, -7.981773, -7.980773, -7.979773, -7.978773, -7.977773, -7.976773, -7.975773, -7.974773, -7.973773, -7.972773, -7.971773, -7.970773, -7.969773, -7.968773, -7.967773, -7.966773, -7.965773, -7.964773, -7.963773, -7.962773, -7.961773, -7.960773, -7.959773, -7.958773, -7.957773, -7.956773, -7.955773, -7.954773, -7.953773, -7.952773, +-7.951773, -7.950773, -7.949773, -7.948773, -7.947773, -7.946773, -7.945773, -7.944773, -7.943773, -7.942773, -7.941773, -7.940773, -7.939773, -7.938773, -7.937773, -7.936773, -7.935773, -7.934773, -7.933773, -7.932773, -7.931773, -7.930773, -7.929773, -7.928773, -7.927773, -7.926773, -7.925773, -7.924773, -7.923773, -7.922773, -7.921773, -7.920773, -7.919773, -7.918773, -7.917773, -7.916773, -7.915773, -7.914773, -7.913773, -7.912773, -7.911773, -7.910773, -7.909773, -7.908773, -7.907773, -7.906773, -7.905773, -7.904773, -7.903773, -7.902773, -7.901773, -7.900773, -7.899773, -7.898773, -7.897773, -7.896773, -7.895773, -7.894773, -7.893773, -7.892773, -7.891773, -7.890773, -7.889773, -7.888773, -7.887773, -7.886773, -7.885773, -7.884773, -7.883773, -7.882773, -7.881773, -7.880773, -7.879773, -7.878773, -7.877773, -7.876773, -7.875773, -7.874773, -7.873773, -7.872773, -7.871773, -7.870773, -7.869773, -7.868773, -7.867773, -7.866773, -7.865773, -7.864773, -7.863773, -7.862773, -7.861773, -7.860773, -7.859773, -7.858773, -7.857773, -7.856773, -7.855773, -7.854773, -7.853773, -7.852773, -7.851773, -7.850773, -7.849773, +-7.848773, -7.847773, -7.846773, -7.845773, -7.844773, -7.843773, -7.842773, -7.841773, -7.840773, -7.839773, -7.838773, -7.837773, -7.836773, -7.835773, -7.834773, -7.833773, -7.832773, -7.831773, -7.830773, -7.829773, -7.828773, -7.827773, -7.826773, -7.825773, -7.824773, -7.823773, -7.822773, -7.821773, -7.820773, -7.819773, -7.818773, -7.817773, -7.816773, -7.815773, -7.814773, -7.813773, -7.812773, -7.811773, -7.810773, -7.809773, -7.808773, -7.807773, -7.806773, -7.805773, -7.804773, -7.803773, -7.802773, -7.801773, -7.800773, -7.799773, -7.798773, -7.797773, -7.796773, -7.795773, -7.794773, -7.793773, -7.792773, -7.791773, -7.790773, -7.789773, -7.788773, -7.787773, -7.786773, -7.785773, -7.784773, -7.783773, -7.782773, -7.781773, -7.780773, -7.779773, -7.778773, -7.777773, -7.776773, -7.775773, -7.774773, -7.773773, -7.772773, -7.771773, -7.770773, -7.769773, -7.768773, -7.767773, -7.766773, -7.765773, -7.764773, -7.763773, -7.762773, -7.761773, -7.760773, -7.759773, -7.758773, -7.757773, -7.756773, -7.755773, -7.754773, -7.753773, -7.752773, -7.751773, -7.750773, -7.749773, -7.748773, -7.747773, -7.746773, +-7.745773, -7.744773, -7.743773, -7.742773, -7.741773, -7.740773, -7.739773, -7.738773, -7.737773, -7.736773, -7.735773, -7.734773, -7.733773, -7.732773, -7.731773, -7.730773, -7.729773, -7.728773, -7.727773, -7.726773, -7.725773, -7.724773, -7.723773, -7.722773, -7.721773, -7.720773, -7.719773, -7.718773, -7.717773, -7.716773, -7.715773, -7.714773, -7.713773, -7.712773, -7.711773, -7.710773, -7.709773, -7.708773, -7.707773, -7.706773, -7.705773, -7.704773, -7.703773, -7.702773, -7.701773, -7.700773, -7.699773, -7.698773, -7.697773, -7.696773, -7.695773, -7.694773, -7.693773, -7.692773, -7.691773, -7.690773, -7.689773, -7.688773, -7.687773, -7.686773, -7.685773, -7.684773, -7.683773, -7.682773, -7.681773, -7.680773, -7.679773, -7.678773, -7.677773, -7.676773, -7.675773, -7.674773, -7.673773, -7.672773, -7.671773, -7.670773, -7.669773, -7.668773, -7.667773, -7.666773, -7.665773, -7.664773, -7.663773, -7.662773, -7.661773, -7.660773, -7.659773, -7.658773, -7.657773, -7.656773, -7.655773, -7.654773, -7.653773, -7.652773, -7.651773, -7.650773, -7.649773, -7.648773, -7.647773, -7.646773, -7.645773, -7.644773, -7.643773, +-7.642773, -7.641773, -7.640773, -7.639773, -7.638773, -7.637773, -7.636773, -7.635773, -7.634773, -7.633773, -7.632773, -7.631773, -7.630773, -7.629773, -7.628773, -7.627773, -7.626773, -7.625773, -7.624773, -7.623773, -7.622773, -7.621773, -7.620773, -7.619773, -7.618773, -7.617773, -7.616773, -7.615773, -7.614773, -7.613773, -7.612773, -7.611773, -7.610773, -7.609773, -7.608773, -7.607773, -7.606773, -7.605773, -7.604773, -7.603773, -7.602773, -7.601773, -7.600773, -7.599773, -7.598773, -7.597773, -7.596773, -7.595773, -7.594773, -7.593773, -7.592773, -7.591773, -7.590773, -7.589773, -7.588773, -7.587773, -7.586773, -7.585773, -7.584773, -7.583773, -7.582773, -7.581773, -7.580773, -7.579773, -7.578773, -7.577773, -7.576773, -7.575773, -7.574773, -7.573773, -7.572773, -7.571773, -7.570773, -7.569773, -7.568773, -7.567773, -7.566773, -7.565773, -7.564773, -7.563773, -7.562773, -7.561773, -7.560773, -7.559773, -7.558773, -7.557773, -7.556773, -7.555773, -7.554773, -7.553773, -7.552773, -7.551773, -7.550773, -7.549773, -7.548773, -7.547773, -7.546773, -7.545773, -7.544773, -7.543773, -7.542773, -7.541773, -7.540773, +-7.539773, -7.538773, -7.537773, -7.536773, -7.535773, -7.534773, -7.533773, -7.532773, -7.531773, -7.530773, -7.529773, -7.528773, -7.527773, -7.526773, -7.525773, -7.524773, -7.523773, -7.522773, -7.521773, -7.520773, -7.519773, -7.518773, -7.517773, -7.516773, -7.515773, -7.514773, -7.513773, -7.512773, -7.511773, -7.510773, -7.509773, -7.508773, -7.507773, -7.506773, -7.505773, -7.504773, -7.503773, -7.502773, -7.501773, -7.500773, -7.499773, -7.498773, -7.497773, -7.496773, -7.495773, -7.494773, -7.493773, -7.492773, -7.491773, -7.490773, -7.489773, -7.488773, -7.487773, -7.486773, -7.485773, -7.484773, -7.483773, -7.482773, -7.481773, -7.480773, -7.479773, -7.478773, -7.477773, -7.476773, -7.475773, -7.474773, -7.473773, -7.472773, -7.471773, -7.470773, -7.469773, -7.468773, -7.467773, -7.466773, -7.465773, -7.464773, -7.463773, -7.462773, -7.461773, -7.460773, -7.459773, -7.458773, -7.457773, -7.456773, -7.455773, -7.454773, -7.453773, -7.452773, -7.451773, -7.450773, -7.449773, -7.448773, -7.447773, -7.446773, -7.445773, -7.444773, -7.443773, -7.442773, -7.441773, -7.440773, -7.439773, -7.438773, -7.437773, +-7.436773, -7.435773, -7.434773, -7.433773, -7.432773, -7.431773, -7.430773, -7.429773, -7.428773, -7.427773, -7.426773, -7.425773, -7.424773, -7.423773, -7.422773, -7.421773, -7.420773, -7.419773, -7.418773, -7.417773, -7.416773, -7.415773, -7.414773, -7.413773, -7.412773, -7.411773, -7.410773, -7.409773, -7.408773, -7.407773, -7.406773, -7.405773, -7.404773, -7.403773, -7.402773, -7.401773, -7.400773, -7.399773, -7.398773, -7.397773, -7.396773, -7.395773, -7.394773, -7.393773, -7.392773, -7.391773, -7.390773, -7.389773, -7.388773, -7.387773, -7.386773, -7.385773, -7.384773, -7.383773, -7.382773, -7.381773, -7.380773, -7.379773, -7.378773, -7.377773, -7.376773, -7.375773, -7.374773, -7.373773, -7.372773, -7.371773, -7.370773, -7.369773, -7.368773, -7.367773, -7.366773, -7.365773, -7.364773, -7.363773, -7.362773, -7.361773, -7.360773, -7.359773, -7.358773, -7.357773, -7.356773, -7.355773, -7.354773, -7.353773, -7.352773, -7.351773, -7.350773, -7.349773, -7.348773, -7.347773, -7.346773, -7.345773, -7.344773, -7.343773, -7.342773, -7.341773, -7.340773, -7.339773, -7.338773, -7.337773, -7.336773, -7.335773, -7.334773, +-7.333773, -7.332773, -7.331773, -7.330773, -7.329773, -7.328773, -7.327773, -7.326773, -7.325773, -7.324773, -7.323773, -7.322773, -7.321773, -7.320773, -7.319773, -7.318773, -7.317773, -7.316773, -7.315773, -7.314773, -7.313773, -7.312773, -7.311773, -7.310773, -7.309773, -7.308773, -7.307773, -7.306773, -7.305773, -7.304773, -7.303773, -7.302773, -7.301773, -7.300773, -7.299773, -7.298773, -7.297773, -7.296773, -7.295773, -7.294773, -7.293773, -7.292773, -7.291773, -7.290773, -7.289773, -7.288773, -7.287773, -7.286773, -7.285773, -7.284773, -7.283773, -7.282773, -7.281773, -7.280773, -7.279773, -7.278773, -7.277773, -7.276773, -7.275773, -7.274773, -7.273773, -7.272773, -7.271773, -7.270773, -7.269773, -7.268773, -7.267773, -7.266773, -7.265773, -7.264773, -7.263773, -7.262773, -7.261773, -7.260773, -7.259773, -7.258773, -7.257773, -7.256773, -7.255773, -7.254773, -7.253773, -7.252773, -7.251773, -7.250773, -7.249773, -7.248773, -7.247773, -7.246773, -7.245773, -7.244773, -7.243773, -7.242773, -7.241773, -7.240773, -7.239773, -7.238773, -7.237773, -7.236773, -7.235773, -7.234773, -7.233773, -7.232773, -7.231773, +-7.230773, -7.229773, -7.228773, -7.227773, -7.226773, -7.225773, -7.224773, -7.223773, -7.222773, -7.221773, -7.220773, -7.219773, -7.218773, -7.217773, -7.216773, -7.215773, -7.214773, -7.213773, -7.212773, -7.211773, -7.210773, -7.209773, -7.208773, -7.207773, -7.206773, -7.205773, -7.204773, -7.203773, -7.202773, -7.201773, -7.200773, -7.199773, -7.198773, -7.197773, -7.196773, -7.195773, -7.194773, -7.193773, -7.192773, -7.191773, -7.190773, -7.189773, -7.188773, -7.187773, -7.186773, -7.185773, -7.184773, -7.183773, -7.182773, -7.181773, -7.180773, -7.179773, -7.178773, -7.177773, -7.176773, -7.175773, -7.174773, -7.173773, -7.172773, -7.171773, -7.170773, -7.169773, -7.168773, -7.167773, -7.166773, -7.165773, -7.164773, -7.163773, -7.162773, -7.161773, -7.160773, -7.159773, -7.158773, -7.157773, -7.156773, -7.155773, -7.154773, -7.153773, -7.152773, -7.151773, -7.150773, -7.149773, -7.148773, -7.147773, -7.146773, -7.145773, -7.144773, -7.143773, -7.142773, -7.141773, -7.140773, -7.139773, -7.138773, -7.137773, -7.136773, -7.135773, -7.134773, -7.133773, -7.132773, -7.131773, -7.130773, -7.129773, -7.128773, +-7.127773, -7.126773, -7.125773, -7.124773, -7.123773, -7.122773, -7.121773, -7.120773, -7.119773, -7.118773, -7.117773, -7.116773, -7.115773, -7.114773, -7.113773, -7.112773, -7.111773, -7.110773, -7.109773, -7.108773, -7.107773, -7.106773, -7.105773, -7.104773, -7.103773, -7.102773, -7.101773, -7.100773, -7.099773, -7.098773, -7.097773, -7.096773, -7.095773, -7.094773, -7.093773, -7.092773, -7.091773, -7.090773, -7.089773, -7.088773, -7.087773, -7.086773, -7.085773, -7.084773, -7.083773, -7.082773, -7.081773, -7.080773, -7.079773, -7.078773, -7.077773, -7.076773, -7.075773, -7.074773, -7.073773, -7.072773, -7.071773, -7.070773, -7.069773, -7.068773, -7.067773, -7.066773, -7.065773, -7.064773, -7.063773, -7.062773, -7.061773, -7.060773, -7.059773, -7.058773, -7.057773, -7.056773, -7.055773, -7.054773, -7.053773, -7.052773, -7.051773, -7.050773, -7.049773, -7.048773, -7.047773, -7.046773, -7.045773, -7.044773, -7.043773, -7.042773, -7.041773, -7.040773, -7.039773, -7.038773, -7.037773, -7.036773, -7.035773, -7.034773, -7.033773, -7.032773, -7.031773, -7.030773, -7.029773, -7.028773, -7.027773, -7.026773, -7.025773, +-7.024773, -7.023773, -7.022773, -7.021773, -7.020773, -7.019773, -7.018773, -7.017773, -7.016773, -7.015773, -7.014773, -7.013773, -7.012773, -7.011773, -7.010773, -7.009773, -7.008773, -7.007773, -7.006773, -7.005773, -7.004773, -7.003773, -7.002773, -7.001773, -7.000773, -6.999773, -6.998773, -6.997773, -6.996773, -6.995773, -6.994773, -6.993773, -6.992773, -6.991773, -6.990773, -6.989773, -6.988773, -6.987773, -6.986773, -6.985773, -6.984773, -6.983773, -6.982773, -6.981773, -6.980773, -6.979773, -6.978773, -6.977773, -6.976773, -6.975773, -6.974773, -6.973773, -6.972773, -6.971773, -6.970773, -6.969773, -6.968773, -6.967773, -6.966773, -6.965773, -6.964773, -6.963773, -6.962773, -6.961773, -6.960773, -6.959773, -6.958773, -6.957773, -6.956773, -6.955773, -6.954773, -6.953773, -6.952773, -6.951773, -6.950773, -6.949773, -6.948773, -6.947773, -6.946773, -6.945773, -6.944773, -6.943773, -6.942773, -6.941773, -6.940773, -6.939773, -6.938773, -6.937773, -6.936773, -6.935773, -6.934773, -6.933773, -6.932773, -6.931773, -6.930773, -6.929773, -6.928773, -6.927773, -6.926773, -6.925773, -6.924773, -6.923773, -6.922773, +-6.921773, -6.920773, -6.919773, -6.918773, -6.917773, -6.916773, -6.915773, -6.914773, -6.913773, -6.912773, -6.911773, -6.910773, -6.909773, -6.908773, -6.907773, -6.906773, -6.905773, -6.904773, -6.903773, -6.902773, -6.901773, -6.900773, -6.899773, -6.898773, -6.897773, -6.896773, -6.895773, -6.894773, -6.893773, -6.892773, -6.891773, -6.890773, -6.889773, -6.888773, -6.887773, -6.886773, -6.885773, -6.884773, -6.883773, -6.882773, -6.881773, -6.880773, -6.879773, -6.878773, -6.877773, -6.876773, -6.875773, -6.874773, -6.873773, -6.872773, -6.871773, -6.870773, -6.869773, -6.868773, -6.867773, -6.866773, -6.865773, -6.864773, -6.863773, -6.862773, -6.861773, -6.860773, -6.859773, -6.858773, -6.857773, -6.856773, -6.855773, -6.854773, -6.853773, -6.852773, -6.851773, -6.850773, -6.849773, -6.848773, -6.847773, -6.846773, -6.845773, -6.844773, -6.843773, -6.842773, -6.841773, -6.840773, -6.839773, -6.838773, -6.837773, -6.836773, -6.835773, -6.834773, -6.833773, -6.832773, -6.831773, -6.830773, -6.829773, -6.828773, -6.827773, -6.826773, -6.825773, -6.824773, -6.823773, -6.822773, -6.821773, -6.820773, -6.819773, +-6.818773, -6.817773, -6.816773, -6.815773, -6.814773, -6.813773, -6.812773, -6.811773, -6.810773, -6.809773, -6.808773, -6.807773, -6.806773, -6.805773, -6.804773, -6.803773, -6.802773, -6.801773, -6.800773, -6.799773, -6.798773, -6.797773, -6.796773, -6.795773, -6.794773, -6.793773, -6.792773, -6.791773, -6.790773, -6.789773, -6.788773, -6.787773, -6.786773, -6.785773, -6.784773, -6.783773, -6.782773, -6.781773, -6.780773, -6.779773, -6.778773, -6.777773, -6.776773, -6.775773, -6.774773, -6.773773, -6.772773, -6.771773, -6.770773, -6.769773, -6.768773, -6.767773, -6.766773, -6.765773, -6.764773, -6.763773, -6.762773, -6.761773, -6.760773, -6.759773, -6.758773, -6.757773, -6.756773, -6.755773, -6.754773, -6.753773, -6.752773, -6.751773, -6.750773, -6.749773, -6.748773, -6.747773, -6.746773, -6.745773, -6.744773, -6.743773, -6.742773, -6.741773, -6.740773, -6.739773, -6.738773, -6.737773, -6.736773, -6.735773, -6.734773, -6.733773, -6.732773, -6.731773, -6.730773, -6.729773, -6.728773, -6.727773, -6.726773, -6.725773, -6.724773, -6.723773, -6.722773, -6.721773, -6.720773, -6.719773, -6.718773, -6.717773, -6.716773, -6.715773, +-6.714773, -6.713773, -6.712773, -6.711773, -6.710773, -6.709773, -6.708773, -6.707773, -6.706773, -6.705773, -6.704773, -6.703773, -6.702773, -6.701773, -6.700773, -6.699773, -6.698773, -6.697773, -6.696773, -6.695773, -6.694773, -6.693773, -6.692773, -6.691773, -6.690773, -6.689773, -6.688773, -6.687773, -6.686773, -6.685773, -6.684773, -6.683773, -6.682773, -6.681773, -6.680773, -6.679773, -6.678773, -6.677773, -6.676773, -6.675773, -6.674773, -6.673773, -6.672773, -6.671773, -6.670773, -6.669773, -6.668773, -6.667773, -6.666773, -6.665773, -6.664773, -6.663773, -6.662773, -6.661773, -6.660773, -6.659773, -6.658773, -6.657773, -6.656773, -6.655773}, sd::DataType::DOUBLE); + input.linspace(0.01, 0.001); + + ops::helpers::logSoftmax(sd::LaunchContext ::defaultContext(), input, output, 0); + + ASSERT_TRUE(output.equalsTo(&expOutput)); +} + + +////////////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, mmulMxV_1) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('f', {M,N}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + NDArray temp('f', {M,N,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(6, {0,2}); + NDArray y('f', {M}, sd::DataType::DOUBLE); + + NDArray exp('f', {M}, {5.5, 5.1, 4.7}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, mmulMxV_2) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + a.permutei({1,0}); + NDArray temp('f', {M,N,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(6, {0,2}); + NDArray y('f', {M}, sd::DataType::DOUBLE); + + NDArray exp('f', {M}, {5.1, 3.3, 1.5}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, mmulMxV_3) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + a.permutei({1,0}); + NDArray temp('f', {N,M,5}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(4, {1,2}); + NDArray y('f', {M}, sd::DataType::DOUBLE); + + NDArray exp('f', {M}, {6.2, 4.5, 1.7}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, mmulMxV_4) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('f', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + a.permutei({1,0}); + NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(3, {0,1}); + NDArray y('f', {M}, sd::DataType::DOUBLE); + + NDArray exp('f', {M}, {1.5, 1.8, 1.5}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, mmulMxV_5) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + a.permutei({1,0}); + NDArray temp('f', {5,M,N}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(2, {0,1}); + NDArray y('f', {M}, sd::DataType::DOUBLE); + + NDArray exp('f', {M}, {-0.3, 0.3, 0.9}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, mmulMxV_6) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + a.permutei({1,0}); + NDArray temp('c', {5,N,M}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(13, {0,2}); + NDArray y('f', {M}, sd::DataType::DOUBLE); + + NDArray exp('f', {M}, {-12.1, -10.9, -9.7}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, mmulMxV_7) { + + const Nd4jLong M = 3; + const Nd4jLong N = 4; + + NDArray a('c', {N,M}, {1.2,1.1,1.0,0.9,0.8,0.7,0.5,0.4,0.3,0.2,0.1,0}, sd::DataType::DOUBLE); + a.permutei({1,0}); + NDArray temp('c', {5,N,M}, {16,2,-6,7,2,-2,4,-7,6,4,4,6,-3,1,3,9,1,4,9,10,-10,-3,-8,7,-7,-7,6,9,7,-6,8,7,-3,-3,4,-2,5,-3,-3,4,6,-5,-1,7,-5,4,-10,-1,8,0,-7,4,-10,-7,-8,-9,2,9,7,9}, sd::DataType::DOUBLE); + NDArray x = temp(10, {0,2}); + NDArray y('c', {M}, sd::DataType::DOUBLE); + + NDArray exp('c', {M}, {3.3, 3.3, 3.3}, sd::DataType::DOUBLE); + + sd::MmulHelper::mmul(&a, &x, &y, 1., 0.); + ASSERT_TRUE(y.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, softmaxDerivative_1) { + + NDArray input('c', {3,3}, {-1, 1, -2, 2, -3, 3, -4, 4, 5.}, sd::DataType::DOUBLE); + NDArray expOutput('c', {3,3}, {0.04508, 0.04514, 0.0008 , 0.0472 , 0.00087, 0.10492, 0.00235, 0.04592, 0.10553}, sd::DataType::DOUBLE); + NDArray output('c', {3,3}, sd::DataType::DOUBLE); + + // input.applyTransform(sd::transform::SoftMaxDerivative, &output); + + sd::ops::helpers::softmaxDerivative(input.getContext(), input, output, 0); + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, softmaxDerivative_2) { + + NDArray input('c', {3,3,3}, {-1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14.}, sd::DataType::DOUBLE); + NDArray expOutput('c', {3,3,3}, {4.50755e-02, 4.51394e-02, 6.64586e-03,4.72027e-02, 8.67128e-04, 6.97440e-03,2.35008e-03, 4.59243e-02, 3.32995e-04, + 4.51766e-02, 2.26032e-06, 4.51767e-02,2.91394e-07, 2.37285e-06, 3.94360e-08,4.51769e-02, 1.12535e-07, 4.51767e-02, + 7.58256e-10, 4.51767e-02, 1.22325e-11,7.96007e-10, 1.32293e-11, 1.04994e-01,3.77513e-11, 4.51767e-02, 1.04994e-01}, sd::DataType::DOUBLE); + NDArray output('c', {3,3,3}, sd::DataType::DOUBLE); + + // input.applyTransform(sd::transform::SoftMaxDerivative, &output); + + sd::ops::helpers::softmaxDerivative(input.getContext(), input, output, 1); + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, softmaxDerivative_3) { + + NDArray input('c', {5}, {-1., 1, -2, 2, 3}, sd::DataType::DOUBLE); + NDArray expOutput('c', {5}, {0.01184, 0.08071, 0.00439, 0.18277, 0.22618}, sd::DataType::DOUBLE); + NDArray output('c', {5}, sd::DataType::DOUBLE); + + // input.applyTransform(sd::transform::SoftMaxDerivative, &output); + + sd::ops::helpers::softmaxDerivative(input.getContext(), input, output, 0); + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, lstmLayerCell_1) { + + const int bS = 2; + const int nIn = 10; + const int nOut = 4; + + const float dataFormat = 0; // is ignored in cell op + const float cellClip = 5; // clipping value + const float gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid + const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid + const float cellAct = 0; // tanh activation for cell state + const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh + const float cellBeta = 0; // beta value for cell state activation, not required for tanh + const float outAct = 0; // tanh activation for output + const float outAlpha = 0; // alpha value for output activation, not required for tanh + const float outBeta = 0; // beta value for output activation, not required for tanh + + NDArray x ('c', {bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); + NDArray b ('c', {4*nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray Wp('c', {3*nOut}, sd::DataType::FLOAT32); + + NDArray h('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray c('c', {bS, nOut}, sd::DataType::FLOAT32); + + NDArray expH('c', {bS, nOut}, {0.999288, 0.999288, 0.999288, 0.999288, 0.999288, 0.999288, 0.999288, 0.999288}, sd::DataType::FLOAT32); + NDArray expC('c', {bS, nOut}, {3.999778, 3.999778, 3.999778, 3.999778, 3.999778, 3.999778, 3.999778, 3.999778}, sd::DataType::FLOAT32); + + std::vector params = {dataFormat, 0, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; + + x = 1.; + hI = 2.; + cI = 3.; + Wx = 0.5; + Wr = 0.4; + Wp = 0.3; + b = 0.7; + + sd::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, &c); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expC.isSameShape(c)); + ASSERT_TRUE(expC.equalsTo(c)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, lstmLayerCell_2) { + + const int bS = 2; + const int nIn = 10; + const int nOut = 4; + + const float dataFormat = 0; // is ignored in cell op + const float cellClip = 3; // clipping value + const float gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid + const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid + const float cellAct = 0; // tanh activation for cell state + const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh + const float cellBeta = 0; // beta value for cell state activation, not required for tanh + const float outAct = 0; // tanh activation for output + const float outAlpha = 0; // alpha value for output activation, not required for tanh + const float outBeta = 0; // beta value for output activation, not required for tanh + + NDArray x ('c', {bS, nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); + NDArray b ('c', {4*nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray Wp('c', {3*nOut}, sd::DataType::FLOAT32); + + NDArray h('c', {bS, nOut}, sd::DataType::FLOAT32); + NDArray c('c', {bS, nOut}, sd::DataType::FLOAT32); + + NDArray expH('c', {bS, nOut}, {0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995, 0.995}, sd::DataType::FLOAT32); + NDArray expC('c', {bS, nOut}, {3., 3., 3., 3., 3., 3., 3., 3.}, sd::DataType::FLOAT32); + + std::vector params = {dataFormat, 0, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; + + x = 1.; + hI = 2.; + cI = 3.; + Wx = 0.5; + Wr = 0.4; + Wp = 0.3; + b = 0.7; + + sd::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, &c); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expC.isSameShape(c)); + ASSERT_TRUE(expC.equalsTo(c)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests1, lstmLayerCell_3) { + + const int nIn = 10; + const int nOut = 4; + + const float dataFormat = 0; // is ignored in cell op + const float cellClip = 5; // clipping value + const float gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid + const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid + const float cellAct = 0; // tanh activation for cell state + const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh + const float cellBeta = 0; // beta value for cell state activation, not required for tanh + const float outAct = 0; // tanh activation for output + const float outAlpha = 0; // alpha value for output activation, not required for tanh + const float outBeta = 0; // beta value for output activation, not required for tanh + + NDArray x ('c', {nIn}, sd::DataType::FLOAT32); + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::FLOAT32); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::FLOAT32); + NDArray b ('c', {4*nOut}, sd::DataType::FLOAT32); + NDArray hI('c', {nOut}, sd::DataType::FLOAT32); + NDArray cI('c', {nOut}, sd::DataType::FLOAT32); + NDArray Wp('c', {3*nOut}, sd::DataType::FLOAT32); + + NDArray h('c', {nOut}, sd::DataType::FLOAT32); + NDArray c('c', {nOut}, sd::DataType::FLOAT32); + + NDArray expH('c', {nOut}, {0.999288, 0.999288, 0.999288, 0.999288}, sd::DataType::FLOAT32); + NDArray expC('c', {nOut}, {3.999778, 3.999778, 3.999778, 3.999778}, sd::DataType::FLOAT32); + + std::vector params = {dataFormat, 0, cellClip, gateAct, gateAlpha, gateBeta, cellAct, cellAlpha, cellBeta, outAct, outAlpha, outBeta}; + + x = 1.; + hI = 2.; + cI = 3.; + Wx = 0.5; + Wr = 0.4; + Wp = 0.3; + b = 0.7; + + sd::ops::helpers::lstmLayerCell(&x, &Wx, &Wr, &b, &hI, &cI, &Wp, params, &h, &c); + + ASSERT_TRUE(expH.isSameShape(h)); + ASSERT_TRUE(expH.equalsTo(h)); + ASSERT_TRUE(expC.isSameShape(c)); + ASSERT_TRUE(expC.equalsTo(c)); +} diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/HelpersTests2.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/HelpersTests2.cpp new file mode 100644 index 000000000..110617937 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/HelpersTests2.cpp @@ -0,0 +1,429 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +#include "testlayers.h" +#include +#include +#include +#include +#include + +using namespace sd; + +class HelpersTests2 : public testing::Test { +public: + + HelpersTests2() { + + std::cout< hess1(x1); + ASSERT_TRUE(hess1._H.isSameShape(&x1)); + ASSERT_TRUE(hess1._H.equalsTo(&x1)); + ASSERT_TRUE(hess1._Q.isSameShape(&expQ)); + ASSERT_TRUE(hess1._Q.equalsTo(&expQ)); + + ops::helpers::Hessenberg hess2(x2); + ASSERT_TRUE(hess2._H.isSameShape(&x2)); + ASSERT_TRUE(hess2._H.equalsTo(&x2)); + ASSERT_TRUE(hess2._Q.isSameShape(&expQ)); + ASSERT_TRUE(hess2._Q.equalsTo(&expQ)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests2, Hessenberg_2) { + + NDArray x('c', {2,2}, {1.5,-2,17,5}, sd::DataType::DOUBLE); + NDArray expQ('c', {2,2}, {1,0,0,1}, sd::DataType::DOUBLE); + + ops::helpers::Hessenberg hess(x); + + // hess._H.printBuffer(); + + ASSERT_TRUE(hess._H.isSameShape(&x)); + ASSERT_TRUE(hess._H.equalsTo(&x)); + + ASSERT_TRUE(hess._Q.isSameShape(&expQ)); + ASSERT_TRUE(hess._Q.equalsTo(&expQ)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests2, Hessenberg_3) { + + NDArray x('c', {3,3}, {33,24,-48,57,12.5,-3,1.1,10,-5.2}, sd::DataType::DOUBLE); + NDArray expH('c', {3,3}, {33, -23.06939, -48.45414, -57.01061, 12.62845, 3.344058, 0, -9.655942, -5.328448}, sd::DataType::DOUBLE); + NDArray expQ('c', {3,3}, {1,0,0,0, -0.99981, -0.019295, 0, -0.019295,0.99981}, sd::DataType::DOUBLE); + + ops::helpers::Hessenberg hess(x); + + ASSERT_TRUE(hess._H.isSameShape(&expH)); + ASSERT_TRUE(hess._H.equalsTo(&expH)); + + ASSERT_TRUE(hess._Q.isSameShape(&expQ)); + ASSERT_TRUE(hess._Q.equalsTo(&expQ)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests2, Hessenberg_4) { + + NDArray x('c', {4,4}, {0.33 ,-7.25 ,1.71 ,6.20 ,1.34 ,5.38 ,-2.76 ,-8.51 ,7.59 ,3.44 ,2.24 ,-6.82 ,-1.15 ,4.80 ,-4.67 ,2.14}, sd::DataType::DOUBLE); + NDArray expH('c', {4,4}, {0.33, 0.4961181, 3.51599, 9.017665, -7.792702, 4.190221, 6.500328, 5.438888, 0, 3.646734, 0.4641911, -7.635502, 0,0, 5.873535, 5.105588}, sd::DataType::DOUBLE); + NDArray expQ('c', {4,4}, {1,0,0,0, 0,-0.171956, 0.336675, -0.925787, 0,-0.973988,0.0826795, 0.210976, 0, 0.147574, 0.937984,0.3137}, sd::DataType::DOUBLE); + + ops::helpers::Hessenberg hess(x); + + ASSERT_TRUE(hess._H.isSameShape(&expH)); + ASSERT_TRUE(hess._H.equalsTo(&expH)); + + ASSERT_TRUE(hess._Q.isSameShape(&expQ)); + ASSERT_TRUE(hess._Q.equalsTo(&expQ)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests2, Hessenberg_5) { + + NDArray x('c', {10,10}, {6.9 ,4.8 ,9.5 ,3.1 ,6.5 ,5.8 ,-0.9 ,-7.3 ,-8.1 ,3.0 ,0.1 ,9.9 ,-3.2 ,6.4 ,6.2 ,-7.0 ,5.5 ,-2.2 ,-4.0 ,3.7 ,-3.6 ,9.0 ,-1.4 ,-2.4 ,1.7 , + -6.1 ,-4.2 ,-2.5 ,-5.6 ,-0.4 ,0.4 ,9.1 ,-2.1 ,-5.4 ,7.3 ,3.6 ,-1.7 ,-5.7 ,-8.0 ,8.8 ,-3.0 ,-0.5 ,1.1 ,10.0 ,8.0 ,0.8 ,1.0 ,7.5 ,3.5 ,-1.8 , + 0.3 ,-0.6 ,-6.3 ,-4.5 ,-1.1 ,1.8 ,0.6 ,9.6 ,9.2 ,9.7 ,-2.6 ,4.3 ,-3.4 ,0.0 ,-6.7 ,5.0 ,10.5 ,1.5 ,-7.8 ,-4.1 ,-5.3 ,-5.0 ,2.0 ,-4.4 ,-8.4 , + 6.0 ,-9.4 ,-4.8 ,8.2 ,7.8 ,5.2 ,-9.5 ,-3.9 ,0.2 ,6.8 ,5.7 ,-8.5 ,-1.9 ,-0.3 ,7.4 ,-8.7 ,7.2 ,1.3 ,6.3 ,-3.7 ,3.9 ,3.3 ,-6.0 ,-9.1 ,5.9}, sd::DataType::DOUBLE); + NDArray expH('c', {10,10}, {6.9, 6.125208, -8.070945, 7.219828, -9.363308, 2.181236, 5.995414, 3.892612, 4.982657, -2.088574,-12.6412, 1.212547, -6.449684, 5.162879, 0.4341714, -5.278079, -2.624011, -2.03615, 11.39619, -3.034842, + 0, -12.71931, 10.1146, 6.494434, -1.062934, 5.668906, -4.672953, -9.319893, -2.023392, 6.090341,0,0, 7.800521, -1.46286, 1.484626, -10.58252, -3.492978, 2.42187, 5.470045, 1.877265, + 0,0,0, 14.78259,-0.3147726, -5.74874, -0.377823, 3.310056, 2.242614, -5.111574,0,0,0,0, -9.709131, 3.885072, 6.762626, 4.509144, 2.390195, -4.991013, + 0,0,0,0,0, 8.126269, -12.32529, 9.030151, 1.390931, 0.8634045,0,0,0,0,0,0, -12.99477, 9.574299,-0.3098022, 4.910835,0,0,0,0,0,0,0, 14.75256, 18.95723, -5.054717,0,0,0,0,0,0,0,0, -4.577715, -5.440827,}, sd::DataType::DOUBLE); + NDArray expQ('c', {10,10}, {1,0,0,0,0,0,0,0,0,0,0,-0.0079106,-0.38175,-0.39287,-0.26002,-0.44102,-0.071516,0.12118,0.64392,0.057562, + 0,0.28478,0.0058784,0.3837,-0.47888,0.39477,0.0036847,-0.24678,0.3229,0.47042,0,-0.031643,-0.61277,0.087648,0.12014,0.47648,-0.5288,0.060599,0.021434,-0.30102, + 0,0.23732,-0.17801,-0.31809,-0.31267,0.27595,0.30134,0.64555,-0.33392,0.13363,0,-0.023732,-0.40236,0.43089,-0.38692,-0.5178,-0.03957,-0.081667,-0.47515,-0.0077949, + 0,0.20568,-0.0169,0.36962,0.49669,-0.22475,-0.22199,0.50075,0.10454,0.46112,0,0.41926,0.30243,-0.3714,-0.16795,-0.12969,-0.67572,-0.1205,-0.26047,0.10407, + 0,-0.41135,-0.28357,-0.33858,0.18836,0.083822,-0.0068213,-0.30161,-0.24956,0.66327,0,0.68823,-0.33616,-0.12129,0.36163,-0.063256,0.34198,-0.37564,-0.048196,-0.058948}, sd::DataType::DOUBLE); + + ops::helpers::Hessenberg hess(x); + + ASSERT_TRUE(hess._H.isSameShape(&expH)); + ASSERT_TRUE(hess._H.equalsTo(&expH)); + + ASSERT_TRUE(hess._Q.isSameShape(&expQ)); + ASSERT_TRUE(hess._Q.equalsTo(&expQ)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests2, Schur_1) { + + NDArray x('c', {3,3}, sd::DataType::DOUBLE); + + NDArray expT('c', {3,3}, {-2.5, -2, 1, 0, 1.5, -2, 3, 4, 5}, sd::DataType::DOUBLE); + NDArray expU('c', {3,3}, {0.3, 0.2,-0.1, 0,-0.1, 0.2, -0.3,-0.4, 0.5}, sd::DataType::DOUBLE); + + ops::helpers::Schur schur(x); + schur._T.linspace(-3, 1); + schur._U.linspace(-0.3, 0.1); + + schur.splitTwoRows(1, 0.5); + + ASSERT_TRUE(schur._T.isSameShape(&expT)); + ASSERT_TRUE(schur._T.equalsTo(&expT)); + + ASSERT_TRUE(schur._U.isSameShape(&expU)); + ASSERT_TRUE(schur._U.equalsTo(&expU)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests2, Schur_2) { + + NDArray x('c', {3,3}, sd::DataType::DOUBLE); + + NDArray shift('c', {3}, sd::DataType::DOUBLE); + NDArray exp1('c', {3}, {1,-3,0}, sd::DataType::DOUBLE); + NDArray exp2('c', {3}, {3, 3,-7}, sd::DataType::DOUBLE); + NDArray exp3('c', {3}, {0.964,0.964,0.964}, sd::DataType::DOUBLE); + NDArray exp1T('c', {3,3}, {-3,-2,-1,0,1,2,3,4,5}, sd::DataType::DOUBLE); + NDArray exp2T('c', {3,3}, {-8,-2,-1,0,-4,2,3,4,0}, sd::DataType::DOUBLE); + NDArray exp3T('c', {3,3}, {-9.464102,-2,-1,0,-5.464102,2,3,4,-1.464102,}, sd::DataType::DOUBLE); + + ops::helpers::Schur schur(x); + // schur._U.linspace(-0.3, 0.1); // doesn't matter + + schur._T.linspace(-3, 1); + double expShift =0; + schur.calcShift(1, 5, expShift, shift); + ASSERT_TRUE(schur._T.equalsTo(&exp1T)); + ASSERT_TRUE(shift.isSameShape(&exp1)); + ASSERT_TRUE(shift.equalsTo(&exp1)); + ASSERT_TRUE(expShift == 0); + + schur._T.linspace(-3, 1); + expShift = 0; + schur.calcShift(2, 10, expShift, shift); + ASSERT_TRUE(schur._T.equalsTo(&exp2T)); + ASSERT_TRUE(shift.isSameShape(&exp2)); + ASSERT_TRUE(shift.equalsTo(&exp2)); + ASSERT_TRUE(expShift == 5); + + schur._T.linspace(-3, 1); + expShift = 0; + schur.calcShift(2, 30, expShift, shift); + ASSERT_TRUE(schur._T.equalsTo(&exp3T)); + ASSERT_TRUE(shift.isSameShape(&exp3)); + ASSERT_TRUE(shift.equalsTo(&exp3)); + ASSERT_TRUE((6.4641-0.00001) < expShift && expShift < (6.4641+0.00001)); +} + + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests2, Schur_3) { + + NDArray x('c', {2,2}, {1.5,-2,17,5}, sd::DataType::DOUBLE); + NDArray expU('c', {2,2}, {1,0,0,1}, sd::DataType::DOUBLE); + + ops::helpers::Schur schur(x); + + ASSERT_TRUE(schur._T.isSameShape(&x)); + ASSERT_TRUE(schur._T.equalsTo(&x)); + + ASSERT_TRUE(schur._U.isSameShape(&expU)); + ASSERT_TRUE(schur._U.equalsTo(&expU)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests2, Schur_4) { + + NDArray x('c', {3,3}, {33,24,-48,57,12.5,-3,1.1,10,-5.2}, sd::DataType::DOUBLE); + NDArray expT('c', {3,3}, {53.73337,-20.21406,-50.44809,0,-27.51557, 26.74307,0,0,14.0822}, sd::DataType::DOUBLE); + NDArray expU('c', {3,3}, {-0.5848506, 0.7185352, 0.3763734,-0.7978391,-0.5932709,-0.1071558,-0.1462962, 0.3629555,-0.9202504}, sd::DataType::DOUBLE); + + ops::helpers::Schur schur(x); + + ASSERT_TRUE(schur._T.isSameShape(&expT)); + ASSERT_TRUE(schur._T.equalsTo(&expT)); + + ASSERT_TRUE(schur._U.isSameShape(&expU)); + ASSERT_TRUE(schur._U.equalsTo(&expU)); +} + +/* +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests2, Schur_5) { + + NDArray x('c', {4,4}, {0.33 ,-7.25 ,1.71 ,6.20 ,1.34 ,5.38 ,-2.76 ,-8.51 ,7.59 ,3.44 ,2.24 ,-6.82 ,-1.15 ,4.80 ,-4.67 ,2.14}, sd::DataType::DOUBLE); + NDArray expT('c', {4,4}, {6.940177,7.201107,2.523849,-8.534745,-3.109643,5.289615,-2.940507,9.330303, 0,0,-0.1740346, 7.19851,0,0, -2.870214, -1.965758}, sd::DataType::DOUBLE); + NDArray expU('c', {4,4}, {-0.2602141, 0.8077556,-0.3352316,-0.4091935,0.3285353,-0.4395489,-0.4714875,-0.6903338,0.7536921, 0.3005626,-0.3910435, 0.4343908,-0.5062621, -0.252962,-0.7158242, 0.4090287}, sd::DataType::DOUBLE); + + ops::helpers::Schur schur(x); + + ASSERT_TRUE(schur._T.isSameShape(&expT)); + ASSERT_TRUE(schur._T.equalsTo(&expT)); + + ASSERT_TRUE(schur._U.isSameShape(&expU)); + ASSERT_TRUE(schur._U.equalsTo(&expU)); +} +*/ +/* +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests2, Schur_6) { + + NDArray x('c', {10,10}, {6.9 ,4.8 ,9.5 ,3.1 ,6.5 ,5.8 ,-0.9 ,-7.3 ,-8.1 ,3.0 ,0.1 ,9.9 ,-3.2 ,6.4 ,6.2 ,-7.0 ,5.5 ,-2.2 ,-4.0 ,3.7 ,-3.6 ,9.0 ,-1.4 ,-2.4 ,1.7 , + -6.1 ,-4.2 ,-2.5 ,-5.6 ,-0.4 ,0.4 ,9.1 ,-2.1 ,-5.4 ,7.3 ,3.6 ,-1.7 ,-5.7 ,-8.0 ,8.8 ,-3.0 ,-0.5 ,1.1 ,10.0 ,8.0 ,0.8 ,1.0 ,7.5 ,3.5 ,-1.8 , + 0.3 ,-0.6 ,-6.3 ,-4.5 ,-1.1 ,1.8 ,0.6 ,9.6 ,9.2 ,9.7 ,-2.6 ,4.3 ,-3.4 ,0.0 ,-6.7 ,5.0 ,10.5 ,1.5 ,-7.8 ,-4.1 ,-5.3 ,-5.0 ,2.0 ,-4.4 ,-8.4 , + 6.0 ,-9.4 ,-4.8 ,8.2 ,7.8 ,5.2 ,-9.5 ,-3.9 ,0.2 ,6.8 ,5.7 ,-8.5 ,-1.9 ,-0.3 ,7.4 ,-8.7 ,7.2 ,1.3 ,6.3 ,-3.7 ,3.9 ,3.3 ,-6.0 ,-9.1 ,5.9}, sd::DataType::DOUBLE); + NDArray expT('c', {10,10}, {-13.78982, 6.072464, 0.3021194, -8.455495,-0.3047058, 4.033153, 2.610364, 2.80607, -2.735616, 0.3040549,-2.188506, -12.38324, -1.167179, -4.539672, -19.08546, 1.752401,-0.1354974,-0.2747422,-0.3270464, -5.070936, + 0,0,0.5067366, 7.930223,-0.6465996, 8.659522, 1.283713, 4.551415, 12.7736, 3.4812,0,0,-9.858142, -2.905068, -6.474159, -6.247967, 0.4720073, -10.49523, 3.617189, -4.941627, + 0,0,0,0,9.461626, -4.896166, 9.339704, 4.640336, 16.8626, 2.056027,0,0,0,0,6.479812, 8.462862, 7.386285, -4.123457, -5.817095, -2.633641,0,0,0,0,0,0,13.46667, -4.907281, 4.602204, 5.198035, + 0,0,0,0,0,0, 7.176822, 16.93311, 2.195036, 1.346086,0,0,0,0,0,0,0,0, 16.86979, -3.052473,0,0,0,0,0,0,0,0,0, -5.52268}, sd::DataType::DOUBLE); + + // NDArray expT('c', {10,10}, {-13.78982, 6.072464, 0.1926198, -8.458698,-0.3047363, 4.033151, 2.610336, 2.806096, -2.735616, 0.3040549,-2.188506, -12.38324, -1.225857, -4.52418, -19.08548, 1.752257,-0.1354946,-0.2747435,-0.3270464, -5.070936, + // 0,0, 0.4812058, 7.886377,-0.7304318, 8.577898, 1.289673, 4.415163, 12.81936, 3.416929,0,0, -9.901988, -2.879537, -6.465196, -6.359608, 0.455452, -10.55328, 3.451505, -4.986284, + // 0,0,0,0, 9.461614, -4.896159, 9.339602, 4.64046, 16.86265, 2.056047,0,0,0,0, 6.47982, 8.462874, 7.386396, -4.123349, -5.816967, -2.633626, + // 0,0,0,0,0,0, 13.46665, -4.907315, 4.602182, 5.198022,0,0,0,0,0,0, 7.176788, 16.93313, 2.195081, 1.346137,0,0,0,0,0,0,0,0, 16.86979, -3.052473,0,0,0,0,0,0,0,0,0, -5.52268}, sd::DataType::DOUBLE); + + NDArray expU('c', {10,10}, {0.1964177, 0.2165192, -0.2138164, 0.4083154, -0.1872303, -0.5087223, 0.5529025, -0.2996174,-0.08772947, 0.07126534,-0.1906247, -0.223588, 0.3574755, 0.4245914, -0.3885589,-0.07328949, -0.4176507, -0.1885168, -0.4476957, 0.1971104, + -0.2219015, 0.3084187, 0.1069209, -0.4905009, -0.3517786, 0.1446875, 0.121738, -0.3772941, 0.1232591, 0.5353205,-0.4766346, 0.6158252, -0.1529085, 0.04780914, 0.1274182, -0.1219211, -0.3123289, -0.2219282,-0.07613826, -0.429201, + 0.2577533, -0.3356205, -0.225358, -0.1540796, 0.3155174, -0.1904664, -0.3567101, -0.6831458, 0.1244646, 0.03383783, -0.45597, -0.3350697, 0.06824276, -0.2861978,-0.06724917, -0.7046481, 0.01664764, 0.2270567, 0.2003283,-0.01544937, + 0.122865, 0.1516775, -0.4446453, -0.2338583, 0.1633447, -0.193498, -0.198088, 0.3170272, -0.5869794, 0.4013553, 0.347383, 0.3666581, 0.6890763,-0.05797414, 0.3630058, -0.319958, -0.1071812, 0.06162044, 0.03171228, 0.1275262, + -0.2986812, 0.05382598, -0.1484276, 0.4936468, 0.362756, 0.05858297, -0.1055183, 0.1090384, 0.4217073, 0.5534347, 0.3864388, 0.2085926, -0.204135, 0.05230855, -0.5290207, -0.1548485, -0.4670302, 0.2205726, 0.4380318,-0.01626632}, sd::DataType::DOUBLE); + + ops::helpers::Schur schur(x); + + ASSERT_TRUE(schur._T.isSameShape(&expT)); + ASSERT_TRUE(schur._T.equalsTo(&expT, 1e-3)); + + ASSERT_TRUE(schur._U.isSameShape(&expU)); + ASSERT_TRUE(schur._U.equalsTo(&expU)); +} +*/ + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests2, EigenValsAndVecs_1) { + + NDArray x('c', {2,2}, {1.5,-2,17,5}, sd::DataType::DOUBLE); + NDArray expVals('c', {2,2}, {3.25,5.562149, 3.25,-5.562149}, sd::DataType::DOUBLE); + NDArray expVecs('c', {2,2,2}, {-0.3094862,-0.0973726, -0.3094862,0.0973726,0,0.9459053, 0,-0.9459053}, sd::DataType::DOUBLE); + + ops::helpers::EigenValsAndVecs eig(x); + + ASSERT_TRUE(eig._Vals.isSameShape(&expVals)); + ASSERT_TRUE(eig._Vals.equalsTo(&expVals)); + + ASSERT_TRUE(eig._Vecs.isSameShape(&expVecs)); + ASSERT_TRUE(eig._Vecs.equalsTo(&expVecs)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests2, EigenValsAndVecs_2) { + + NDArray x('c', {3,3}, {33,24,-48,57,12.5,-3,1.1,10,-5.2}, sd::DataType::DOUBLE); + NDArray expVals('c', {3,2}, {53.73337,0, -27.51557,0, 14.0822,0}, sd::DataType::DOUBLE); + NDArray expVecs('c', {3,3,2}, {-0.5848506,0,0.5560778,0,-0.04889745,0,-0.7978391,0,-0.7683444,0,-0.8855156,0,-0.1462962,0,0.3168979,0,-0.4620293,0}, sd::DataType::DOUBLE); + + ops::helpers::EigenValsAndVecs eig(x); + + ASSERT_TRUE(eig._Vals.isSameShape(&expVals)); + ASSERT_TRUE(eig._Vals.equalsTo(&expVals)); + + ASSERT_TRUE(eig._Vecs.isSameShape(&expVecs)); + ASSERT_TRUE(eig._Vecs.equalsTo(&expVecs)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests2, EigenValsAndVecs_3) { + + NDArray x('c', {4,4}, {0.33 ,-7.25 ,1.71 ,6.20 ,1.34 ,5.38 ,-2.76 ,-8.51 ,7.59 ,3.44 ,2.24 ,-6.82 ,-1.15 ,4.80 ,-4.67 ,2.14}, sd::DataType::DOUBLE); + NDArray expVals('c', {4,2}, {6.114896,4.659591,6.114896,-4.659591, -1.069896,4.45631,-1.069896,-4.45631}, sd::DataType::DOUBLE); + NDArray expVecs('c', {4,4,2}, {-0.2141303,0.4815241,-0.2141303,-0.4815241, 0.1035092,-0.4270603, 0.1035092,0.4270603, 0.2703519,-0.2892722, 0.2703519,0.2892722, -0.5256817,0.044061, -0.5256817,-0.044061, + 0.6202137,0.05521234,0.6202137,-0.05521234, -0.5756007,0.3932209,-0.5756007,-0.3932209,-0.4166034,-0.0651337, -0.4166034,0.0651337, -0.1723716,0.1138941,-0.1723716,-0.1138941}, sd::DataType::DOUBLE); + + ops::helpers::EigenValsAndVecs eig(x); + + ASSERT_TRUE(eig._Vals.isSameShape(&expVals)); + ASSERT_TRUE(eig._Vals.equalsTo(&expVals)); + + ASSERT_TRUE(eig._Vecs.isSameShape(&expVecs)); + ASSERT_TRUE(eig._Vecs.equalsTo(&expVecs)); +} + +/* +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests2, EigenValsAndVecs_4) { + + NDArray x('c', {10,10}, {6.9 ,4.8 ,9.5 ,3.1 ,6.5 ,5.8 ,-0.9 ,-7.3 ,-8.1 ,3.0 ,0.1 ,9.9 ,-3.2 ,6.4 ,6.2 ,-7.0 ,5.5 ,-2.2 ,-4.0 ,3.7 ,-3.6 ,9.0 ,-1.4 ,-2.4 ,1.7 , + -6.1 ,-4.2 ,-2.5 ,-5.6 ,-0.4 ,0.4 ,9.1 ,-2.1 ,-5.4 ,7.3 ,3.6 ,-1.7 ,-5.7 ,-8.0 ,8.8 ,-3.0 ,-0.5 ,1.1 ,10.0 ,8.0 ,0.8 ,1.0 ,7.5 ,3.5 ,-1.8 , + 0.3 ,-0.6 ,-6.3 ,-4.5 ,-1.1 ,1.8 ,0.6 ,9.6 ,9.2 ,9.7 ,-2.6 ,4.3 ,-3.4 ,0.0 ,-6.7 ,5.0 ,10.5 ,1.5 ,-7.8 ,-4.1 ,-5.3 ,-5.0 ,2.0 ,-4.4 ,-8.4 , + 6.0 ,-9.4 ,-4.8 ,8.2 ,7.8 ,5.2 ,-9.5 ,-3.9 ,0.2 ,6.8 ,5.7 ,-8.5 ,-1.9 ,-0.3 ,7.4 ,-8.7 ,7.2 ,1.3 ,6.3 ,-3.7 ,3.9 ,3.3 ,-6.0 ,-9.1 ,5.9}, sd::DataType::DOUBLE); + NDArray expVals('c', {10,2}, { -13.08653,3.577011,-13.08653,-3.577011, -1.199166,8.675665,-1.199166,-8.675665,8.962244, + 5.610424, 8.962244,-5.610424, 15.19989,5.675794, 15.19989,-5.675794,16.86979,0,-5.52268,0}, sd::DataType::DOUBLE); + NDArray expVecs('c', {10,10,2}, {0.1652385,0.1439317, 0.1652385,-0.1439317, -0.198272,0.207306, -0.198272,-0.207306, 0.1861466,-0.4599919, 0.1861466,0.4599919, 0.09384053,-0.4889922, 0.09384053,0.4889922, -0.6153314,0, -0.2180209,0, + -0.1603652,-0.1466119, -0.1603652,0.1466119, 0.2817409,0.3301842, 0.2817409,-0.3301842, 0.09747303,-0.2218182, 0.09747303,0.2218182, 0.2318273,-0.3355113, 0.2318273,0.3355113, -0.4828878,0, -0.1451126,0, + -0.1866771,0.1220412, -0.1866771,-0.1220412, 0.08937842,-0.3025104, 0.08937842,0.3025104, 0.2783766,0.2258364, 0.2783766,-0.2258364, -0.1413997,-0.09596012, -0.1413997,0.09596012, -0.2286925,0, 0.3290011,0, + -0.4009741,0.238131, -0.4009741,-0.238131, -0.02772353,0.1338458, -0.02772353,-0.1338458, 0.09030543,-0.2222453, 0.09030543,0.2222453, 0.2565825,-0.2275446, 0.2565825,0.2275446, -0.2855937,0, -0.3950544,0, + 0.2168379,-0.1301121, 0.2168379,0.1301121, -0.165433,-0.1220125, -0.165433,0.1220125, -0.2685605,0.008133055,-0.2685605,-0.008133055, 0.1929395,-0.1194659, 0.1929395,0.1194659, 0.2206467,0, 0.3289105,0, + -0.3835898,-0.2478813, -0.3835898,0.2478813, 0.1923005,-0.01036433, 0.1923005,0.01036433, -0.1711637,-0.3548358, -0.1711637,0.3548358, 0.2888441,0.09625169, 0.2888441,-0.09625169, 0.2595426,0, -0.1288072,0, + 0.1033616,0.09839151, 0.1033616,-0.09839151, -0.3080167,-0.1624564, -0.3080167,0.1624564,-0.03972293,-0.03967309, -0.03972293,0.03967309, 0.1965443,0.3025898, 0.1965443,-0.3025898, 0.04587166,0, 0.499261,0, + 0.2922398,0.2461792, 0.2922398,-0.2461792, 0.2769633,-0.2745029, 0.2769633,0.2745029, 0.1034687,-0.002947149, 0.1034687,0.002947149, -0.02611308,0.1658046, -0.02611308,-0.1658046, 0.2351063,0, -0.3787892,0, + -0.2512689,-0.02169855, -0.2512689,0.02169855, -0.01481625,0.4376404, -0.01481625,-0.4376404, -0.2298635,-0.2360671, -0.2298635,0.2360671, 0.11004,-0.1467444, 0.11004,0.1467444, 0.1501568,0, 0.340117,0, + 0.325096,0.1712822, 0.325096,-0.1712822, -0.2412035,-0.09236849, -0.2412035,0.09236849, 0.3894343,-0.08673087, 0.3894343,0.08673087, 0.3125305,0.07128152, 0.3125305,-0.07128152, -0.2415555,0, 0.1841298,0,}, sd::DataType::DOUBLE); + + ops::helpers::EigenValsAndVecs eig(x); + + ASSERT_TRUE(eig._Vals.isSameShape(&expVals)); + ASSERT_TRUE(eig._Vals.equalsTo(&expVals)); + + ASSERT_TRUE(eig._Vecs.isSameShape(&expVecs)); + ASSERT_TRUE(eig._Vecs.equalsTo(&expVecs)); +} +*/ + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests2, fullPivLU_1) { + + NDArray a('c', {4,4}, {0.33 ,-7.25 ,1.71 ,6.20 ,1.34 ,5.38 ,-2.76 ,-8.51 ,7.59 ,3.44 ,2.24 ,-6.82 ,-1.15 ,4.80 ,-4.67 ,2.14}, sd::DataType::DOUBLE); + NDArray b('c', {4,1}, {-5.,10,9,1}, sd::DataType::DOUBLE); + + NDArray x = b.ulike(); + + NDArray expX('c', {4,1}, {0.8527251, -0.2545784, -1.076495, -0.8526268}, sd::DataType::DOUBLE); + + ops::helpers::FullPivLU::solve(a,b,x); + + ASSERT_TRUE(x.equalsTo(&expX)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests2, fullPivLU_2) { + + NDArray a('c', {4,4}, {0.33 ,-7.25 ,1.71 ,6.20 ,1.34 ,5.38 ,-2.76 ,-8.51 ,7.59 ,3.44 ,2.24 ,-6.82 ,-1.15 ,4.80 ,-4.67 ,2.14}, sd::DataType::DOUBLE); + NDArray b('c', {4,2}, {-5.,10,9,1,1.5,-2,17,5}, sd::DataType::DOUBLE); + + NDArray x = b.ulike(); + + NDArray expX('c', {4,2}, {1.462913, 1.835338, 0.4083664, -2.163816, -3.344481, -3.739225, 0.5156383,0.01624954}, sd::DataType::DOUBLE); + + ops::helpers::FullPivLU::solve(a,b,x); + + ASSERT_TRUE(x.equalsTo(&expX)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests2, fullPivLU_3) { + + NDArray a1('c', {4,3}, {0.33 ,1.71 ,6.20 ,1.34 ,5.38 ,-2.76 ,-8.51 ,2.24 ,-6.82 ,4.80 ,-4.67 ,2.14}, sd::DataType::DOUBLE); + NDArray a2('c', {3,4}, {0.33 ,1.71 ,6.20 ,1.34 ,5.38 ,-2.76 ,-8.51 ,2.24 ,-6.82 ,4.80 ,-4.67 ,2.14}, sd::DataType::DOUBLE); + NDArray b1('c', {4,2}, {-5.,10,9,1,1.5,-2,17,5}, sd::DataType::DOUBLE); + NDArray b2('c', {3,2}, {-5.,10,9,1,1.5,-2}, sd::DataType::DOUBLE); + + NDArray expX1('c', {3,2}, {0.9344955,-0.5841325, 0.8768102, 1.029137, -1.098021, 1.360152}, sd::DataType::DOUBLE); + NDArray expX2('c', {4,2}, {0.3536033,0.5270184,0,0,-0.8292221,0.967515,0.01827441,2.856337}, sd::DataType::DOUBLE); + + NDArray x1 = expX1.ulike(); + ops::helpers::FullPivLU::solve(a1,b1,x1); + ASSERT_TRUE(x1.equalsTo(&expX1)); + + NDArray x2 = expX2.ulike(); + ops::helpers::FullPivLU::solve(a2,b2,x2); + ASSERT_TRUE(x2.equalsTo(&expX2)); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(HelpersTests2, fullPivLU_4) { + + NDArray a('c', {10,10}, {6.9 ,4.8 ,9.5 ,3.1 ,6.5 ,5.8 ,-0.9 ,-7.3 ,-8.1 ,3.0 ,0.1 ,9.9 ,-3.2 ,6.4 ,6.2 ,-7.0 ,5.5 ,-2.2 ,-4.0 ,3.7 ,-3.6 ,9.0 ,-1.4 ,-2.4 ,1.7 , + -6.1 ,-4.2 ,-2.5 ,-5.6 ,-0.4 ,0.4 ,9.1 ,-2.1 ,-5.4 ,7.3 ,3.6 ,-1.7 ,-5.7 ,-8.0 ,8.8 ,-3.0 ,-0.5 ,1.1 ,10.0 ,8.0 ,0.8 ,1.0 ,7.5 ,3.5 ,-1.8 , + 0.3 ,-0.6 ,-6.3 ,-4.5 ,-1.1 ,1.8 ,0.6 ,9.6 ,9.2 ,9.7 ,-2.6 ,4.3 ,-3.4 ,0.0 ,-6.7 ,5.0 ,10.5 ,1.5 ,-7.8 ,-4.1 ,-5.3 ,-5.0 ,2.0 ,-4.4 ,-8.4 , + 6.0 ,-9.4 ,-4.8 ,8.2 ,7.8 ,5.2 ,-9.5 ,-3.9 ,0.2 ,6.8 ,5.7 ,-8.5 ,-1.9 ,-0.3 ,7.4 ,-8.7 ,7.2 ,1.3 ,6.3 ,-3.7 ,3.9 ,3.3 ,-6.0 ,-9.1 ,5.9}, sd::DataType::DOUBLE); + NDArray b('c', {10,2}, {-5.,10,9,1,1.5,-2,17,5,3.6,0.12, -3.1,2.27,-0.5,27.3,8.9,5,-7,8,-9,10}, sd::DataType::DOUBLE); + + NDArray x = b.ulike(); + + NDArray expX('c', {10,2}, {-0.697127, 2.58257, 2.109721,3.160622,-2.217796, -3.275736,-0.5752479, 2.475356,1.996841, -1.928947, + 2.213154,3.541014, 0.7104885, -1.981451,-3.297972,-0.4720612, 3.672657, 0.9161028, -2.322383, -1.784493}, sd::DataType::DOUBLE); + + ops::helpers::FullPivLU::solve(a,b,x); + + ASSERT_TRUE(x.equalsTo(&expX)); +} diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/IndexingTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/IndexingTests.cpp new file mode 100644 index 000000000..922827b7f --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/IndexingTests.cpp @@ -0,0 +1,472 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 31.10.2017. +// + +#include "testlayers.h" +#include +#include +#include + +using namespace sd; +using namespace sd::graph; + +class IndexingTests : public testing::Test { +public: + +}; + +TEST_F(IndexingTests, StridedSlice_1) { + auto x = NDArrayFactory::create('c', {3, 3, 3}); + auto exp = NDArrayFactory::create('c', {1, 1, 3}); + exp.p(0, 25.f); + exp.p(1, 26.f); + exp.p(2, 27.f); + + x.linspace(1); + auto begin = NDArrayFactory::create({2,2, 0}); + auto end = NDArrayFactory::create({3,3,3}); + auto strides = NDArrayFactory::create({1,1,1}); + + + sd::ops::strided_slice op; + + auto result = op.evaluate({&x, &begin, &end, &strides}, {}, {0,0,0,0,0}); //, 2,2,0, 3,3,3, 1,1,1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(IndexingTests, StridedSlice_2) { + auto x = NDArrayFactory::create('c', {5, 5, 5}); + auto exp = NDArrayFactory::create('c', {2, 3, 3}, {86.f, 87.f, 88.f, 91.f, 92.f, 93.f, 96.f, 97.f, 98.f, 111.f, 112.f, 113.f, 116.f, 117.f, 118.f, 121.f, 122.f, 123.f}); + + x.linspace(1); + + sd::ops::strided_slice op; + + auto result = op.evaluate({&x}, {}, {0,0,0,0,0, 3,2,0, 5,5,3, 1,1,1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(IndexingTests, StridedSlice_3) { + auto x = NDArrayFactory::create('c', {5, 5, 5}); + auto exp = NDArrayFactory::create('c', {2, 3, 2}, {86.f, 88.f, 91.f, 93.f, 96.f, 98.f, 111.f, 113.f, 116.f, 118.f, 121.f, 123.f}); + + x.linspace(1); + + sd::ops::strided_slice op; + + auto result = op.evaluate({&x}, {}, {0,0,0,0,0, 3,2,0, 5,5,3, 1,1,2}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(IndexingTests, SimpleSlice_1) { + + auto input = NDArrayFactory::create('c', {3, 2, 3}, {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6}); + + auto exp = NDArrayFactory::create('c', {1, 1, 3}); + exp.p(0, 3.0f); + exp.p(1, 3.0f); + exp.p(2, 3.0f); + + sd::ops::slice op; + + auto result = op.evaluate({&input}, {}, {1,0,0, 1,1,3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(IndexingTests, SimpleSlice_2) { + auto input = NDArrayFactory::create('c', {3, 2, 3}, {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6}); + + auto exp = NDArrayFactory::create('c', {1, 2, 3}); + exp.p(0, 3.0f); + exp.p(1, 3.0f); + exp.p(2, 3.0f); + exp.p(3, 4.0f); + exp.p(4, 4.0f); + exp.p(5, 4.0f); + + sd::ops::slice op; + + auto result = op.evaluate({&input}, {}, {1,0,0, 1,2,3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(IndexingTests, SimpleSlice_3) { + auto input = NDArrayFactory::create('c', {3, 2, 3}, {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6}); + + auto exp = NDArrayFactory::create('c', {2, 1, 3}); + exp.p(0, 3.0f); + exp.p(1, 3.0f); + exp.p(2, 3.0f); + exp.p(3, 5.0f); + exp.p(4, 5.0f); + exp.p(5, 5.0f); + + sd::ops::slice op; + + auto result = op.evaluate({&input}, {}, {1,0,0, 2,1,3}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(IndexingTests, SimpleSlice_4) { + auto input = NDArrayFactory::create('c', {3, 2, 3}, {1.0, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6}); + auto start = NDArrayFactory::create('c', {3}, {1.0, 0.0, 0.0}); + auto stop = NDArrayFactory::create('c', {3}, {2.0, 1.0, 3.0}); + auto exp = NDArrayFactory::create('c', {2, 1, 3}, {3.0, 3.0, 3.0, 5.0, 5.0, 5.0}); + + sd::ops::slice op; + + auto result = op.evaluate({&input, &start, &stop}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(IndexingTests, MaskedSlice_0) { + auto matrix = NDArrayFactory::create('c', {3, 5}); + auto tads = matrix.allTensorsAlongDimension({1}); + for (int e = 0; e < tads.size(); e++) { + tads.at(e)->assign((float) (e+1)); + } + + auto exp = NDArrayFactory::create('c', {1, 5}); + exp.assign(2.0f); + + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix}, {}, {0,0,0,0,0, 1, 2, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + // z->printShapeInfo("z"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(IndexingTests, MaskedSlice_00) { + auto matrix = NDArrayFactory::create('c', {3, 5}); + auto tads = matrix.allTensorsAlongDimension({1}); + for (int e = 0; e < tads.size(); e++) { + tads.at(e)->assign((float) (e+1)); + } + + auto exp = NDArrayFactory::create('c', {1, 2}, {2, 2}); + + + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix}, {}, {0,0,0,0,0, 1, 1, 2, 3, 1, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(IndexingTests, MaskedSlice_1) { + auto matrix = NDArrayFactory::create('c', {3, 5}); + auto tads = matrix.allTensorsAlongDimension({1}); + for (int e = 0; e < tads.size(); e++) { + tads.at(e)->assign((float) (e+1)); + } + + auto exp = NDArrayFactory::create('c', {5}); + exp.assign(2.0f); + + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix}, {}, {0,0,0,0,1, 1, 2, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + // z->printShapeInfo("z"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(IndexingTests, MaskedSlice_2) { + + auto matrix = NDArrayFactory::create('c', {3, 3, 3}, {1.f, 1.2f, 1.3f, 2.f, 2.2f, 2.3f, 3.f, 3.2f, 3.3f, 4.f, 4.2f, 4.3f, 5.f, 5.2f, 5.3f, 6.f, 6.2f, 6.3f, 7.f, 7.2f, 7.3f, 8.f, 8.2f, 8.3f, 9.f, 9.2f, 9.3f}); + auto exp = NDArrayFactory::create('c', {3, 3}, {4.000000f, 4.200000f, 4.300000f, 5.000000f, 5.200000f, 5.300000f, 6.000000f, 6.200000f, 6.300000f}); + + // output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5) + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix}, {}, {0,0,0,0,1, 1, 0, 0, 3, 3, 3, 1, 1, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(IndexingTests, MaskedSlice_3) { + + auto matrix = NDArrayFactory::create('c', {3, 3, 3}, {1.f, 1.2f, 1.3f, 2.f, 2.2f, 2.3f, 3.f, 3.2f, 3.3f, 4.f, 4.2f, 4.3f, 5.f, 5.2f, 5.3f, 6.f, 6.2f, 6.3f, 7.f, 7.2f, 7.3f, 8.f, 8.2f, 8.3f, 9.f, 9.2f, 9.3f}); + auto exp = NDArrayFactory::create('c', {2, 3}, { 4.f, 4.2f, 4.3f, 7.f, 7.2f, 7.3f}); + + // output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5) + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix}, {}, {0,0,0,0,2, 1, 0, 0, 3, 3, 3, 1, 1, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(IndexingTests, MaskedSlice_4) { + + auto matrix = NDArrayFactory::create('c', {3, 3, 3}, {1.f, 1.2f, 1.3f, 2.f, 2.2f, 2.3f, 3.f, 3.2f, 3.3f, 4.f, 4.2f, 4.3f, 5.f, 5.2f, 5.3f, 6.f, 6.2f, 6.3f, 7.f, 7.2f, 7.3f, 8.f, 8.2f, 8.3f, 9.f, 9.2f, 9.3f}); + auto exp = NDArrayFactory::create('c', {3}, { 4.f, 4.2f, 4.3f}); + + // output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5) + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix}, {}, {0,0,0,0, 3, 1, 0, 0, 3, 3, 3, 1, 1, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(IndexingTests, Live_Slice_1) { + auto matrix = NDArrayFactory::create('c', {3, 3, 3}, {1.f, 1.2f, 1.3f, 2.f, 2.2f, 2.3f, 3.f, 3.2f, 3.3f, 4.f, 4.2f, 4.3f, 5.f, 5.2f, 5.3f, 6.f, 6.2f, 6.3f, 7.f, 7.2f, 7.3f, 8.f, 8.2f, 8.3f, 9.f, 9.2f, 9.3f}); + auto exp = NDArrayFactory::create('c', {3}, { 4.f, 4.2f, 4.3f}); + + auto begin = NDArrayFactory::create('c', {3}, {1.0f, 0.0f, 0.0f}); + auto end = NDArrayFactory::create('c', {3}, {3.0f, 3.0f, 3.0f}); + auto stride = NDArrayFactory::create('c', {3}, {1.0f, 1.0f, 1.0f}); + + // output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5) + sd::ops::strided_slice op; + auto result = op.evaluate({&matrix, &begin, &end, &stride}, {}, {0,0,0,0,3}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + // z->printShapeInfo("z shape"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(IndexingTests, Test_StridedSlice_1) { + auto x = NDArrayFactory::create('c', {1, 2}, {5.f, 2.f}); + auto a = NDArrayFactory::create('c', {1}, {0.f}); + auto b = NDArrayFactory::create('c', {1}, {1.f}); + auto c = NDArrayFactory::create('c', {1}, {1.f}); + auto exp = NDArrayFactory::create({5.0f, 2}); + + sd::ops::strided_slice op; + auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(IndexingTests, Test_StridedSlice_2) { + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + auto a = NDArrayFactory::create('c', {2}, {1, 1}); + auto b = NDArrayFactory::create('c', {2}, {2, 2}); + auto c = NDArrayFactory::create('c', {2}, {1, 1}); + auto exp = NDArrayFactory::create('c', {1}, {5.0}); + + sd::ops::strided_slice op; + auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + // z->printIndexedBuffer("Z"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(IndexingTests, Test_StridedSlice_3) { + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + auto a = NDArrayFactory::create('c', {2}, {1, 2}); + auto b = NDArrayFactory::create('c', {2}, {2, 3}); + auto c = NDArrayFactory::create('c', {2}, {1, 1}); + auto exp = NDArrayFactory::create('c', {1}, {6.0}); + + sd::ops::strided_slice op; + auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(IndexingTests, Test_StridedSlice_4) { + auto x = NDArrayFactory::create('c', {1, 2}, {5, 2}); + auto a = NDArrayFactory::create('c', {1}, {0.}); + auto b = NDArrayFactory::create('c', {1}, {1}); + auto c = NDArrayFactory::create('c', {1}, {1}); + auto exp = NDArrayFactory::create({5.0f, 2}); + + sd::ops::strided_slice op; + auto result = op.evaluate({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1}); +// auto result = op.execute({&x, &a, &b, &c}, {}, {0, 0, 0, 0, 1, 0, 1, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + //z->printIndexedBuffer("Z"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(IndexingTests, Test_Subarray_Strided_1) { + auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto exp = NDArrayFactory::create('c', {3, 2}, {1, 3, 4, 6, 7, 9}); + auto sub = x({0,0,0, 0,3,2}, true, true); + + ASSERT_TRUE(exp.isSameShape(sub)); + ASSERT_TRUE(exp.equalsTo(sub)); +} + + +/* +TEST_F(IndexingTests, MaskedSlice_5) { + + auto matrix('c', {3, 3, 3}, {1.f, 1.2f, 1.3f, 2.f, 2.2f, 2.3f, 3.f, 3.2f, 3.3f, 4.f, 4.2f, 4.3f, 5.f, 5.2f, 5.3f, 6.f, 6.2f, 6.3f, 7.f, 7.2f, 7.3f, 8.f, 8.2f, 8.3f, 9.f, 9.2f, 9.3f}); + auto exp('c', {2, 3}, { 4.f, 4.2f, 4.3f, 7.f, 7.2f, 7.3f}); + + // output = tf.strided_slice(a, [1, 0, 0], [3, 3, 3], shrink_axis_mask=5) + sd::ops::strided_slice op; + auto result = op.execute({&matrix}, {}, {0,0,0,0,2, 1, 0, 0, 3, 3, 3}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} +*/ \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/JavaInteropCudaTests.cu b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/JavaInteropCudaTests.cu new file mode 100644 index 000000000..a8e3509bc --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/JavaInteropCudaTests.cu @@ -0,0 +1,89 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include + +using namespace sd; +using namespace sd::ops; + +class JavaInteropCudaTests : public testing::Test { +public: + +}; + +TEST_F(JavaInteropCudaTests, test_DeclarableOp_execution_1) { + auto x = NDArrayFactory::create('c', {3, 5}); + auto y = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); + auto e = NDArrayFactory::create('c', {3, 5}); + x.assign(1.f); + e.assign(2.f); + + sd::ops::add op; + Context context(1); + + context.setCudaContext(LaunchContext::defaultContext()->getCudaStream(), LaunchContext::defaultContext()->getReductionPointer(), LaunchContext::defaultContext()->getAllocationPointer()); + context.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); + context.setInputArray(1, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo()); + + context.setOutputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); + + PointersManager pm(LaunchContext::defaultContext(), "test_DeclarableOp_execution_1"); + execCustomOp2(nullptr, op.getOpHash(), &context); + + pm.synchronize(); + + ASSERT_EQ(e, x); +} + +TEST_F(JavaInteropCudaTests, test_DeclarableOp_execution_2) { + NDArray x('c', {3, 1, 2}, sd::DataType::FLOAT32); + NDArray y('c', {2, 2}, sd::DataType::FLOAT32); + NDArray z('c', {3, 2, 2}, sd::DataType::BOOL); + NDArray e('c', {3, 2, 2}, sd::DataType::BOOL); + + x.assign(1.f); + y.assign(2.f); + e.assign(false); + + sd::ops::equals op; + Context context(1); + + context.setCudaContext(LaunchContext::defaultContext()->getCudaStream(), LaunchContext::defaultContext()->getReductionPointer(), LaunchContext::defaultContext()->getAllocationPointer()); + context.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); + context.setInputArray(1, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo()); + + context.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + + PointersManager pm(LaunchContext::defaultContext(), "test_DeclarableOp_execution_2"); + execCustomOp2(nullptr, op.getOpHash(), &context); + + pm.synchronize(); + + ASSERT_EQ(e, z); +} + diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/JavaInteropTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/JavaInteropTests.cpp new file mode 100644 index 000000000..4c2b5a971 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/JavaInteropTests.cpp @@ -0,0 +1,1500 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include +#include +#include +#include +#include "testlayers.h" +#include + +using namespace sd; +using namespace sd::ops; + +class JavaInteropTests : public testing::Test { +public: + +}; + + +TEST_F(JavaInteropTests, TestShapeExposure1) { + auto input = NDArrayFactory::create('c', {1, 2, 5, 4}); + auto weights = NDArrayFactory::create('c', {2, 2, 2, 3}); + auto exp = NDArrayFactory::create('c', {1, 3, 5, 4}); + + sd::ops::conv2d op; + + std::vector tArgs({}); + std::vector iArgs({2, 2, 1, 1, 0, 0, 1, 1, 1}); + + Nd4jPointer ptrs[] = {(Nd4jPointer) input.shapeInfo(), (Nd4jPointer) weights.shapeInfo()}; + + auto shapeList = calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 2, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size()); + + ASSERT_EQ(1, shapeList->size()); + + ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0))); + ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]); + ASSERT_EQ(exp.sizeAt(1), shape::shapeOf((Nd4jLong *)shapeList->at(0))[1]); + ASSERT_EQ(exp.sizeAt(2), shape::shapeOf((Nd4jLong *)shapeList->at(0))[2]); + ASSERT_EQ(exp.sizeAt(3), shape::shapeOf((Nd4jLong *)shapeList->at(0))[3]); + + //int *ptr = (int *) shapeList[0]; + //delete[] ptr; + //delete shapeList; + + deleteShapeList((Nd4jPointer) shapeList); +} + + +TEST_F(JavaInteropTests, TestShapeExposure2) { + auto input = NDArrayFactory::create('c', {1, 2, 5, 4}); + auto exp = NDArrayFactory::create('c', {4}, {1, 2, 5, 4}); + + sd::ops::shape_of op; + + std::vector tArgs({}); + std::vector iArgs({}); + + + Nd4jPointer ptrs[] = {(Nd4jPointer) input.shapeInfo()}; + + auto shapeList = calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 1, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size()); + + ASSERT_EQ(1, shapeList->size()); + + ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0))); + ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]); + + deleteShapeList((Nd4jPointer) shapeList); +} + +TEST_F(JavaInteropTests, TestShapeExposure3) { + auto x = NDArrayFactory::create('c', {5, 30}); + auto sizes = NDArrayFactory::create('c', {3}, {4, 15, 11}); + + std::vector list0 = {0,0, 0,4}; + std::vector list1 = {0,0, 4,19}; + std::vector list2 = {0,0, 19,30}; + + auto sub0 = x(list0, true); + auto sub1 = x(list1, true); + auto sub2 = x(list2, true); + + sub0.assign(0.0f); + sub1.assign(1.0f); + sub2.assign(2.0f); + + Nd4jPointer inputBuffers[] = {x.buffer(), sizes.buffer(), x.specialBuffer(), sizes.specialBuffer()}; + Nd4jPointer inputShapes[] = {(Nd4jPointer)x.shapeInfo(), (Nd4jPointer)sizes.shapeInfo(), (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)sizes.specialShapeInfo()}; + + sd::ops::split_v op; + + Nd4jLong iArgs[] = {1}; + auto hash = op.getOpHash(); + + auto shapeList = calculateOutputShapes2(nullptr, hash, inputBuffers, inputShapes, 2, nullptr, 0, iArgs, 1, nullptr, 0, nullptr, 0); + + ASSERT_EQ(3, shapeList->size()); + + ASSERT_TRUE(shape::equalsSoft(sub0.shapeInfo(), shapeList->at(0))); + ASSERT_TRUE(shape::equalsSoft(sub1.shapeInfo(), shapeList->at(1))); + ASSERT_TRUE(shape::equalsSoft(sub2.shapeInfo(), shapeList->at(2))); + + deleteShapeList((Nd4jPointer) shapeList); +} + +TEST_F(JavaInteropTests, Test_Squeeze_1) { + auto x = NDArrayFactory::create('c', {1, 6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto z = NDArrayFactory::create('c', {6}); + auto e = NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + + sd::ops::squeeze op; + + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), x.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer)x.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.buffer(), z.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.shapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; + auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); +} + +TEST_F(JavaInteropTests, Test_RDiv_1) { + auto x = NDArrayFactory::create('c', {3}, {2, 2, 2}); + auto y = NDArrayFactory::create('c', {3}, {4, 6, 8}); + auto z = NDArrayFactory::create('c', {3}); + auto e = NDArrayFactory::create('c', {3}, {2, 3, 4}); + + NDArray::prepareSpecialUse({&z}, {&x, &y}); + + sd::ops::reversedivide op; + + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), (Nd4jPointer) y.buffer(), x.specialBuffer(), y.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer) y.shapeInfo(), (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)y.specialShapeInfo()}; + + + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.buffer(), (Nd4jPointer)z.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.shapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; + auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); + + NDArray::registerSpecialUse({&z}, {&x, &y}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); +} + +TEST_F(JavaInteropTests, TestSconv2d_1) { + auto input = NDArrayFactory::create('c', {3, 3, 8, 8}); + auto weightsD = NDArrayFactory::create('c', {1, 3, 1, 1}); + auto weightsP = NDArrayFactory::create('c', {2, 3, 1, 1}); + auto bias = NDArrayFactory::create('c', {2}); + auto output = NDArrayFactory::create('c', {3, 2, 8, 8}); + output.assign(0.0); + + input.linspace(1); + weightsD.linspace(1); + weightsP.linspace(1); + bias.linspace(1); + weightsD.permutei({2,3,1,0}); + weightsP.permutei({2,3,1,0}); + + auto expOutput = NDArrayFactory::create('c', {3, 2, 8, 8}); + + sd::ops::sconv2d op; + + NDArray::prepareSpecialUse({&output}, {&input, &weightsD, &weightsP, &bias}); + + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.buffer(), (Nd4jPointer) weightsD.buffer(), (Nd4jPointer) weightsP.buffer(), (Nd4jPointer) bias.buffer(), (Nd4jPointer) input.specialBuffer(), (Nd4jPointer) weightsD.specialBuffer(), (Nd4jPointer) weightsP.specialBuffer(), (Nd4jPointer) bias.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.shapeInfo(), (Nd4jPointer) weightsD.shapeInfo(), (Nd4jPointer) weightsP.shapeInfo(), (Nd4jPointer) bias.shapeInfo(), (Nd4jPointer) input.specialShapeInfo(), (Nd4jPointer) weightsD.specialShapeInfo(), (Nd4jPointer) weightsP.specialShapeInfo(), (Nd4jPointer) bias.specialShapeInfo()}; + + + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.buffer(), (Nd4jPointer) output.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.shapeInfo(), (Nd4jPointer) output.specialShapeInfo()}; + + Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0, 0}; + + execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 4, ptrsOutBuffers, ptrsOutShapes, 1, + nullptr, 0, exp, 9, nullptr, 0, false); + + //output.printBuffer("output"); + NDArray::registerSpecialUse({&output}, {&input, &weightsD, &weightsP, &bias}); + + ASSERT_NEAR(1423, output.e(0), 1e-5); + //nd4j_printf("Iter %i passed...\n", e); +} + +TEST_F(JavaInteropTests, TestSconv2d_2) { + auto input = NDArrayFactory::create('c', {3, 3, 8, 8}); + auto weightsD = NDArrayFactory::create('c', {1, 3, 1, 1}); + auto output = NDArrayFactory::create('c', {3, 3, 8, 8}); + output.assign(0.0); + + input.linspace(1); + weightsD.linspace(1); + weightsD.permutei({2,3,1,0}); + + auto expOutput = NDArrayFactory::create('c', {3, 3, 8, 8}); + + sd::ops::sconv2d op; + + NDArray::prepareSpecialUse({&output}, {&input, &weightsD}); + + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.buffer(), (Nd4jPointer) weightsD.buffer(), input.specialBuffer(), weightsD.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.shapeInfo(), (Nd4jPointer) weightsD.shapeInfo(), (Nd4jPointer)input.specialShapeInfo(), (Nd4jPointer)weightsD.specialShapeInfo()}; + + + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.buffer(), output.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.shapeInfo(), (Nd4jPointer)output.specialShapeInfo()}; + + Nd4jLong exp[] = {1, 1, 1, 1, 0, 0, 1, 1, 0}; + + execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false); + + NDArray::registerSpecialUse({&output}, {&input, &weightsD}); + + ASSERT_NEAR(1, output.e(0), 1e-5); +} + + +TEST_F(JavaInteropTests, TestMaxPooling2d_1) { + auto input = NDArrayFactory::create('c', {1, 2, 4, 5}); + auto output = NDArrayFactory::create('c', {1, 2, 4, 5}); + input.linspace(1); + + NDArray::prepareSpecialUse({&output}, {&input}); + + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.buffer(), input.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.shapeInfo(), (Nd4jPointer)input.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.buffer(), output.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.shapeInfo(), (Nd4jPointer)output.specialShapeInfo()}; + + std::vector iArgs({2, 2, 1, 1, 0, 0, 1, 1, 1}); + + sd::ops::maxpool2d op; + + Nd4jStatus status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs.data(), 9, nullptr, 0, false); + + NDArray::registerSpecialUse({&output}, {&input}); + ASSERT_EQ(ND4J_STATUS_OK, status); + +} +TEST_F(JavaInteropTests, TestCol2Im_1) { + /* + o.d.n.l.c.ConvolutionLayer - eps shape: [6, 1, 2, 2, 2, 4, 5, 160, 4, 2, 1, 40, 8, 0, -1, 99] + o.d.n.l.c.ConvolutionLayer - epsNext shape: [4, 1, 2, 4, 5, 20, 20, 5, 1, 0, 1, 99] + o.d.n.l.c.ConvolutionLayer - Strides: [1, 1] + o.d.n.l.c.ConvolutionLayer - Padding: [0, 0] + o.d.n.l.c.ConvolutionLayer - Input: [4,5] + o.d.n.l.c.ConvolutionLayer - Dilation: [1, 1] + */ + auto input = NDArrayFactory::create('c', {1, 2, 2, 2, 4, 5}); + auto output = NDArrayFactory::create('c', {1, 2, 4, 5}); + input.linspace(1); + + NDArray::prepareSpecialUse({&output}, {&input}); + + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.buffer(), input.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.shapeInfo(), (Nd4jPointer)input.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.buffer(), output.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.shapeInfo(), (Nd4jPointer)output.specialShapeInfo()}; + + sd::ops::col2im op; + + Nd4jLong exp[] = {1, 1, 1, 1, 4, 5, 1, 1, 1}; + + auto hash = op.getOpHash(); + + execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 9, nullptr, 0, false); + + NDArray::registerSpecialUse({&output}, {&input}); + + ASSERT_TRUE(output.meanNumber().e(0) > 0.0f); +} + +TEST_F(JavaInteropTests, TestPNorm_1) { + /* + o.d.n.l.c.s.SubsamplingLayer - input: [4, 1, 3, 4, 4, 16, 16, 4, 1, 0, 1, 99] + o.d.n.l.c.s.SubsamplingLayer - output: [4, 1, 3, 3, 3, 27, 9, 3, 1, 0, 1, 99] + o.d.n.l.c.s.SubsamplingLayer - Kernel: [2, 2] + o.d.n.l.c.s.SubsamplingLayer - Strides: [1, 1] + o.d.n.l.c.s.SubsamplingLayer - Pad: [0, 0] + o.d.n.l.c.s.SubsamplingLayer - Dilation: [1, 1] + o.d.n.l.c.s.SubsamplingLayer - Same: false + o.d.n.l.c.s.SubsamplingLayer - pnorm: 2 + */ + auto input = NDArrayFactory::create('c', {1, 3, 4, 4}); + auto output = NDArrayFactory::create('c', {1, 3, 3, 3}); + input.linspace(1); + + NDArray::prepareSpecialUse({&output}, {&input}); + + sd::ops::pnormpool2d op; + + Nd4jLong exp[] = {2, 2, 1, 1, 0, 0, 1, 1, 0, 2, 0, 0}; + + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.buffer(), input.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.shapeInfo(), (Nd4jPointer)input.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.buffer(), output.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.shapeInfo(), (Nd4jPointer)output.specialShapeInfo()}; + + + execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false); + + NDArray::registerSpecialUse({&output}, {&input}); + + ASSERT_TRUE(output.meanNumber().e(0) > 0.0); +} + + +TEST_F(JavaInteropTests, TestInplace_1) { + auto input = NDArrayFactory::create('c', {10, 10}); + //auto exp('c', {10, 10}); + input.linspace(1); + + NDArray::prepareSpecialUse({}, {&input}); + + sd::ops::clipbyvalue op; + + double extras[] = {-1.0f, 1.0f}; + + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.buffer(), input.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.shapeInfo(), (Nd4jPointer)input.specialShapeInfo()}; + + + Nd4jStatus result = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, nullptr, nullptr, 0, extras, 2, nullptr, 0, nullptr, 0, true); + + NDArray::registerSpecialUse({}, {&input}); + + ASSERT_EQ(ND4J_STATUS_OK, result); + + ASSERT_NEAR(1.0, input.meanNumber().e(0), 1e-5); +} + +TEST_F(JavaInteropTests, Test_Synonyms_1) { + auto op = OpRegistrator::getInstance().getOperation("RDiv"); + auto opRef = OpRegistrator::getInstance().getOperation("reversedivide"); + std::string nameExp("reversedivide"); + + ASSERT_TRUE(op != nullptr); + ASSERT_TRUE(opRef != nullptr); + + std::string name = *(op->getOpName()); + std::string nameRef = *(opRef->getOpName()); + + ASSERT_EQ(nameExp, nameRef); + ASSERT_EQ(nameRef, name); +} + +TEST_F(JavaInteropTests, Test_Synonyms_2) { + auto op = OpRegistrator::getInstance().getOperation("RDiv"); + auto opRef = OpRegistrator::getInstance().getOperation("reversedivide"); + std::string nameExp("reversedivide"); + + ASSERT_TRUE(op != nullptr); + ASSERT_TRUE(opRef != nullptr); + + std::string name = *(op->getOpName()); + std::string nameRef = *(opRef->getOpName()); + + ASSERT_EQ(nameExp, nameRef); + ASSERT_EQ(nameRef, name); +} + +TEST_F(JavaInteropTests, Test_Synonyms_3) { + auto op = OpRegistrator::getInstance().getOperation("RDiv"); + auto opRef = OpRegistrator::getInstance().getOperation("reversedivide"); + std::string nameExp("reversedivide"); + + ASSERT_TRUE(op != nullptr); + ASSERT_TRUE(opRef != nullptr); + + std::string name = *(op->getOpName()); + std::string nameRef = *(opRef->getOpName()); + + ASSERT_EQ(nameExp, nameRef); + ASSERT_EQ(nameRef, name); +} + +TEST_F(JavaInteropTests, Test_FastPath_Validation_1) { + auto x = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + auto z = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + + Context ctx(1); + ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + + sd::ops::softmax op; + auto status = op.execute(&ctx); + ASSERT_NE(Status::OK(), status); +} + +TEST_F(JavaInteropTests, Test_FastPath_Validation_2) { + auto x = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto z = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + + Context ctx(1); + ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + + sd::ops::softmax op; + auto status = op.execute(&ctx); + ASSERT_NE(Status::OK(), status); +} + +TEST_F(JavaInteropTests, Test_FastPath_Validation_3) { + auto x = NDArrayFactory::create('c', {3, 5}, { 0.7788f, 0.8012f, 0.7244f, 0.2309f, 0.7271f, + 0.1804f, 0.5056f, 0.8925f, 0.5461f, 0.9234f, + 0.0856f, 0.7938f, 0.6591f, 0.5555f, 0.1596f}); + + auto min = NDArrayFactory::create({ -0.2283f, -0.0719f, -0.0154f, -0.5162f, -0.3567f}); + auto max = NDArrayFactory::create({ 0.9441f, 0.5957f, 0.8669f, 0.3502f, 0.5100f}); + + auto z = NDArrayFactory::create('c', {3, 5}); + + Context ctx(1); + ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); + ctx.setInputArray(1, min.buffer(), min.shapeInfo(), min.specialBuffer(), min.specialShapeInfo()); + ctx.setInputArray(2, max.buffer(), max.shapeInfo(), max.specialBuffer(), max.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + + sd::ops::fake_quant_with_min_max_vars_per_channel op; + ASSERT_ANY_THROW(op.execute(&ctx)); +} + +TEST_F(JavaInteropTests, Test_empty_cast_1) { + auto x = NDArrayFactory::create('c', {1, 0, 2}); + auto z = NDArrayFactory::create('c', {1, 0, 2}); + auto e = NDArrayFactory::create('c', {1, 0, 2}); + + Nd4jLong iArgs[] = {10}; + + Context ctx(1); + ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + ctx.setIArguments(iArgs, 1); + + sd::ops::cast op; + auto result = op.execute(&ctx); + ASSERT_EQ(Status::OK(), result); + ASSERT_EQ(e, z); +} + +/* +TEST_F(JavaInteropTests, test_avgpooling_edge_1) { + int inOutH = 35; + int inOutW = 35; + int inOutC = 192; + + auto x = NDArrayFactory::create('c', {1, inOutH, inOutW, inOutC}); + auto z = NDArrayFactory::create('c', {1, inOutH, inOutW, inOutC}); + x.linspace(1.0); + z.linspace(1.0); + + NDArray::prepareSpecialUse({&z}, {&x}); + + sd::ops::avgpool2d op; + //auto result = op.execute({&x}, {}, {3,3, 1,1, 0,0, 1,1, 1, 0, 1}); + + Nd4jLong exp[] = {3,3, 1,1, 0,0, 1,1, 1, 0, 1}; + + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), x.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), x.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.buffer(), z.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.shapeInfo(), z.special()}; + + auto result = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, exp, 11, nullptr, 0, false); + + NDArray::registerSpecialUse({&z}, {&x}); + + ASSERT_EQ(Status::OK(), result); + + int totalPadHeight = (inOutH - 1) * 1 + 3 - inOutH; + int padTop = totalPadHeight / 2; + int padBottom = totalPadHeight - totalPadHeight / 2; + + int k = 3; + + auto m = NDArrayFactory::create('c', {1, inOutH, inOutW, inOutC}); + auto c = NDArrayFactory::create('c', {1, inOutH, inOutW, inOutC}); + + for (int h = 0; h < inOutH; h++) { + for (int w = 0; w < inOutW; w++) { + int hFrom = h - padTop; + int wFrom = w - padBottom; + + int hTo = hFrom + k; + int wTo = wFrom + k; + + hFrom = sd::math::nd4j_max(0, hFrom); + wFrom = sd::math::nd4j_max(0, wFrom); + + hTo = sd::math::nd4j_min(inOutH, hTo); + wTo = sd::math::nd4j_min(inOutW, wTo); + + int idxOut[4]; + int idxIn[4]; + for (int ch = 0; ch < inOutC; ch++) { + idxOut[1] = h; + idxOut[2] = w; + idxOut[3] = ch; + idxIn[3] = ch; + + for (int kh = hFrom; kh < hTo; kh++) { + for (int kw = wFrom; kw < wTo; kw++) { + idxIn[1] = kh; + idxIn[2] = kw; + + auto inVal = x.e(0, kh, kw, ch); + m.p(0, h, w, ch, inVal + m.e(0, h, w, ch)); + c.p(0, h, w, ch, 1 + c.e(0, h, w, ch)); + } + } + } + } + } + m /= c; + + //z.printIndexedBuffer("z buffer", 100); + //m.printIndexedBuffer("m buffer", 100); + int cnt = 0; + int lim = 10; + for (int e = 0; e < z.lengthOf() && cnt < lim; e++) { + auto _m = m.e(e); + auto _z = z.e(e); + auto eq = sd::math::nd4j_eq(_m, _z, 1e-5); + if (!eq) { + nd4j_printf("Difference at element e [%i]: <%f> vs <%f>\n", e, _m, _z); + cnt++; + } + } + + ASSERT_EQ(m, z); +} + + +TEST_F(JavaInteropTests, Test_GraphReuse_1) { + uint8_t* data = sd::graph::readFlatBuffers("./resources/reduce_dim_false.fb"); + + registerGraph(nullptr, 119, (Nd4jPointer) data); + + ASSERT_TRUE(GraphHolder::getInstance().hasGraph(119)); + + unregisterGraph(nullptr, 119); + + ASSERT_FALSE(GraphHolder::getInstance().hasGraph(119)); + + + delete[] data; +} + +TEST_F(JavaInteropTests, Test_GraphReuse_2) { + //Environment::getInstance().setDebug(true); + //Environment::getInstance().setVerbose(true); + + auto exp0 = NDArrayFactory::create('c', {3}, {3, 3, 3}); + auto exp1 = NDArrayFactory::create('c', {3}, {6, 6, 6}); + auto exp2 = NDArrayFactory::create('c', {3}, {9, 9, 9}); + + // we load graph from file, because we're not in java here, and dont have buffer ready + uint8_t* data = sd::graph::readFlatBuffers("./resources/reduce_dim_false.fb"); + + // we ensure that there's no such a graph stored earlier + ASSERT_FALSE(GraphHolder::getInstance().hasGraph(119)); + + // register the graph, to call for it later + registerGraph(nullptr, 119, (Nd4jPointer) data); + + // and ensure we're ok + ASSERT_TRUE(GraphHolder::getInstance().hasGraph(119)); + + + + // run stuff + + auto input_0 = NDArrayFactory::create('c', {3, 3}); + input_0.assign(1.0f); + + int idx[] = {1}; + + Nd4jPointer inputs_0[] = {(Nd4jPointer) input_0.buffer()}; + Nd4jPointer shapes_0[] = {(Nd4jPointer) input_0.shapeInfo()}; + + // now we're executing stored graph and providing replacement for input variable + auto res_0 = executeStoredGraph(nullptr, 119, inputs_0, shapes_0, idx, 1); + ASSERT_EQ(ND4J_STATUS_OK, res_0->status()); + ASSERT_EQ(1, res_0->size()); + + auto z0 = res_0->at(0)->getNDArray(); + ASSERT_TRUE(exp0.isSameShape(z0)); + + + auto input_1 = NDArrayFactory::create('c', {3, 3}); + input_1.assign(2.0f); + + Nd4jPointer inputs_1[] = {(Nd4jPointer) input_1.buffer()}; + Nd4jPointer shapes_1[] = {(Nd4jPointer) input_1.shapeInfo()}; + + // doing it again + auto res_1 = executeStoredGraph(nullptr, 119, inputs_1, shapes_1, idx, 1); + ASSERT_EQ(ND4J_STATUS_OK, res_1->status()); + ASSERT_EQ(1, res_1->size()); + + auto z1 = res_1->at(0)->getNDArray(); + ASSERT_TRUE(exp1.isSameShape(z1)); + + + auto input_2 = NDArrayFactory::create('c', {3, 3}); + input_2.assign(3.0f); + + Nd4jPointer inputs_2[] = {(Nd4jPointer) input_2.buffer()}; + Nd4jPointer shapes_2[] = {(Nd4jPointer) input_2.shapeInfo()}; + + // and again + auto res_2 = executeStoredGraph(nullptr, 119, inputs_2, shapes_2, idx, 1); + ASSERT_EQ(ND4J_STATUS_OK, res_1->status()); + ASSERT_EQ(1, res_2->size()); + + auto z2 = res_2->at(0)->getNDArray(); + ASSERT_TRUE(exp2.isSameShape(z2)); + + + //////// clean out + unregisterGraph(nullptr, 119); + + ASSERT_FALSE(GraphHolder::getInstance().hasGraph(119)); + + + delete[] data; + delete res_0; + delete res_1; + delete res_2; +} +*/ + +TEST_F(JavaInteropTests, Test_Greater_1) { + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 1, 2}); + auto y = NDArrayFactory::create('c', {2, 2}, {1, 2, 0, 0}); +// auto o = NDArrayFactory::create('c', {2, 2}, {3, 3, 3, 3}); + auto o = NDArrayFactory::create('c', {2, 2}, {true, true, true, true}); + + auto exp = NDArrayFactory::create('c', {2, 2}, {false, false, true, true}); + + NDArray::prepareSpecialUse({&o}, {&x, &y}); + + sd::ops::greater op; + + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), (Nd4jPointer) y.buffer(), x.specialBuffer(), y.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer) y.shapeInfo(), (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)y.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.buffer(), o.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.shapeInfo(), (Nd4jPointer)o.specialShapeInfo()}; + + execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); + + NDArray::registerSpecialUse({&o}, {&x, &y}); + ASSERT_TRUE(exp.equalsTo(&o)); +} + + +TEST_F(JavaInteropTests, Test_Greater_2) { + auto x = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 1.f, 2.f}); + auto y = NDArrayFactory::create('c', {2, 2}, {1.f, 2.f, 0.f, 0.f}); + auto o = NDArrayFactory::create('c', {2, 2}, {true, true, true, true}); + + auto exp = NDArrayFactory::create('c', {2, 2}, {false, false, true, true}); + + sd::ops::greater op; + + NDArray::prepareSpecialUse({&o}, {&x, &y}); + + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), (Nd4jPointer) y.buffer(), x.specialBuffer(), y.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer) y.shapeInfo(), (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)y.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.buffer(), o.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.shapeInfo(), (Nd4jPointer)o.specialShapeInfo()}; + + execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); + + NDArray::registerSpecialUse({&o}, {&x, &y}); + + ASSERT_TRUE(exp.equalsTo(&o)); +} + +TEST_F(JavaInteropTests, Test_Boolean_Op_1) { + + sd::ops::is_non_decreasing op; + + auto x = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); + auto o = NDArrayFactory::create(false); + auto exp = NDArrayFactory::create(1); + + NDArray::prepareSpecialUse({&o}, {&x}); + + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), x.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer)x.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) o.buffer(), o.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) o.shapeInfo(), (Nd4jPointer)o.specialShapeInfo()}; + + auto hash = op.getOpHash(); + auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); + + NDArray::registerSpecialUse({&o}, {&x}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(exp.equalsTo(&o)); +} + + +TEST_F(JavaInteropTests, Test_Inplace_Outputs_1) { + auto x = NDArrayFactory::create('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto exp = NDArrayFactory::create('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto z = NDArrayFactory::create('c', {2, 3}); + + sd::ops::test_output_reshape op; + + NDArray::prepareSpecialUse({&z}, {&x}); + + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), x.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer)x.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.buffer(), z.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.shapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; + + auto hash = op.getOpHash(); + auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); + + NDArray::registerSpecialUse({&z}, {&x}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + + +TEST_F(JavaInteropTests, Test_Inplace_Outputs_2) { + auto x = NDArrayFactory::create('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto y = NDArrayFactory::create(2.0f); + auto z = NDArrayFactory::create('f', {2, 3}); + auto e = NDArrayFactory::create('c', {2, 3}, {3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + + + sd::ops::add op; + + NDArray::prepareSpecialUse({&z}, {&x, &y}); + + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), (Nd4jPointer) y.buffer(), x.specialBuffer(), y.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer) y.shapeInfo(), (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)y.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.buffer(), z.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.shapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; + + auto hash = op.getOpHash(); + auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); + + NDArray::prepareSpecialUse({&z}, {&x, &y}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); + ASSERT_FALSE(e.ordering() == z.ordering()); +} + +TEST_F(JavaInteropTests, Test_Inplace_Outputs_3) { + auto input = NDArrayFactory::create('c', {2, 3, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}); + auto indices = NDArrayFactory::create('c', {1, 6}, {0,1, 2,2, 1,2}); + auto output = NDArrayFactory::create('f', {2, 1, 6, 4}); + auto e = NDArrayFactory::create('c', {2, 1, 6, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 9,10,11,12, 9,10,11,12, 5, 6, 7, 8, 9,10,11,12, 13,14,15,16, 17,18,19,20, 21,22,23,24, 21,22,23,24, 17,18,19,20, 21,22,23,24}); + + sd::ops::gather op; + + NDArray::prepareSpecialUse({&output}, {&input, &indices}); + + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) input.buffer(), (Nd4jPointer) indices.buffer(), input.specialBuffer(), indices.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) input.shapeInfo(), (Nd4jPointer) indices.shapeInfo(), (Nd4jPointer)input.specialShapeInfo(), (Nd4jPointer)input.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) output.buffer(), output.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) output.shapeInfo(), (Nd4jPointer)output.specialShapeInfo()}; + + Nd4jLong iArgs[] = {1}; + + auto hash = op.getOpHash(); + auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 1, nullptr, 0, false); + + NDArray::registerSpecialUse({&output}, {&input, &indices}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(e.isSameShape(output)); + ASSERT_TRUE(e.equalsTo(output)); + ASSERT_FALSE(e.ordering() == output.ordering()); +} + +TEST_F(JavaInteropTests, Test_Reduce3_EdgeCase) { + auto x = NDArrayFactory::create('c', {3, 4, 5}); + auto y = NDArrayFactory::create('c', {3, 4, 5}); + auto z = NDArrayFactory::create('c', {5}); + + auto dims = NDArrayFactory::create('c', {2}, {0, 1}); + dims.syncToHost(); + + sd::LaunchContext* context = sd::LaunchContext::defaultContext(); + + Nd4jPointer* extraPointers = nullptr; + #ifdef __CUDABLAS__ + extraPointers = new Nd4jPointer[6] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer()}; + #endif + + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), {0,1}); + auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions(y.shapeInfo(), {0,1}); + + NDArray::prepareSpecialUse({&z}, {&x, &y, &dims}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer dimBuf(dims.dataBuffer()); + + execReduce3Tad(extraPointers, 2, &xBuf, x.shapeInfo(), x.specialShapeInfo(), + nullptr, + &yBuf, y.shapeInfo(), y.specialShapeInfo(), + &zBuf, z.shapeInfo(), z.specialShapeInfo(), + &dimBuf, dims.shapeInfo(), dims.specialShapeInfo(), packX.platformShapeInfo(), + packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); + + NDArray::registerSpecialUse({&z}, {&x, &y, &dims}); + + delete []extraPointers; +} + +/* +TEST_F(JavaInteropTests, Test_SimpleIf_Output) { + Environment::getInstance().setDebug(true); + Environment::getInstance().setVerbose(false); + + auto pl = sd::graph::readFlatBuffers("./resources/simpleif_0_1.fb"); + auto ptr = executeFlatGraph(nullptr, pl); + + Environment::getInstance().setDebug(false); + Environment::getInstance().setVerbose(false); + + delete[] pl; + delete ptr; +} +*/ + +TEST_F(JavaInteropTests, Test_AveragePooling_FF_TF_double) { + + auto input = NDArrayFactory::create('c', {4, 10, 10, 3}, {9.37125111, 2.20166993, 2.91434479, 5.43639755, -2.10573769, 4.08528662, 5.86908436, -4.46203756, 2.21057916, 5.35849190, 0.01394637, 4.40566349, 7.07982206, -0.09633455, 2.42429352, 3.97301817, -1.89553940, 1.99690318, 6.33141708, 0.55401880, 1.70707977, 5.55204201, -0.03513752, 1.60011971, 2.62700319, -2.74582434, 3.06697464, 1.06277943, -1.16075921, -0.78095782, 9.72352791, -1.22686064, 1.99644792, 7.35571337, 1.40607321, 0.11390255, 9.53334427, 2.28303599, -1.66728830, 6.16678810, -0.04532295, -1.97708666, 9.74906158, 1.46223176, -1.46734393, 4.30761862, -1.23790228, 1.24823606, 6.13938427, -3.83689475, -1.19625473, 7.91535568, 6.05868721, -3.22946382, 8.81633949, -0.19967777, 0.66053957, 2.30919123, 0.74543846, -0.39347672, 11.11058044, 0.53720862, 1.52645731, 5.70012379, -1.15213466, 1.16451406, 7.00526333, 1.57362783, -2.44384766, 5.54213285, -1.98828590, -0.70483637, 7.88281822, -3.59875536, 0.80745387, 13.41578484, -1.55507684, -0.65855008, 9.32583523, -0.14544789, 0.73436141, 3.61176538, -1.71268058, -2.58490300, 9.09280205, -3.27405524, -2.04569697, 4.44761324, -0.62955856, -2.61917663, 8.04890442, 0.54579324, 0.85929775, 9.82259560, -1.93825579, 0.77703512, 4.67090321, -4.79267597, -2.38906908, 9.31265545, 0.96026313, -1.14109385, 11.54231834, -0.01417295, -0.39500344, 8.49191666, 0.55300158, 2.79490185, 6.92466164, 1.72254205, 2.82222271, 8.83112717, 2.95033407, 2.18054962, 6.73509789, -2.22272944, 0.51127720, -1.04563558, 2.15747333, -2.30959272, 9.55441570, 1.50396204, 1.77370787, 7.38146257, -1.79076433, 3.20961165, 7.18864202, 2.91217351, 0.43018937, 7.11078024, -1.17386127, -0.16817921, 6.12327290, -2.82205725, 3.30696845, 13.51291752, -1.30856836, -2.38332748, 11.09487438, -1.47190213, -0.53050828, 4.38285351, -5.07309771, 1.50714362, 5.72274446, -2.85825086, -0.89673209, 3.73791552, -0.67708802, -4.13149452, -0.00671843, -0.26566532, 0.32961160, 7.14501762, -1.41608179, -4.96590328, 12.26205540, -0.65158135, -0.88641000, 6.95777559, -0.79058206, -0.10260171, 7.87169170, 1.35921454, 1.11759663, 5.46187401, -2.57214499, 2.48484039, 4.04043484, -2.07137156, -1.42709637, 9.25487137, -0.12605135, -2.66949964, 2.89412403, 0.74451172, -2.96250391, 3.99258423, 0.27084303, 0.32213116, 5.42332172, -0.44414216, 1.70881832, 6.69346905, 0.53058422, -4.73146200, 4.22051668, 2.24834967, 0.66996074, 4.30173683, 0.11849818, -4.07520294, 8.27318478, -2.54398274, -2.86705542, 10.11775303, -0.99382895, 0.65881538, 7.93556786, -1.27934420, -1.69343162, 9.68042564, -1.02609646, -1.18189347, 5.75370646, -1.67888868, -4.48871994, 4.79537392, -0.79212248, -0.19855022, 6.15060997, -0.01081491, 3.64454579, 10.82562447, 1.58859253, -2.65847278, 8.60093212, -1.59196103, 0.07635692, 11.76175690, -1.17453325, 0.10122013, 6.86458445, -2.18891335, -2.74004745, 8.07066154, 0.71818852, -2.03035975, 6.31053686, 0.51509416, 1.39789927, 9.43515587, 2.04256630, 0.13985133, 4.65010691, 2.40911126, -0.36255789, -3.06867862, -0.45225358, -1.56778407, 6.05917358, -1.09891272, 1.77184200, 6.46248102, 0.96042323, -0.24346280, 4.63436460, -4.69907761, 1.25187206, 11.46173859, -2.21917558, 1.28007793, 6.92173195, 2.11268163, -3.47389889, 5.08722782, -3.03950930, -4.17154264, 11.30568314, 0.80361372, 2.53214502, 7.18707085, -4.49114513, 2.85449266, 10.14906883, -0.31974933, -0.84472644, -0.52459574, 0.12921631, -1.81390119, 2.76170087, 1.03982210, 2.91744232, -0.29048753, 5.87453508, -1.53684759, 1.85800636, -0.91404629, 1.28954852, 5.11354685, -2.47475505, -1.33179152, 2.58552408, 1.37316465, -3.32339454, 1.54122913, 3.24953628, -0.29758382, 2.82391763, -1.51142192, -1.22699404, 6.75745535, 0.65452754, -3.29385471, 2.06008053, 2.53172946, -4.23532820, -1.53909743, -0.07010663, -1.42173731, 7.29031610, -0.18448229, 4.59496164, 6.73027277, 0.73441899, 0.14426160, 4.14915276, -2.97010231, 6.05851364, 4.95218086, -2.39145470, 2.40494704, 2.10288811, 0.53503096, 1.44511235, 6.66344261, -3.05803776, 7.21418667, 3.30303526, -0.24163735, 3.47409391, 3.64520788, 2.15189481, -3.11243272, 3.62310791, 0.37379482, 0.40865007, -0.83132005, -4.78246069, 2.07030797, 6.51765442, 3.16178989, 5.06180477, 3.78434467, -0.96689719, 0.35965276, 5.89967585, 1.40294051, 1.11952639, 10.59778214, 0.26739889, -1.61297631, 6.24801159, -0.93914318, -0.57812452, 9.92604542, -0.73025000, -3.38530874, 2.45646000, -2.47949195, 0.51638460, 10.65636063, 1.97816694, -3.00407791, 2.66914415, -0.81951088, -0.23316640, 2.40737987, -2.70007610, 1.51531935, 4.08860207, -0.27552786, -1.31721711, 7.11568260, -3.33498216, -4.02545023, 7.22675610, -0.81690705, -2.52689576, 1.04016697, -0.79291463, -0.34875512, 10.00498390, -4.24167728, 1.46162593, 11.82569408, -1.70359993, -0.30161047, 16.44085884, -0.82253462, -0.09435523, 6.13080597, -0.20259480, 0.68308711, 6.15663004, -6.61776876, 0.33295766, 2.55449438, -0.17819691, -1.14892209, 5.56776142, 1.99279118, 1.33035934, 4.45823956, 3.34916544, -2.59905386, 6.16164446, -2.03881931, -2.45273542, 12.46793365, -2.22743297, 2.83738565, 8.48628139, -1.39347959, -1.30867767, 11.08041477, -4.00363779, 2.09183025, 11.30395889, -2.20504737, 1.37426853, 8.98735619, 1.04676604, -0.72757077, 8.28050232, -6.70741081, -0.65798020, 5.68592072, -0.60760021, 0.35854483, 6.26852131, 1.94100165, 1.32112014, 0.80987954, -1.74617672, -0.25434083, 7.16045523, 1.58884013, -2.64847064, 13.14820385, 1.21393633, -2.47258949, 9.41650105, -0.79384226, 2.48954105, 10.95629311, 0.47723705, 4.02126694, 8.02593136, -2.20726371, -1.18794477, 1.50836647, 0.93118095, -1.73513174, 8.85493565, -2.99670315, -0.79055870, 2.39473820, 2.05046916, -2.38055134, 11.82299423, 0.15609655, 0.68744308, 5.66401434, -0.69281673, 2.09855556, 7.74626589, -0.34283102, 1.00542057, 9.95838642, 0.80161905, 2.33455157, 9.80057335, -0.93561798, 2.56991577, 8.29711342, 0.94213426, 0.44209945, 11.70259857, 0.92710167, 2.60957146, 0.24971688, -0.86529571, 3.78628922, 6.80884457, -0.68178189, 2.21103406, 3.18895817, 0.60283208, -2.92716241, 6.72060776, -1.06625068, 2.56543374, 9.97404480, 3.58080721, -0.94936347, 10.16736984, -1.38464379, 1.18191063, 6.66179037, -3.56115270, 0.32329530, 10.90870762, 2.20638227, 0.19653285, 7.34650040, -3.63859272, -1.03027737, 5.98829985, -3.66606474, -3.89746714, 8.63469028, 1.22569811, 1.63240814, 3.74385309, 0.58243257, -0.56981975, 3.69260955, 1.00979900, -1.44030499, 8.57058144, -1.10648811, 1.20474911, 5.43133020, -2.14822555, -0.07928789, 11.25825310, 0.19645604, -5.49546146, 10.41917038, -0.68178523, -2.99639869, 6.50054455, 0.46488351, -5.42328453, 9.09500027, -2.82107449, 0.05601966, 15.34610748, -0.06820253, 3.86699796, 10.73316956, -3.04795432, -0.14702171, 5.64813185, 1.44028485, -2.47596145, 0.07280898, -3.03187990, -1.35183525, 9.35835648, 2.72966957, 1.88199532, 10.36187744, -0.22834805, -3.26738238, 6.92025137, -2.34061313, 4.77379704, 5.28559113, -2.96323752, -1.76186585, 5.94436455, 0.38647744, -5.73869514, 6.76849556, 1.40892124, -1.19068217, 5.37919092, -6.65328646, 3.62782669, 12.34744644, 2.44762444, -4.19242620, 6.14906216, 0.08121119, 0.61355996, 2.69666457, -1.88962626, -0.55314136, 1.84937525, 1.56048691, 1.17460012, 3.75674725, 1.06198275, -5.74625874, 5.41645575, -1.28946674, -1.51689398, 4.32400894, -0.05222082, -4.83948946, 1.80747867, 1.63144708, -2.73887825, 1.63975775, -2.02163982, -0.16210437, 2.93518686, 1.14427686, -2.83246303, 4.79283667, 2.69697428, -3.12678456, -1.19225168, -2.37022972, -3.09429741, 1.94225383, -1.13747168, -2.55048585, 5.40242243, 1.12777328, 3.43713188, 3.62658787, -2.16878843, 0.30164462, 2.97407579, -0.07275413, -1.31149673, 4.70066261, -2.01323795, 4.85255766, 4.59128904, 1.68084168, 1.60336494, 6.58138466, -1.04759812, 2.69906545, 3.55769277, -0.74327278, 2.65819693, 5.39528131, 2.11248922, -1.06446671, 5.24546766, -2.43146014, 4.58907509, 0.06521678, -2.24503994, 2.45722699, 6.94863081, 0.35258654, 2.83396196, 9.92525196, -1.12225175, -0.34365177, 7.19116688, -4.39813757, 0.46517885, 13.22028065, -2.57483673, -6.37226963, 7.58046293, -2.74600363, 0.42231262, 8.04881668, 0.17289802, -0.53447008, 16.55157471, -5.63614368, 0.39288223, 3.37079263, 1.26484549, -0.12820500, 8.46440125, -4.39304399, 2.97676420, 0.65650189, 0.83158541, -1.11556435, 6.32885838, -0.36087769, 2.80724382, 9.90292645, 1.15936041, 0.20947981, 6.91249275, -2.67404819, 2.93782163, 6.65656614, -2.30828357, 2.98214006, 6.80611229, -4.93821478, -7.66555262, 7.59763002, -0.54159302, 3.87403512, 12.42607784, 2.59284401, -0.23375344, 8.95293331, -0.71807784, 0.61873478, 8.66713524, 1.24289191, -2.37835455, 2.08071637, -0.88315344, -3.41891551, 6.85245323, 1.73007369, 1.02169311, 7.69170332, -2.85411978, 2.69790673, 8.12906551, -1.19351399, -2.26442742, 12.26104450, -0.75579089, -1.73274946, 10.68729019, 2.20655656, -0.90522075, 12.42165184, -1.67929137, 2.44851565, 9.31565762, -0.06645700, 1.52762020, 6.18427515, -1.68882596, 3.70261097, 3.02252960, -3.44125366, -1.31575799, 2.84617424, -0.96849400, -4.52356243, 9.95027161, 0.19966406, -0.78874779, 8.18595028, -4.08300209, 1.75126517, 0.96418417, -4.04913044, -0.95200396, 12.03637886, -0.03041124, 0.41642749, 8.88267422, -3.24985337, -2.24919462, 7.32566118, 0.16964148, -2.74123430, 7.05264473, -3.30191112, 0.17163286, 4.81851053, -1.64463484, -0.85933101, 7.29276276, 2.34066939, -2.14860010, 3.46148157, -0.01782012, 1.51504040, 4.79304934, 1.85281146, -1.70663762, 6.93470192, -4.15440845, -1.25983095, 10.52491760, 0.42930329, -1.85146868, 11.70042324, -0.41704914, 3.83796859, 9.21148491, -2.79719448, 0.79470479, 6.26926661, -5.85230207, 3.95105338, 7.84790897, -1.38680744, -1.78099084, 11.95235348, -2.99841452, -1.34507811, 6.15714645, -1.07552516, -2.81228638, 1.66234732, -4.55166149, -1.92601109, 8.64634514, -0.48158705, 3.31595659, 7.67371941, 2.56964207, 0.12107098, 4.56467867, -0.93541539, 1.39432955, 11.99714088, 1.05353570, -2.13099813, 3.67617917, 3.45895386, 1.37365830, 8.74344158, -4.17585802, 1.43908918, 6.28764772, 3.97346330, -0.69144285, 9.07983303, -0.41635889, -0.14965028, 8.85469818, 1.11306190, 2.59440994, 5.38982344, -1.07948279, 1.37252975, 10.26984596, -0.09318046, 2.73104119, 12.45902252, -1.55446684, -2.76124811, 12.19395065, -0.51846564, 1.02764034, 11.42673588, -0.95940983, -0.04781032, 8.78379822, -4.88957930, 0.32534006, 11.97696400, -3.35108662, 1.95104563, 4.46915388, -2.32061648, 3.45230985, 8.29983711, 2.81034684, -2.35529327, 6.07801294, -0.98105043, -0.05359888, 2.52291036, -0.01986909, -2.35321999, 10.51954269, 2.11145401, 3.53506470, 7.29093266, 0.03721160, -1.13496494, 7.43886709, -5.84201956, 2.50796294, 12.14647675, 2.77490377, -2.18896222, 6.05641937, 5.32617044, 1.04221284, 10.79106712, -2.95749092, -2.75414610, 11.30037117, -3.40654182, -2.24673963, 7.49126101, 0.70811015, -6.18003702, 13.83951187, -1.01204085, 1.36298490, -1.04451632, 2.42435336, -0.02346706, -0.85528886, 1.04731262, 0.22192979, 4.15708160, 0.34933877, 0.04814529, 2.24107265, 0.49676740, -1.47752666, 0.45040059, -0.70471478, -1.19759345, 0.21711677, 0.88461423, -2.76830935, 5.52066898, 1.97664857, -1.75381601, 3.45877838, 1.52617192, -1.61350942, 0.85337949, 1.97610760, -3.40310287, 3.40319014, -3.38691044, -0.71319139, 1.65463758, -0.60680127, -1.80700517, 8.02592373, 2.59627104, 2.65895891, 5.93043184, -4.48425817, 3.92670918, 4.19496679, -2.28286791, 6.41634607, 5.72330523, 1.16269672, -0.28753027, 2.46342492, 0.36693189, 0.26712441, 6.37652683, -2.50139046, 2.43923736, 5.56310415, 0.98065847, 1.04267502, 4.16403675, -0.04966142, 4.40897894, 3.72905660, -3.46129870, 3.59962773, 1.34830284, -1.76661730, 0.47943926, 5.29946661, -1.12711561, 1.26970029, 15.17655945, -1.50971997, 5.81345224, 8.48562050, -4.36049604, 2.48144460, 8.23780441, -3.46030426, -0.84656560, 5.94946814, 1.12747943, -2.65683913, 8.69085693, 1.31309867, -2.79958344, 8.76840591, -1.56444156, 1.62710834, 2.41177034, -0.72804940, 5.70619011, 4.67169666, -0.86167198, -1.83803177, 2.96346045, 2.82692933, -2.81557131, 7.11113358, -1.90071094, 2.54244423, 11.19284058, -0.06298946, -1.71517313, 12.98388577, 0.84510714, 3.00816894, 2.57200313, 0.03899818, -1.49330592, 9.60099125, -3.59513044, -1.30045319, 7.09241819, -0.65233821, -2.33627677, 8.81366920, 0.84154201, 1.03312039, 9.85289097, 0.19351870, 1.78496623, 7.34631205, -2.16530800, -0.65016162, 2.46842360, 0.24016285, -1.24308395, 4.78175163, -0.97682536, 2.20942235, 6.68382788, 3.76786447, -1.44454038, 6.26453733, -3.23575711, -2.30137897, 9.53092670, -5.55222607, 3.25999236, 9.37559509, 1.86339056, -0.23551451, 10.23400211, 3.93031883, -0.52629089, 7.85724449, -2.91549587, 4.46612740, 5.66530371, -2.70820427, 4.81359577, 10.31247330, 1.92230141, 2.53931546, 0.74986327, 1.70303428, 0.48063779, 5.31099129, -0.78976244, 3.75864220, 4.23051405, 2.34042454, -7.98193836, 9.83987141, -1.46722627, 3.54497814, 10.36455154, -4.51249075, 0.77715248, 7.78694630, -4.59989023, -2.49585629, 9.90296268, 1.38535416, 1.17441154, 10.10452843, -0.98628229, 0.60194463, 9.12639141, -3.90754628, 2.88526392, 7.24123430, -0.15283313, -0.75728363, -1.15116858, -2.53791571, 0.77229571, 6.44114161, 0.02646767, 4.95463037, 7.21066380, 1.79384065, 0.73250306, 8.04447937, 0.32576546, -0.79447043, 10.12717724, 2.33392906, 1.30716443, 12.36073112, -0.36694977, -1.20438910, 7.03105593, 0.59557682, 0.69267452, 10.18113136, 2.49944925, -0.42229167, 8.83143330, -1.18805945, -2.87509322, 4.53596449, 4.09732771, -3.39088297, -1.02536607, 0.82119560, -3.47302604, 9.29991817, 0.21001509, 4.97036457, 9.50018406, 1.04420102, 1.96560478, 10.74769592, -6.22709799, 3.11690164, 5.06759691, -1.23724771, -3.05831861, 8.12925529, -1.93435478, -1.10151744, 9.32263088, -0.04249470, -5.98547363, 10.49398136, 0.26400441, -0.78915191, 13.28219604, 2.99276900, 0.74853164, 2.49364305, -3.43529654, 4.05278301, 2.13498688, -2.35444307, -0.79900265, 4.66968822, -0.31095147, 3.60674143, 12.37222099, -0.07855003, -3.30292702, 12.15215874, 0.60886210, 2.87075138, 7.75271845, 0.38044083, 3.34402204, 6.40583277, -0.87888050, 0.67438459, 6.91080809, 1.98332930, -0.08303714, 8.08630371, -0.16772588, -2.74058914, 7.17253590, -2.69122696, 1.48173678, 8.99470139, -1.43302310, -0.88651133, 2.66944790, -0.29186964, 2.00838661, 5.09587479, -0.76676071, -2.88322186, 8.31110573, -0.14550979, -1.37726915, 10.28355122, -1.60575438, -0.04118848, 9.97510815, 0.14440438, -3.24632120, 9.00034523, 4.14319563, -1.31023729, 7.16950464, -0.70428526, 2.01559544, 7.26155043, 2.40816474, 2.09847403, 7.31264496, -0.75401551, 2.13392544, 7.03648758, 1.04036045, -1.15636516, 1.09634531, -0.06340861, -0.58107805, -0.65623116, 1.18972754, -0.80717683, 1.40118241, -0.61932516, -3.60596156, 1.59904599, -2.23774099, -1.13721037, 3.89620137, -0.09115922, -7.51356888, 2.36975193, -1.42520905, -2.34173775, 3.33830214, -2.74016523, -3.04115510, 6.00119495, -1.36084354, -2.45065260, 4.56992292, -3.02825928,-3.74182844,5.11069250,-0.91531068,-2.31385994,1.83399653,3.39370203,-3.60886002}); + auto z = NDArrayFactory::create('c', {4, 4, 4, 3}); + auto exp = NDArrayFactory::create('c', {4, 4, 4, 3}, {7.97172260, 0.06878620, 2.27749538, 7.29276514, -0.14074677, 0.65480286, 5.70313978, -0.06546132, 0.35443667, 3.70382833, -0.84020567, 0.63826996, 8.60301399, -0.38236514, 1.55177069, 7.37542057, -0.99374938, -0.29971302, 8.84352493, -0.67121059, 0.43132120, 4.78175592, -1.25070143, -1.91523600, 6.03855371, -0.00292124, -1.11214364, 7.90158176, -0.57949901, -0.96735370, 7.81192017, -0.53255427, -0.48009714, 3.16953635, 0.08353355, -1.54299748, 3.74821687, 1.69396687, 0.72724354, 5.42915201, -1.13686812, -0.71793109, 5.78376389, -0.72239977, -0.60055625, 2.53636408, 0.56777251, -2.07892323, 6.08064651, 0.68620735, 2.54017019, 5.65828180, -0.68255502, 1.47283304, 6.10842514, -0.39655915, 0.28380761, 1.96707797, -1.98206317, 0.94027776, 4.71811438, 0.32104525, -0.92409706, 8.34588146, -1.05581069, -0.55217457, 9.58440876, -0.96549922, 0.45820439, 5.65453672, -2.50953507, -0.71441835, 8.03059578, -0.21281289, 0.92125505, 9.26900673, -0.35963219, -0.70039093, 8.59924412, -1.22358346, 0.81318003, 3.85920119, -0.01305223, -1.09234154, 6.33158875, 1.28094780, -1.48926139, 4.94969177, -0.77126902, -1.97033751, 5.64381838, -0.16285487, -1.31277227, 2.39893222, -1.32902908, -1.39609122, 6.47572327, -0.45267010, 1.55727172, 6.70965624, -1.68735468, -0.05672536, 7.25092363, -0.64613032, 0.67050058, 3.60789680, -2.05948973, 2.22687531, 8.15202713, -0.70148355, 1.28314006, 8.14842319, -1.88807654, -1.04808438, 8.45500565, -0.76425624, 0.94542569, 4.56179953, -0.28786001, -2.04502511, 8.46278095, -0.31019822, 0.07339200, 9.34214592, -0.61948007, 0.52481830, 8.32515621, -1.52418160, 0.49678251, 5.11082315, -1.09908783, -0.52969611, 5.27806664, 0.88632923, 0.66754371, 4.75839233, 0.48928693, -0.68036932, 6.56925392, -0.02949905, -2.99189186, 4.46320581, -0.64534980, -0.29516968, 8.60809517, -1.13120568, 3.41720533, 5.84243155, -1.24109328, 0.89566326, 5.99578333, -0.42496428, 2.07076764, 3.17812920, -0.81566459, -0.14363396, 6.55184317, 0.39633346, -0.43852386, 8.70214558, -2.24613595, 0.30708700, 8.73882294, -0.53545928, 1.54409575, 4.49452257, -0.16509305, 0.19028664, 8.24897003, 0.44750381, 2.15448594, 8.97640514, -0.77728152, 0.57272542, 9.03467560, 0.47173575, -1.10807717, 3.30056310, -0.43268481, -0.41470885, 3.53798294, -0.08546703, -2.16840744, 6.18733406, -0.17871059, -2.59837723, 5.94218683, -1.02990067, -0.49760687, 3.76938033, 0.86383581, -1.91504073}); + + sd::ops::avgpool2d op; + + NDArray::prepareSpecialUse({&z}, {&input}); + + Nd4jPointer ptrsInBuffer[] = {reinterpret_cast(input.buffer()), input.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)input.shapeInfo(), (Nd4jPointer)input.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {reinterpret_cast(z.buffer()), z.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z.shapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; + + Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1}; + + auto hash = op.getOpHash(); + auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false); + + NDArray::registerSpecialUse({&z}, {&input}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +TEST_F(JavaInteropTests, Test_MaxPool2D_float_1) { + auto input = NDArrayFactory::create('c', {1, 1, 4, 5}); + auto z = NDArrayFactory::create('c', {1, 1, 4, 5}); + + input.linspace(1.0); + + NDArray::prepareSpecialUse({&z}, {&input}); + + Nd4jPointer ptrsInBuffer[] = {reinterpret_cast(input.buffer()), input.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)input.shapeInfo(), (Nd4jPointer)input.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {reinterpret_cast(z.buffer()), z.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z.shapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; + + Nd4jLong iArgs[] = {2,2, 1,1, 1,1, 2,2,1, 0,0}; + + sd::ops::maxpool2d op; + + auto hash = op.getOpHash(); + auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false); + + NDArray::registerSpecialUse({&z}, {&input}); + ASSERT_EQ(Status::OK(), status); +} + +TEST_F(JavaInteropTests, Test_Unstack_1) { + auto x = NDArrayFactory::create('c', {5, 5}); + x.linspace(1.0); + auto z0 = NDArrayFactory::create('c',{5}); + auto z1 = NDArrayFactory::create('c',{5}); + auto z2 = NDArrayFactory::create('c',{5}); + auto z3 = NDArrayFactory::create('c',{5}); + auto z4 = NDArrayFactory::create('c',{5}); + + NDArray::prepareSpecialUse({&z0, &z1, &z2, &z3, &z4}, {&x}); + + Nd4jPointer ptrsInBuffer[] = {reinterpret_cast(x.buffer()), x.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)x.shapeInfo(), (Nd4jPointer)x.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {z0.buffer(), z1.buffer(), z2.buffer(), z3.buffer(), z4.buffer(), z0.specialBuffer(), z1.specialBuffer(), z2.specialBuffer(), z3.specialBuffer(), z4.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z0.shapeInfo(), (Nd4jPointer)z1.shapeInfo(), (Nd4jPointer)z2.shapeInfo(), + (Nd4jPointer)z3.shapeInfo(), (Nd4jPointer)z4.shapeInfo(), (Nd4jPointer)z0.specialShapeInfo(), + (Nd4jPointer)z1.specialShapeInfo(), (Nd4jPointer)z2.specialShapeInfo(), + (Nd4jPointer)z3.specialShapeInfo(), (Nd4jPointer)z4.specialShapeInfo()}; + + Nd4jLong iArgs[] = {0}; + + sd::ops::unstack op; + + auto hash = op.getOpHash(); + auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 5, nullptr, 0, iArgs, 1, nullptr, 0, false); + + NDArray::registerSpecialUse({&z0, &z1, &z2, &z3, &z4}, {&x}); + ASSERT_EQ(Status::OK(), status); +} + +TEST_F(JavaInteropTests, Test_AveragePooling_FF_TF_float) { + + auto input = NDArrayFactory::create('c', {4, 10, 10, 3}, {9.37125111f, 2.20166993f,2.91434479f,5.43639755f,-2.10573769f, 4.08528662f,5.86908436f,-4.46203756f,2.21057916f,5.35849190f,0.01394637f, 4.40566349f, 7.07982206f, -0.09633455f, 2.42429352f, 3.97301817f, -1.89553940f, 1.99690318f, 6.33141708f, 0.55401880f, 1.70707977f, 5.55204201f, -0.03513752f, 1.60011971f, 2.62700319f, -2.74582434f, 3.06697464f, 1.06277943f, -1.16075921f, -0.78095782f, 9.72352791f, -1.22686064f, 1.99644792f, 7.35571337f, 1.40607321f, 0.11390255f, 9.53334427f, 2.28303599f, -1.66728830f, 6.16678810f, -0.04532295f, -1.97708666f, 9.74906158f, 1.46223176f, -1.46734393f, 4.30761862f, -1.23790228f, 1.24823606f, 6.13938427f, -3.83689475f, -1.19625473f, 7.91535568f, 6.05868721f, -3.22946382f, 8.81633949f, -0.19967777f, 0.66053957f, 2.30919123f, 0.74543846f, -0.39347672f, 11.11058044f, 0.53720862f, 1.52645731f, 5.70012379f, -1.15213466f, 1.16451406f, 7.00526333f, 1.57362783f, -2.44384766f, 5.54213285f, -1.98828590f, -0.70483637f, 7.88281822f, -3.59875536f, 0.80745387f, 13.41578484f, -1.55507684f, -0.65855008f, 9.32583523f, -0.14544789f, 0.73436141f, 3.61176538f, -1.71268058f, -2.58490300f, 9.09280205f, -3.27405524f, -2.04569697f, 4.44761324f, -0.62955856f, -2.61917663f, 8.04890442f, 0.54579324f, 0.85929775f, 9.82259560f, -1.93825579f, 0.77703512f, 4.67090321f, -4.79267597f, -2.38906908f, 9.31265545f, 0.96026313f, -1.14109385f, 11.54231834f, -0.01417295f, -0.39500344f, 8.49191666f, 0.55300158f, 2.79490185f, 6.92466164f, 1.72254205f, 2.82222271f, 8.83112717f, 2.95033407f, 2.18054962f, 6.73509789f, -2.22272944f, 0.51127720f, -1.04563558f, 2.15747333f, -2.30959272f, 9.55441570f, 1.50396204f, 1.77370787f, 7.38146257f, -1.79076433f, 3.20961165f, 7.18864202f, 2.91217351f, 0.43018937f, 7.11078024f, -1.17386127f, -0.16817921f, 6.12327290f, -2.82205725f, 3.30696845f, 13.51291752f, -1.30856836f, -2.38332748f, 11.09487438f, -1.47190213f, -0.53050828f, 4.38285351f, -5.07309771f, 1.50714362f, 5.72274446f, -2.85825086f, -0.89673209f, 3.73791552f, -0.67708802f, -4.13149452f, -0.00671843f, -0.26566532f, 0.32961160f, 7.14501762f, -1.41608179f, -4.96590328f, 12.26205540f, -0.65158135f, -0.88641000f, 6.95777559f, -0.79058206f, -0.10260171f, 7.87169170f, 1.35921454f, 1.11759663f, 5.46187401f, -2.57214499f, 2.48484039f, 4.04043484f, -2.07137156f, -1.42709637f, 9.25487137f, -0.12605135f, -2.66949964f, 2.89412403f, 0.74451172f, -2.96250391f, 3.99258423f, 0.27084303f, 0.32213116f, 5.42332172f, -0.44414216f, 1.70881832f, 6.69346905f, 0.53058422f, -4.73146200f, 4.22051668f, 2.24834967f, 0.66996074f, 4.30173683f, 0.11849818f, -4.07520294f, 8.27318478f, -2.54398274f, -2.86705542f, 10.11775303f, -0.99382895f, 0.65881538f, 7.93556786f, -1.27934420f, -1.69343162f, 9.68042564f, -1.02609646f, -1.18189347f, 5.75370646f, -1.67888868f, -4.48871994f, 4.79537392f, -0.79212248f, -0.19855022f, 6.15060997f, -0.01081491f, 3.64454579f, 10.82562447f, 1.58859253f, -2.65847278f, 8.60093212f, -1.59196103f, 0.07635692f, 11.76175690f, -1.17453325f, 0.10122013f, 6.86458445f, -2.18891335f, -2.74004745f, 8.07066154f, 0.71818852f, -2.03035975f, 6.31053686f, 0.51509416f, 1.39789927f, 9.43515587f, 2.04256630f, 0.13985133f, 4.65010691f, 2.40911126f, -0.36255789f, -3.06867862f, -0.45225358f, -1.56778407f, 6.05917358f, -1.09891272f, 1.77184200f, 6.46248102f, 0.96042323f, -0.24346280f, 4.63436460f, -4.69907761f, 1.25187206f, 11.46173859f, -2.21917558f, 1.28007793f, 6.92173195f, 2.11268163f, -3.47389889f, 5.08722782f, -3.03950930f, -4.17154264f, 11.30568314f, 0.80361372f, 2.53214502f, 7.18707085f, -4.49114513f, 2.85449266f, 10.14906883f, -0.31974933f, -0.84472644f, -0.52459574f, 0.12921631f, -1.81390119f, 2.76170087f, 1.03982210f, 2.91744232f, -0.29048753f, 5.87453508f, -1.53684759f, 1.85800636f, -0.91404629f, 1.28954852f, 5.11354685f, -2.47475505f, -1.33179152f, 2.58552408f, 1.37316465f, -3.32339454f, 1.54122913f, 3.24953628f, -0.29758382f, 2.82391763f, -1.51142192f, -1.22699404f, 6.75745535f, 0.65452754f, -3.29385471f, 2.06008053f, 2.53172946f, -4.23532820f, -1.53909743f, -0.07010663f, -1.42173731f, 7.29031610f, -0.18448229f, 4.59496164f, 6.73027277f, 0.73441899f, 0.14426160f, 4.14915276f, -2.97010231f, 6.05851364f, 4.95218086f, -2.39145470f, 2.40494704f, 2.10288811f, 0.53503096f, 1.44511235f, 6.66344261f, -3.05803776f, 7.21418667f, 3.30303526f, -0.24163735f, 3.47409391f, 3.64520788f, 2.15189481f, -3.11243272f, 3.62310791f, 0.37379482f, 0.40865007f, -0.83132005f, -4.78246069f, 2.07030797f, 6.51765442f, 3.16178989f, 5.06180477f, 3.78434467f, -0.96689719f, 0.35965276f, 5.89967585f, 1.40294051f, 1.11952639f, 10.59778214f, 0.26739889f, -1.61297631f, 6.24801159f, -0.93914318f, -0.57812452f, 9.92604542f, -0.73025000f, -3.38530874f, 2.45646000f, -2.47949195f, 0.51638460f, 10.65636063f, 1.97816694f, -3.00407791f, 2.66914415f, -0.81951088f, -0.23316640f, 2.40737987f, -2.70007610f, 1.51531935f, 4.08860207f, -0.27552786f, -1.31721711f, 7.11568260f, -3.33498216f, -4.02545023f, 7.22675610f, -0.81690705f, -2.52689576f, 1.04016697f, -0.79291463f, -0.34875512f, 10.00498390f, -4.24167728f, 1.46162593f, 11.82569408f, -1.70359993f, -0.30161047f, 16.44085884f, -0.82253462f, -0.09435523f, 6.13080597f, -0.20259480f, 0.68308711f, 6.15663004f, -6.61776876f, 0.33295766f, 2.55449438f, -0.17819691f, -1.14892209f, 5.56776142f, 1.99279118f, 1.33035934f, 4.45823956f, 3.34916544f, -2.59905386f, 6.16164446f, -2.03881931f, -2.45273542f, 12.46793365f, -2.22743297f, 2.83738565f, 8.48628139f, -1.39347959f, -1.30867767f, 11.08041477f, -4.00363779f, 2.09183025f, 11.30395889f, -2.20504737f, 1.37426853f, 8.98735619f, 1.04676604f, -0.72757077f, 8.28050232f, -6.70741081f, -0.65798020f, 5.68592072f, -0.60760021f, 0.35854483f, 6.26852131f, 1.94100165f, 1.32112014f, 0.80987954f, -1.74617672f, -0.25434083f, 7.16045523f, 1.58884013f, -2.64847064f, 13.14820385f, 1.21393633f, -2.47258949f, 9.41650105f, -0.79384226f, 2.48954105f, 10.95629311f, 0.47723705f, 4.02126694f, 8.02593136f, -2.20726371f, -1.18794477f, 1.50836647f, 0.93118095f, -1.73513174f, 8.85493565f, -2.99670315f, -0.79055870f, 2.39473820f, 2.05046916f, -2.38055134f, 11.82299423f, 0.15609655f, 0.68744308f, 5.66401434f, -0.69281673f, 2.09855556f, 7.74626589f, -0.34283102f, 1.00542057f, 9.95838642f, 0.80161905f, 2.33455157f, 9.80057335f, -0.93561798f, 2.56991577f, 8.29711342f, 0.94213426f, 0.44209945f, 11.70259857f, 0.92710167f, 2.60957146f, 0.24971688f, -0.86529571f, 3.78628922f, 6.80884457f, -0.68178189f, 2.21103406f, 3.18895817f, 0.60283208f, -2.92716241f, 6.72060776f, -1.06625068f, 2.56543374f, 9.97404480f, 3.58080721f, -0.94936347f, 10.16736984f, -1.38464379f, 1.18191063f, 6.66179037f, -3.56115270f, 0.32329530f, 10.90870762f, 2.20638227f, 0.19653285f, 7.34650040f, -3.63859272f, -1.03027737f, 5.98829985f, -3.66606474f, -3.89746714f, 8.63469028f, 1.22569811f, 1.63240814f, 3.74385309f, 0.58243257f, -0.56981975f, 3.69260955f, 1.00979900f, -1.44030499f, 8.57058144f, -1.10648811f, 1.20474911f, 5.43133020f, -2.14822555f, -0.07928789f, 11.25825310f, 0.19645604f, -5.49546146f, 10.41917038f, -0.68178523f, -2.99639869f, 6.50054455f, 0.46488351f, -5.42328453f, 9.09500027f, -2.82107449f, 0.05601966f, 15.34610748f, -0.06820253f, 3.86699796f, 10.73316956f, -3.04795432f, -0.14702171f, 5.64813185f, 1.44028485f, -2.47596145f, 0.07280898f, -3.03187990f, -1.35183525f, 9.35835648f, 2.72966957f, 1.88199532f, 10.36187744f, -0.22834805f, -3.26738238f, 6.92025137f, -2.34061313f, 4.77379704f, 5.28559113f, -2.96323752f, -1.76186585f, 5.94436455f, 0.38647744f, -5.73869514f, 6.76849556f, 1.40892124f, -1.19068217f, 5.37919092f, -6.65328646f, 3.62782669f, 12.34744644f, 2.44762444f, -4.19242620f, 6.14906216f, 0.08121119f, 0.61355996f, 2.69666457f, -1.88962626f, -0.55314136f, 1.84937525f, 1.56048691f, 1.17460012f, 3.75674725f, 1.06198275f, -5.74625874f, 5.41645575f, -1.28946674f, -1.51689398f, 4.32400894f, -0.05222082f, -4.83948946f, 1.80747867f, 1.63144708f, -2.73887825f, 1.63975775f, -2.02163982f, -0.16210437f, 2.93518686f, 1.14427686f, -2.83246303f, 4.79283667f, 2.69697428f, -3.12678456f, -1.19225168f, -2.37022972f, -3.09429741f, 1.94225383f, -1.13747168f, -2.55048585f, 5.40242243f, 1.12777328f, 3.43713188f, 3.62658787f, -2.16878843f, 0.30164462f, 2.97407579f, -0.07275413f, -1.31149673f, 4.70066261f, -2.01323795f, 4.85255766f, 4.59128904f, 1.68084168f, 1.60336494f, 6.58138466f, -1.04759812f, 2.69906545f, 3.55769277f, -0.74327278f, 2.65819693f, 5.39528131f, 2.11248922f, -1.06446671f, 5.24546766f, -2.43146014f, 4.58907509f, 0.06521678f, -2.24503994f, 2.45722699f, 6.94863081f, 0.35258654f, 2.83396196f, 9.92525196f, -1.12225175f, -0.34365177f, 7.19116688f, -4.39813757f, 0.46517885f, 13.22028065f, -2.57483673f, -6.37226963f, 7.58046293f, -2.74600363f, 0.42231262f, 8.04881668f, 0.17289802f, -0.53447008f, 16.55157471f, -5.63614368f, 0.39288223f, 3.37079263f, 1.26484549f, -0.12820500f, 8.46440125f, -4.39304399f, 2.97676420f, 0.65650189f, 0.83158541f, -1.11556435f, 6.32885838f, -0.36087769f, 2.80724382f, 9.90292645f, 1.15936041f, 0.20947981f, 6.91249275f, -2.67404819f, 2.93782163f, 6.65656614f, -2.30828357f, 2.98214006f, 6.80611229f, -4.93821478f, -7.66555262f, 7.59763002f, -0.54159302f, 3.87403512f, 12.42607784f, 2.59284401f, -0.23375344f, 8.95293331f, -0.71807784f, 0.61873478f, 8.66713524f, 1.24289191f, -2.37835455f, 2.08071637f, -0.88315344f, -3.41891551f, 6.85245323f, 1.73007369f, 1.02169311f, 7.69170332f, -2.85411978f, 2.69790673f, 8.12906551f, -1.19351399f, -2.26442742f, 12.26104450f, -0.75579089f, -1.73274946f, 10.68729019f, 2.20655656f, -0.90522075f, 12.42165184f, -1.67929137f, 2.44851565f, 9.31565762f, -0.06645700f, 1.52762020f, 6.18427515f, -1.68882596f, 3.70261097f, 3.02252960f, -3.44125366f, -1.31575799f, 2.84617424f, -0.96849400f, -4.52356243f, 9.95027161f, 0.19966406f, -0.78874779f, 8.18595028f, -4.08300209f, 1.75126517f, 0.96418417f, -4.04913044f, -0.95200396f, 12.03637886f, -0.03041124f, 0.41642749f, 8.88267422f, -3.24985337f, -2.24919462f, 7.32566118f, 0.16964148f, -2.74123430f, 7.05264473f, -3.30191112f, 0.17163286f, 4.81851053f, -1.64463484f, -0.85933101f, 7.29276276f, 2.34066939f, -2.14860010f, 3.46148157f, -0.01782012f, 1.51504040f, 4.79304934f, 1.85281146f, -1.70663762f, 6.93470192f, -4.15440845f, -1.25983095f, 10.52491760f, 0.42930329f, -1.85146868f, 11.70042324f, -0.41704914f, 3.83796859f, 9.21148491f, -2.79719448f, 0.79470479f, 6.26926661f, -5.85230207f, 3.95105338f, 7.84790897f, -1.38680744f, -1.78099084f, 11.95235348f, -2.99841452f, -1.34507811f, 6.15714645f, -1.07552516f, -2.81228638f, 1.66234732f, -4.55166149f, -1.92601109f, 8.64634514f, -0.48158705f, 3.31595659f, 7.67371941f, 2.56964207f, 0.12107098f, 4.56467867f, -0.93541539f, 1.39432955f, 11.99714088f, 1.05353570f, -2.13099813f, 3.67617917f, 3.45895386f, 1.37365830f, 8.74344158f, -4.17585802f, 1.43908918f, 6.28764772f, 3.97346330f, -0.69144285f, 9.07983303f, -0.41635889f, -0.14965028f, 8.85469818f, 1.11306190f, 2.59440994f, 5.38982344f, -1.07948279f, 1.37252975f, 10.26984596f, -0.09318046f, 2.73104119f, 12.45902252f, -1.55446684f, -2.76124811f, 12.19395065f, -0.51846564f, 1.02764034f, 11.42673588f, -0.95940983f, -0.04781032f, 8.78379822f, -4.88957930f, 0.32534006f, 11.97696400f, -3.35108662f, 1.95104563f, 4.46915388f, -2.32061648f, 3.45230985f, 8.29983711f, 2.81034684f, -2.35529327f, 6.07801294f, -0.98105043f, -0.05359888f, 2.52291036f, -0.01986909f, -2.35321999f, 10.51954269f, 2.11145401f, 3.53506470f, 7.29093266f, 0.03721160f, -1.13496494f, 7.43886709f, -5.84201956f, 2.50796294f, 12.14647675f, 2.77490377f, -2.18896222f, 6.05641937f, 5.32617044f, 1.04221284f, 10.79106712f, -2.95749092f, -2.75414610f, 11.30037117f, -3.40654182f, -2.24673963f, 7.49126101f, 0.70811015f, -6.18003702f, 13.83951187f, -1.01204085f, 1.36298490f, -1.04451632f, 2.42435336f, -0.02346706f, -0.85528886f, 1.04731262f, 0.22192979f, 4.15708160f, 0.34933877f, 0.04814529f, 2.24107265f, 0.49676740f, -1.47752666f, 0.45040059f, -0.70471478f, -1.19759345f, 0.21711677f, 0.88461423f, -2.76830935f, 5.52066898f, 1.97664857f, -1.75381601f, 3.45877838f, 1.52617192f, -1.61350942f, 0.85337949f, 1.97610760f, -3.40310287f, 3.40319014f, -3.38691044f, -0.71319139f, 1.65463758f, -0.60680127f, -1.80700517f, 8.02592373f, 2.59627104f, 2.65895891f, 5.93043184f, -4.48425817f, 3.92670918f, 4.19496679f, -2.28286791f, 6.41634607f, 5.72330523f, 1.16269672f, -0.28753027f, 2.46342492f, 0.36693189f, 0.26712441f, 6.37652683f, -2.50139046f, 2.43923736f, 5.56310415f, 0.98065847f, 1.04267502f, 4.16403675f, -0.04966142f, 4.40897894f, 3.72905660f, -3.46129870f, 3.59962773f, 1.34830284f, -1.76661730f, 0.47943926f, 5.29946661f, -1.12711561f, 1.26970029f, 15.17655945f, -1.50971997f, 5.81345224f, 8.48562050f, -4.36049604f, 2.48144460f, 8.23780441f, -3.46030426f, -0.84656560f, 5.94946814f, 1.12747943f, -2.65683913f, 8.69085693f, 1.31309867f, -2.79958344f, 8.76840591f, -1.56444156f, 1.62710834f, 2.41177034f, -0.72804940f, 5.70619011f, 4.67169666f, -0.86167198f, -1.83803177f, 2.96346045f, 2.82692933f, -2.81557131f, 7.11113358f, -1.90071094f, 2.54244423f, 11.19284058f, -0.06298946f, -1.71517313f, 12.98388577f, 0.84510714f, 3.00816894f, 2.57200313f, 0.03899818f, -1.49330592f, 9.60099125f, -3.59513044f, -1.30045319f, 7.09241819f, -0.65233821f, -2.33627677f, 8.81366920f, 0.84154201f, 1.03312039f, 9.85289097f, 0.19351870f, 1.78496623f, 7.34631205f, -2.16530800f, -0.65016162f, 2.46842360f, 0.24016285f, -1.24308395f, 4.78175163f, -0.97682536f, 2.20942235f, 6.68382788f, 3.76786447f, -1.44454038f, 6.26453733f, -3.23575711f, -2.30137897f, 9.53092670f, -5.55222607f, 3.25999236f, 9.37559509f, 1.86339056f, -0.23551451f, 10.23400211f, 3.93031883f, -0.52629089f, 7.85724449f, -2.91549587f, 4.46612740f, 5.66530371f, -2.70820427f, 4.81359577f, 10.31247330f, 1.92230141f, 2.53931546f, 0.74986327f, 1.70303428f, 0.48063779f, 5.31099129f, -0.78976244f, 3.75864220f, 4.23051405f, 2.34042454f, -7.98193836f, 9.83987141f, -1.46722627f, 3.54497814f, 10.36455154f, -4.51249075f, 0.77715248f, 7.78694630f, -4.59989023f, -2.49585629f, 9.90296268f, 1.38535416f, 1.17441154f, 10.10452843f, -0.98628229f, 0.60194463f, 9.12639141f, -3.90754628f, 2.88526392f, 7.24123430f, -0.15283313f, -0.75728363f, -1.15116858f, -2.53791571f, 0.77229571f, 6.44114161f, 0.02646767f, 4.95463037f, 7.21066380f, 1.79384065f, 0.73250306f, 8.04447937f, 0.32576546f, -0.79447043f, 10.12717724f, 2.33392906f, 1.30716443f, 12.36073112f, -0.36694977f, -1.20438910f, 7.03105593f, 0.59557682f, 0.69267452f, 10.18113136f, 2.49944925f, -0.42229167f, 8.83143330f, -1.18805945f, -2.87509322f, 4.53596449f, 4.09732771f, -3.39088297f, -1.02536607f, 0.82119560f, -3.47302604f, 9.29991817f, 0.21001509f, 4.97036457f, 9.50018406f, 1.04420102f, 1.96560478f, 10.74769592f, -6.22709799f, 3.11690164f, 5.06759691f, -1.23724771f, -3.05831861f, 8.12925529f, -1.93435478f, -1.10151744f, 9.32263088f, -0.04249470f, -5.98547363f, 10.49398136f, 0.26400441f, -0.78915191f, 13.28219604f, 2.99276900f, 0.74853164f, 2.49364305f, -3.43529654f, 4.05278301f, 2.13498688f, -2.35444307f, -0.79900265f, 4.66968822f, -0.31095147f, 3.60674143f, 12.37222099f, -0.07855003f, -3.30292702f, 12.15215874f, 0.60886210f, 2.87075138f, 7.75271845f, 0.38044083f, 3.34402204f, 6.40583277f, -0.87888050f, 0.67438459f, 6.91080809f, 1.98332930f, -0.08303714f, 8.08630371f, -0.16772588f, -2.74058914f, 7.17253590f, -2.69122696f, 1.48173678f, 8.99470139f, -1.43302310f, -0.88651133f, 2.66944790f, -0.29186964f, 2.00838661f, 5.09587479f, -0.76676071f, -2.88322186f, 8.31110573f, -0.14550979f, -1.37726915f, 10.28355122f, -1.60575438f, -0.04118848f, 9.97510815f, 0.14440438f, -3.24632120f, 9.00034523f, 4.14319563f, -1.31023729f, 7.16950464f, -0.70428526f, 2.01559544f, 7.26155043f, 2.40816474f, 2.09847403f, 7.31264496f, -0.75401551f, 2.13392544f, 7.03648758f, 1.04036045f, -1.15636516f, 1.09634531f, -0.06340861f, -0.58107805f, -0.65623116f, 1.18972754f, -0.80717683f, 1.40118241f, -0.61932516f, -3.60596156f, 1.59904599f, -2.23774099f, -1.13721037f, 3.89620137f, -0.09115922f, -7.51356888f, 2.36975193f, -1.42520905f, -2.34173775f, 3.33830214f, -2.74016523f, -3.04115510f, 6.00119495f, -1.36084354f, -2.45065260f, 4.56992292f, -3.02825928f, -3.74182844f, 5.11069250f, -0.91531068f, -2.31385994f, 1.83399653f, 3.39370203f, -3.60886002f}); + auto z = NDArrayFactory::create('c', {4, 4, 4, 3}); + auto exp = NDArrayFactory::create('c', {4, 4, 4, 3}, {7.97172260f, 0.06878620f, 2.27749538f, 7.29276514f, -0.14074677f, 0.65480286f, 5.70313978f, -0.06546132f, 0.35443667f, 3.70382833f, -0.84020567f, 0.63826996f, 8.60301399f, -0.38236514f, 1.55177069f, 7.37542057f, -0.99374938f, -0.29971302f, 8.84352493f, -0.67121059f, 0.43132120f, 4.78175592f, -1.25070143f, -1.91523600f, 6.03855371f, -0.00292124f, -1.11214364f, 7.90158176f, -0.57949901f, -0.96735370f, 7.81192017f, -0.53255427f, -0.48009714f, 3.16953635f, 0.08353355f, -1.54299748f, 3.74821687f, 1.69396687f, 0.72724354f, 5.42915201f, -1.13686812f, -0.71793109f, 5.78376389f, -0.72239977f, -0.60055625f, 2.53636408f, 0.56777251f, -2.07892323f, 6.08064651f, 0.68620735f, 2.54017019f, 5.65828180f, -0.68255502f, 1.47283304f, 6.10842514f, -0.39655915f, 0.28380761f, 1.96707797f, -1.98206317f, 0.94027776f, 4.71811438f, 0.32104525f, -0.92409706f, 8.34588146f, -1.05581069f, -0.55217457f, 9.58440876f, -0.96549922f, 0.45820439f, 5.65453672f, -2.50953507f, -0.71441835f, 8.03059578f, -0.21281289f, 0.92125505f, 9.26900673f, -0.35963219f, -0.70039093f, 8.59924412f, -1.22358346f, 0.81318003f, 3.85920119f, -0.01305223f, -1.09234154f, 6.33158875f, 1.28094780f, -1.48926139f, 4.94969177f, -0.77126902f, -1.97033751f, 5.64381838f, -0.16285487f, -1.31277227f, 2.39893222f, -1.32902908f, -1.39609122f, 6.47572327f, -0.45267010f, 1.55727172f, 6.70965624f, -1.68735468f, -0.05672536f, 7.25092363f, -0.64613032f, 0.67050058f, 3.60789680f, -2.05948973f, 2.22687531f, 8.15202713f, -0.70148355f, 1.28314006f, 8.14842319f, -1.88807654f, -1.04808438f, 8.45500565f, -0.76425624f, 0.94542569f, 4.56179953f, -0.28786001f, -2.04502511f, 8.46278095f, -0.31019822f, 0.07339200f, 9.34214592f, -0.61948007f, 0.52481830f, 8.32515621f, -1.52418160f, 0.49678251f, 5.11082315f, -1.09908783f, -0.52969611f, 5.27806664f, 0.88632923f, 0.66754371f, 4.75839233f, 0.48928693f, -0.68036932f, 6.56925392f, -0.02949905f, -2.99189186f, 4.46320581f, -0.64534980f, -0.29516968f, 8.60809517f, -1.13120568f, 3.41720533f, 5.84243155f, -1.24109328f, 0.89566326f, 5.99578333f, -0.42496428f, 2.07076764f, 3.17812920f, -0.81566459f, -0.14363396f, 6.55184317f, 0.39633346f, -0.43852386f, 8.70214558f, -2.24613595f, 0.30708700f, 8.73882294f, -0.53545928f, 1.54409575f, 4.49452257f, -0.16509305f, 0.19028664f, 8.24897003f, 0.44750381f, 2.15448594f, 8.97640514f, -0.77728152f, 0.57272542f, 9.03467560f, 0.47173575f, -1.10807717f, 3.30056310f, -0.43268481f, -0.41470885f, 3.53798294f, -0.08546703f, -2.16840744f, 6.18733406f, -0.17871059f, -2.59837723f, 5.94218683f, -1.02990067f, -0.49760687f, 3.76938033f, 0.86383581f, -1.91504073f}); + + sd::ops::avgpool2d op; + + NDArray::prepareSpecialUse({&z}, {&input}); + + Nd4jPointer ptrsInBuffer[] = {reinterpret_cast(input.buffer()), input.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)input.shapeInfo(), (Nd4jPointer)input.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {reinterpret_cast(z.buffer()), z.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z.shapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; + Nd4jLong iArgs[] = {3,3, 3,3, 0,0, 1,1,1, 0,1}; + + auto hash = op.getOpHash(); + auto status = execCustomOp(nullptr, hash, ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, iArgs, 11, nullptr, 0, false); + + NDArray::registerSpecialUse({&z}, {&input}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +TEST_F(JavaInteropTests, Test_Mixed_Add_1) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + auto arrayX = NDArrayFactory::create({1, 2, 3, 4}); + auto arrayY = NDArrayFactory::create({1, 2, 3, 4}); + auto arrayZ = NDArrayFactory::create({0, 0, 0, 0}); + auto arrayE = NDArrayFactory::create({2, 4, 6, 8}); + + NDArray::prepareSpecialUse({&arrayZ}, {&arrayX, &arrayY}); + + OpaqueDataBuffer xBuf(arrayX.dataBuffer()); + OpaqueDataBuffer yBuf(arrayY.dataBuffer()); + OpaqueDataBuffer zBuf(arrayZ.dataBuffer()); + + execPairwiseTransform(nullptr, pairwise::Add, + &xBuf, arrayX.shapeInfo(), arrayX.specialShapeInfo(), + &yBuf, arrayY.shapeInfo(), arrayY.specialShapeInfo(), + &zBuf, arrayZ.shapeInfo(), arrayZ.specialShapeInfo(), + nullptr); + + NDArray::registerSpecialUse({&arrayZ}, {&arrayX, &arrayY}); + + ASSERT_EQ(arrayE, arrayZ); +} + +TEST_F(JavaInteropTests, Test_Add_1) { + auto x = NDArrayFactory::create('c', {5}, {1, 1, 1, 1, 1}); + auto y = NDArrayFactory::create('c', {5}, {1, 1, 1, 1, 1}); + auto e = NDArrayFactory::create('c', {5}, {2, 2, 2, 2, 2}); + + NDArray::prepareSpecialUse({&x}, {&x, &y}); + + sd::ops::add op; + + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), y.buffer(), x.specialBuffer(), y.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer)y.shapeInfo(), (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)y.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) x.buffer(), x.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer)x.specialShapeInfo()}; + + execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); + + NDArray::registerSpecialUse({&x}, {&x, &y}); + + ASSERT_EQ(e, x); +} + +TEST_F(JavaInteropTests, zeta_test10) { + + auto x = NDArrayFactory::create('c', {3, 4}, {1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 1.01, 1.11, 1.12}); + auto q = NDArrayFactory::create('c', {3, 4}, {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.01, 0.11, 0.12}); + auto z = NDArrayFactory::create('c', {3, 4}); + + auto e = NDArrayFactory::create('c', {3, 4}, {23.014574, 12.184081, 8.275731, 6.1532226, 4.776538, 3.7945523, 3.0541048, 2.4765317, 2.0163891, 205.27448, 21.090889, 19.477398}); + + sd::ops::zeta op; + + NDArray::prepareSpecialUse({&z}, {&x, &q}); + + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), q.buffer(), x.specialBuffer(), q.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer)q.shapeInfo(), (Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)q.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.buffer(), z.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.shapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; + + execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); + + NDArray::registerSpecialUse({&z}, {&x, &q}); + + ASSERT_EQ(e, z); +} + +TEST_F(JavaInteropTests, Test_IAMax_1) { + auto arrayX = NDArrayFactory::create({-0.24f, -0.26f, -0.07f, -0.01f}); + auto arrayZ = arrayX.indexReduceNumber(indexreduce::IndexAbsoluteMax, nullptr); + auto exp = NDArrayFactory::create(1); + + ASSERT_EQ(exp, arrayZ); +} + +TEST_F(JavaInteropTests, Test_Boolean_Broadcastables_1) { + auto arrayX = NDArrayFactory::create('c', {10, 10}); + auto arrayY = NDArrayFactory::create('c', {10, 10}); + + Nd4jPointer ptrsInBuffer[] = {reinterpret_cast(arrayX.buffer()), reinterpret_cast(arrayY.buffer()), arrayX.specialBuffer(), arrayY.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)arrayX.shapeInfo(), (Nd4jPointer)arrayY.shapeInfo(), (Nd4jPointer)arrayX.specialShapeInfo(), (Nd4jPointer)arrayY.specialShapeInfo()}; + + NDArray::prepareSpecialUse({}, {&arrayX, &arrayY}); + sd::ops::greater_equal op; + auto shapeList = calculateOutputShapes2(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 2, nullptr, 0, nullptr, 0, nullptr, 0, nullptr, 0); + NDArray::registerSpecialUse({}, {&arrayX, &arrayY}); + delete shapeList; +} + +TEST_F(JavaInteropTests, Test_L2_Loss_3) { + auto x = NDArrayFactory::create(0.7787855863571167); + auto e = NDArrayFactory::create(0.303254); + auto z = NDArrayFactory::create(0.0); + + NDArray::prepareSpecialUse({&z}, {&x}); + + Nd4jPointer ptrsInBuffer[] = {reinterpret_cast(x.buffer()), x.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer)x.shapeInfo(), (Nd4jPointer)x.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffer[] = {reinterpret_cast(z.buffer()), (Nd4jPointer)z.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer)z.shapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; + + sd::ops::l2_loss op; + auto status = execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffer, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); + ASSERT_EQ(Status::OK(), status); + + NDArray::registerSpecialUse({&z}, {&x}); + + ASSERT_EQ(e, z); +} + +TEST_F(JavaInteropTests, Test_Fastpath_3) { + auto array0 = NDArrayFactory::create('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto array1 = NDArrayFactory::create('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto z = NDArrayFactory::create('c', {3, 2}); + + auto exp = NDArrayFactory::create('c', {3, 2}, {2.f, 4.f, 6.f, 8.f, 10.f, 12.f}); + Context ctx(1); + + NDArray::prepareSpecialUse({&z}, {&array0, &array1}); + + ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), array0.specialBuffer(), array0.specialShapeInfo()); + ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), array1.specialBuffer(), array1.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + + ASSERT_EQ(2, ctx.width()); + + sd::ops::add op; + execCustomOp2(nullptr, op.getOpHash(), &ctx); + + NDArray::registerSpecialUse({&z}, {&array0, &array1}); + + ASSERT_EQ(exp, z); +} + +TEST_F(JavaInteropTests, Test_Fastpath_4) { + + auto exp = NDArrayFactory::create('c', {3, 5}, {1,1,1,0,0, 1,1,1,1,0, 1,1,1,1,1}); + auto z = NDArrayFactory::create('c', {3, 5}); + Nd4jLong iArgs[] = {3, 5, 2}; + + + NDArray::prepareSpecialUse({&z}, {}); + + Context ctx(1); + + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + ctx.setIArguments(iArgs, 3); + + sd::ops::tri op; + execCustomOp2(nullptr, op.getOpHash(), &ctx); + + NDArray::registerSpecialUse({&z}, {}); + + ASSERT_EQ(exp, z); +} + +TEST_F(JavaInteropTests, Test_Fastpath_5) { + auto a = NDArrayFactory::create('c', {3, 3}); + auto b = NDArrayFactory::create('c', {3, 3}); + auto c = NDArrayFactory::create('c', {3, 3}); + a.linspace(1.0); + b.linspace(1.0); + + NDArray::prepareSpecialUse({&c}, {&b, &c}); + + Context ctx(1); + + ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), a.specialShapeInfo()); + ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), b.specialShapeInfo()); + ctx.setOutputArray(0, c.buffer(), c.shapeInfo(), c.specialBuffer(), c.specialShapeInfo()); + + sd::ops::matmul op; + auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx); + + NDArray::registerSpecialUse({&c}, {&b, &c}); + + ASSERT_EQ(Status::OK(), status); +} + +TEST_F(JavaInteropTests, Test_Fastpath_6) { + auto a = NDArrayFactory::create('c', {2, 3}); + auto b = NDArrayFactory::create('c', {3, 4}); + auto gI = NDArrayFactory::create('c', {2, 4}); + + auto gA = NDArrayFactory::create('c', {2, 3}); + auto gB = NDArrayFactory::create('c', {3, 4}); + a.linspace(1.0); + b.linspace(1.0); + gI.linspace(1.0); + + NDArray::prepareSpecialUse({&gA, &gB}, {&a, &b, &gI}); + + Context ctx(1); + Nd4jLong iArgs[] = {0L, 0L, 0L}; + + ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), a.specialShapeInfo()); + ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), b.specialShapeInfo()); + ctx.setInputArray(2, gI.buffer(), gI.shapeInfo(), gI.specialBuffer(), gI.specialShapeInfo()); + + ctx.setOutputArray(0, gA.buffer(), gA.shapeInfo(), gA.specialBuffer(), gA.specialShapeInfo()); + ctx.setOutputArray(1, gB.buffer(), gB.shapeInfo(), gB.specialBuffer(), gB.specialShapeInfo()); + + ctx.setIArguments(iArgs, 3); + + sd::ops::matmul_bp op; + auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx); + + NDArray::registerSpecialUse({&gA, &gB}, {&a, &b, &gI}); + + ASSERT_EQ(Status::OK(), status); +} + +TEST_F(JavaInteropTests, Test_Fastpath_7) { + auto a = NDArrayFactory::create('c', {2}, {1.f, 2.f}); + auto b = NDArrayFactory::create(3.f); + auto z = NDArrayFactory::create('c', {3}); + auto e = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + + NDArray::prepareSpecialUse({&z}, {&a, &b}); + + Context ctx(1); + Nd4jLong iArgs[] = {0L, 0L, 0L}; + + ctx.setIArguments(iArgs, 1); + + sd::ops::concat op; + + ctx.setInputArray(0, a.buffer(), a.shapeInfo(), a.specialBuffer(), a.specialShapeInfo()); + ctx.setInputArray(1, b.buffer(), b.shapeInfo(), b.specialBuffer(), b.specialShapeInfo()); + + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + + auto status = execCustomOp2(nullptr, op.getOpHash(), &ctx); + + NDArray::registerSpecialUse({&z}, {&a, &b}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); +} + +TEST_F(JavaInteropTests, test_bfloat16_rng) { + if (!Environment::getInstance().isCPU()) + return; + + auto z = NDArrayFactory::create('c', {10}); + RandomGenerator rng(119, 323841120L); + bfloat16 args[2] = {(bfloat16) 0.0f, (bfloat16) 1.0f}; + OpaqueDataBuffer zBuf(z.dataBuffer()); + execRandom(nullptr, sd::random::Ops::UniformDistribution, &rng, &zBuf, z.shapeInfo(), z.specialShapeInfo(), args); + + //z.printIndexedBuffer("z"); + ASSERT_TRUE(z.sumNumber().e(0) > 0); +} + +TEST_F(JavaInteropTests, test_ismax_view) { + auto original = NDArrayFactory::create('c', {2, 3, 40}); + auto v = original.subarray({NDIndex::all(), NDIndex::all(), NDIndex::interval(0, 40, 2)}); + v.assign(1.0); + + auto e = v.like(); + auto t = e(0, {2}); + t.assign(1.0); + + auto z = v.ulike(); + + + Nd4jLong iArgs[] = {2L, 0L}; + Context ctx(1); + ctx.setInputArray(0, v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + ctx.setIArguments(iArgs, 1); + + sd::ops::ismax op; + op.execute(&ctx); + + ASSERT_EQ(e, z); +} + +TEST_F(JavaInteropTests, test_size_dtype_1) { + auto x = NDArrayFactory::create('c', {3}, {1.f, 1.f, 1.f}); + auto z = NDArrayFactory::create(0.0f); + auto e = NDArrayFactory::create(3.0f); + + Context ctx(1); + ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + + sd::ops::size op; + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); +} + +TEST_F(JavaInteropTests, test_expandable_array_op_1) { + auto x = NDArrayFactory::string( {2}, {"first string", "second"}); + auto d = NDArrayFactory::string(" ", sd::DataType::UTF8); + + auto z0 = NDArrayFactory::create('c', {6}); + auto z1 = NDArrayFactory::string( {3}, {"", "", ""}); + + auto exp0 = NDArrayFactory::create({0,0, 0,1, 1,0}); + auto exp1 = NDArrayFactory::string( {3}, {"first", "string", "second"}); + + InteropDataBuffer iz0(z0.dataBuffer()); + InteropDataBuffer iz1(z1.dataBuffer()); + + Context ctx(1); + ctx.setInputArray(0, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo()); + ctx.setInputArray(1, d.buffer(), d.shapeInfo(), d.specialBuffer(), d.specialShapeInfo()); + ctx.setOutputArray(0, &iz0, z0.shapeInfo(), z0.specialShapeInfo()); + ctx.setOutputArray(1, &iz1, z1.shapeInfo(), z1.specialShapeInfo()); + + sd::ops::compat_string_split op; + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(exp0, z0); + ASSERT_EQ(exp1, z1); +} + +TEST_F(JavaInteropTests, test_workspace_backed_arrays_1) { + if (!Environment::getInstance().isCPU()) + return; + + auto x = NDArrayFactory::create('c', {4, 3, 4, 4}); + auto y = NDArrayFactory::create('c', {4, 3, 3, 3}); + auto z = NDArrayFactory::create('c', {4, 3, 4, 4}); + + double buffer[2048]; + + InteropDataBuffer ix(0, DataType::DOUBLE, false); + InteropDataBuffer iy(0, DataType::DOUBLE, false); + InteropDataBuffer iz(0, DataType::DOUBLE, false); + + // we're imitating workspace-managed array here + ix.setPrimary(buffer + 64, x.lengthOf()); + iy.setPrimary(buffer + 64 + x.lengthOf(), y.lengthOf()); + iz.setPrimary(buffer + 64 + x.lengthOf() + y.lengthOf(), z.lengthOf()); + + Context ctx(1); + ctx.setInputArray(0, &ix, x.shapeInfo(), x.specialShapeInfo()); + ctx.setInputArray(1, &iy, y.shapeInfo(), y.specialShapeInfo()); + ctx.setOutputArray(0, &iz, z.shapeInfo(), z.specialShapeInfo()); + + ctx.setIArguments({2, 2, 1, 1, 0, 0, 1, 1, 0, 0, 0}); + + sd::ops::maxpool2d_bp op; + auto status = op.execute(&ctx); + ASSERT_EQ(Status::OK(), status); +} + +TEST_F(JavaInteropTests, test_linspace_shape_1) { + if (!Environment::getInstance().isCPU()) + return; + + sd::ops::lin_space op; + double tArgs[2] = {1.0, 10.0}; + Nd4jLong iArgs = 10L; + int dArg = (int) sd::DataType::FLOAT32; + auto result = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 2, &iArgs, 1, nullptr, 0, &dArg, 1); + + ASSERT_EQ(1, result->size()); + delete result; +} + +/* +TEST_F(JavaInteropTests, Test_Results_Conversion_1) { + auto pl = sd::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb"); + auto ptr = executeFlatGraph(nullptr, pl); + + // at this point we have FlatResults + auto flatResult = GetFlatResult(ptr->pointer()); + auto size = flatResult->variables()->size(); + + // we know exact number of outputs in this graph in given mode + ASSERT_EQ(184, size); + + + // now we're rolling through all variables and restore them one by one + for (int e = 0; e < size; e++) { + auto flatVar = flatResult->variables()->Get(e); + auto flatArray = flatVar->ndarray(); + + // checking var part first + // we just want to ensure we're not experiencing overruns here + auto name = flatVar->name()->str(); + + // checking array part now + auto shape = flatArray->shape(); + auto rank = shape->Get(0); + + ASSERT_TRUE(shape->size() > 0 && rank >= 0 && rank < MAX_RANK); + + // building regular NDArray out of this FlatArray + auto ndarray = sd::graph::FlatUtils::fromFlatArray(flatArray); + + // rank should match FlatArray + ASSERT_EQ(rank, ndarray->rankOf()); + + // array shouldn't have any NaN/Inf values + ASSERT_TRUE(ndarray->isFinite()); + + // array should be assignable + ndarray->assign(123.f); + + // and safely removable after + delete ndarray; + } + + + delete[] pl; + delete ptr; + + // and we should have 0 leaks reported after this line :) +} +*/ +// TEST_F(JavaInteropTests, Test_NLP_Aggregations_1) { +// std::array syn0 = {-0.022756476f, 0.0126427775f, 0.011029151f, -0.013542821f, -0.012327666f, -0.0032439455f, -0.008405109f, -0.016651405f, 0.0015980572f, -0.007442479f, 0.019937921f, -0.016222188f, -0.016541665f, 0.013372547f, 0.006625724f, 0.0058958204f, -0.01281835f, -6.2343775E-4f, 0.0019826533f, 0.010253737f, -0.010291531f, 0.0019767822f, 0.018071089f, -0.0117441565f, 0.023176769f, 0.0032820583f, 0.0061427564f, -0.01696018f, 0.0054971874f, 0.0043818625f, 0.019323621f, 0.0036080598f, 0.024376748f, -0.0024499625f, 0.019496754f, 0.010563821f, -2.0503551E-4f, -0.0146056535f, 0.009949291f, 0.017604528f, -0.0050302492f, -0.022060446f, 0.016468976f, -0.0034482107f, 0.010270384f, -0.0063356445f, -0.019934833f, -0.02325993f, 0.016109904f, -0.0031106502f, -0.0020592287f, 0.024031803f, 0.005184144f, -0.024887865f, 0.02100272f, 3.395051E-4f, 0.018432347f, 5.673498E-4f, -0.020073576f, 0.010949242f}; +// std::array syn1; +// std::array exp; + +// for (int e = 0; e < syn1.size(); e++) +// syn1[e] = 0.0f; + +// for (int e = 0; e < exp.size(); e++) { +// auto f = static_cast(e); +// auto tmp = sd::math::nd4j_exp((f / 100000.0 * 2.0 - 1.0) * 6.0); +// exp[e] = static_cast(tmp / (tmp + 1.0)); +// } + +// auto maxTypes = 5; +// auto numAggregates = 1; +// auto opNum = 3; +// auto maxArgs = 6; +// auto maxShapes = 0; +// auto maxIntArrays = 2; +// auto maxIntArraySize = 40; +// auto maxIndexArguments = 10; +// auto maxRealArguments = 2; + +// std::array pointer; + +// auto batchLimit = 512; + +// int indexPos = maxTypes * batchLimit; +// int intArraysPos = indexPos + (maxIndexArguments * batchLimit); +// int realPos = (intArraysPos + (maxIntArrays * maxIntArraySize * batchLimit)); +// int argsPos = (realPos + ((maxRealArguments * batchLimit))) / 2; +// int shapesPos = argsPos + (maxArgs * batchLimit); + +// std::vector intArray0({0, 0, 0, 0, 0}); +// std::vector intArray1({1, 0, 0, 0, 0}); + +// std::vector indexingArgs0({1, 20, 5, 0, 100000, 3, 0, 0, 0}); +// std::vector indexingArgs1({0, 20, 5, 0, 100000, 3, 1, 0, 0}); + +// std::vector realArgs0({0.024964055335354007f, 3.0768702268737162E18f}); + +// int argSize = 6; +// int shapesSize = 0; +// int indexingSize = 9; +// int realArgsSize = 2; +// int intArraysSize = 2; + +// int e = 0; + +// auto idx = e * maxTypes; + +// // numbers of arguments +// pointer[idx] = 6; // arguments size +// pointer[idx+1] = 0; // shapes size +// pointer[idx+2] = 9; // indexing arguments size +// pointer[idx+3] = 2; // real args size +// pointer[idx+4] = 2; // intArray args size + +// // indexing args +// auto idxArgs = e == 0 ? indexingArgs0 : indexingArgs1; +// for (int f = 0; f < idxArgs.size(); f++) { +// idx = indexPos + e * maxIndexArguments; +// pointer[idx + f] = idxArgs[f]; +// } + +// // int array values +// int bsize = maxIntArrays * maxIntArraySize; +// for (int f = 0; f < intArraysSize; f++) { +// int step = (e * bsize) + (f * maxIntArraySize); +// auto intArr = f == 0 ? intArray0 : intArray1; +// for (int x = 0; x < intArr.size(); x++) { +// idx = intArraysPos + step + x; +// pointer[idx] = intArr[x]; +// } +// } + +// // real args +// auto ptr = reinterpret_cast(pointer.data()); +// for (int f = 0; f < realArgsSize; f++) { +// idx = realPos + (e * maxRealArguments); +// ptr[idx + f] = realArgs0[f]; +// } + +// // +// auto ptrptr = reinterpret_cast(pointer.data()); +// idx = argsPos + e * maxArgs; +// ptrptr[idx] = reinterpret_cast(syn0.data()); +// ptrptr[idx+1] = reinterpret_cast(syn1.data()); +// ptrptr[idx+2] = reinterpret_cast(exp.data()); + + +// execAggregateBatchFloat(nullptr, numAggregates, opNum, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIndexArguments, maxRealArguments, pointer.data()); +// } diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/LambdaTests.cu b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/LambdaTests.cu new file mode 100644 index 000000000..743e6cff2 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/LambdaTests.cu @@ -0,0 +1,221 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include + +using namespace sd; + +class LambdaTests : public testing::Test { +public: + + LambdaTests() { + printf("\n"); + fflush(stdout); + } +}; + +template +__global__ void runLambda(double *input, double *output, Nd4jLong length, Lambda lambda) { + auto tid = blockIdx.x * blockDim.x + threadIdx.x; + for (Nd4jLong e = tid; e < length; e += gridDim.x * blockDim.x) { + output[e] = lambda(input[e]); + } +} + +void launcher(cudaStream_t *stream, double *input, double *output, Nd4jLong length) { + //auto f = [] __host__ __device__ (double x) -> double { + // return x + 1.; + //}; + auto f = LAMBDA_D(x) { + return x+1.; + }; + + + runLambda<<<128, 128, 128, *stream>>>(input, output, length, f); +} + + +TEST_F(LambdaTests, test_basic_1) { + auto x = NDArrayFactory::create('c', {5}); + auto e = NDArrayFactory::create('c', {5}, {1., 1., 1., 1., 1.}); + + + + //x.applyLambda(f, nullptr); + launcher(LaunchContext::defaultContext()->getCudaStream(), (double *)x.specialBuffer(), (double *)x.specialBuffer(), x.lengthOf()); + auto res = cudaStreamSynchronize(*LaunchContext::defaultContext()->getCudaStream()); + ASSERT_EQ(0, res); + + ASSERT_EQ(e, x); +} + +void test(NDArray &x) { + auto f = LAMBDA_D(x) { + return x+1.; + }; + + x.applyLambda(f, x); +} + +template +void test2(NDArray &x) { + auto f = LAMBDA_T(x) { + return x+1.; + }; + + x.applyLambda(f, x); +} + +void testPairwise(NDArray &x, NDArray &y) { + auto f = LAMBDA_DD(x, y) { + return x + y +1.; + }; + + x.applyPairwiseLambda(y, f, x); +} + +void testTriplewise(NDArray &i, NDArray &j, NDArray &k) { + auto f = LAMBDA_DDD(i, j, k) { + return i + j + k + 2.; + }; + + i.applyTriplewiseLambda(j, k, f, i); +} + +void testIndexed(NDArray &x) { + auto f = ILAMBDA_D(x) { + return _idx + 1.; + }; + + x.applyIndexedLambda(f, x); +} + +void testIndexedPairwise(NDArray &x, NDArray &y) { + auto f = ILAMBDA_DD(x, y) { + return _idx + x + y +1.; + }; + + x.applyIndexedPairwiseLambda(y, f, x); +} + +TEST_F(LambdaTests, test_basic_2) { + auto x = NDArrayFactory::create('c', {5}); + auto e = NDArrayFactory::create('c', {5}, {1., 1., 1., 1., 1.}); + + test(x); + + ASSERT_EQ(e, x); +} + +TEST_F(LambdaTests, test_basic_3) { + auto x = NDArrayFactory::create('c', {5}); + auto e = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); + + test(x); + + ASSERT_EQ(e, x); +} + +TEST_F(LambdaTests, test_basic_4) { + auto x = NDArrayFactory::create('c', {5}); + auto e = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); + + test2(x); + + ASSERT_EQ(e, x); +} + +TEST_F(LambdaTests, test_basic_5) { + auto x = NDArrayFactory::create('c', {5}, {1., 1., 1., 1., 1.}); + auto y = NDArrayFactory::create('c', {5}, {2., 2., 2., 2., 2.}); + auto e = NDArrayFactory::create('c', {5}, {4., 4., 4., 4., 4.}); + + testPairwise(x, y); + + ASSERT_EQ(e, x); +} + +TEST_F(LambdaTests, test_basic_6) { + auto x = NDArrayFactory::create('c', {5}); + auto e = NDArrayFactory::create('c', {5}, {1., 2., 3., 4., 5.}); + + testIndexed(x); + + ASSERT_EQ(e, x); +} + +TEST_F(LambdaTests, test_basic_7) { + auto w = NDArrayFactory::create('c', {5}, {0., 0., 0., 0., 0.}); + auto x = NDArrayFactory::create('c', {5}, {1., 1., 1., 1., 1.}); + auto y = NDArrayFactory::create('c', {5}, {2., 2., 2., 2., 2.}); + auto e = NDArrayFactory::create('c', {5}, {5., 5., 5., 5., 5.}); + + testTriplewise(w, x, y); + + ASSERT_EQ(e, w); +} + +TEST_F(LambdaTests, test_basic_8) { + auto x = NDArrayFactory::create('c', {5}, {1., 1., 1., 1., 1.}); + auto y = NDArrayFactory::create('c', {5}, {2., 2., 2., 2., 2.}); + auto e = NDArrayFactory::create('c', {5}, {4., 5., 6., 7., 8.}); + + testIndexedPairwise(x, y); + + ASSERT_EQ(e, x); +} + + +template +void testPairwiseMy(NDArray &x, NDArray &y, NDArray &z) { + + auto f = LAMBDA_TT(x, y){ + return sd::math::nd4j_max(x, (T)0.f) + - x * y + + sd::math::nd4j_log((T)1.f + + sd::math::nd4j_exp(-sd::math::nd4j_abs(x))); + }; + + x.applyPairwiseLambda(y, f, z); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(LambdaTests, test_basic_9) { + + NDArray labels('c', {2,3,4},{0,1,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0,1,1,0,1,0}); + NDArray logits('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray output('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray expected('c', {2,3,4}, {0.744397, 0.598139, 0.554355, 0.913015, 0.474077, 1.037488, 0.403186, 1.171101, 0.341154, 1.313262, 0.287335, 1.463282, 0.241008, 1.620417, 0.201413, 1.783901, 0.167786, 1.952978, 2.039387, 0.126928, 0.115520, 2.305083, 0.095545, 2.486836}); + + logits.linspace(0.1, 0.1); + + NDArray::prepareSpecialUse({&output}, {&logits, &labels}); + testPairwiseMy(logits, labels, output); + NDArray::registerSpecialUse({&output}, {&logits, &labels}); + + // output.printBuffer(nullptr, -1, true); + ASSERT_TRUE(expected.equalsTo(output)); +} diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/LaunchContextCudaTests.cu b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/LaunchContextCudaTests.cu new file mode 100644 index 000000000..8c7142623 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/LaunchContextCudaTests.cu @@ -0,0 +1,127 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace sd; +using namespace sd::ops; + +class LaunchContextCudaTests : public testing::Test { + // +}; + + +void acquireContext(int threadId, int &deviceId) { + deviceId = AffinityManager::currentDeviceId(); + + nd4j_printf("Creating thread: [%i]; assigned deviceId: [%i];\n", threadId, deviceId); + + auto lc = LaunchContext::defaultContext(); + nd4j_printf("LC: [%p]\n", lc); + + nd4j_printf("reductionPtr: [%p]; stream: [%p];\n", lc->getReductionPointer(), lc->getCudaStream()); +} + +TEST_F(LaunchContextCudaTests, basic_test_1) { + int deviceA, deviceB; + std::thread threadA(acquireContext, 0, std::ref(deviceA)); + std::thread threadB(acquireContext, 1, std::ref(deviceB)); + + threadA.join(); + threadB.join(); + nd4j_printf("All threads joined\n",""); + + if (AffinityManager::numberOfDevices() > 1) + ASSERT_NE(deviceA, deviceB); +} + +void fillArray(int tid, std::vector &arrays) { + auto array = NDArrayFactory::create_('c', {3, 10}); + nd4j_printf("Array created on device [%i]\n", AffinityManager::currentDeviceId()); + array->assign(tid); + arrays[tid] = array; +} + +TEST_F(LaunchContextCudaTests, basic_test_2) { + std::vector arrays(2); + + std::thread threadA(fillArray, 0, std::ref(arrays)); + std::thread threadB(fillArray, 1, std::ref(arrays)); + + threadA.join(); + threadB.join(); + + for (int e = 0; e < 2; e++) { + auto array = arrays[e]; + ASSERT_EQ(e, array->e(0)); + + delete array; + } +} + +void initAffinity(int tid, std::vector &aff) { + auto affinity = AffinityManager::currentDeviceId(); + aff[tid] = affinity; + nd4j_printf("Thread [%i] affined with device [%i]\n", tid, affinity); +} + +TEST_F(LaunchContextCudaTests, basic_test_3) { + auto totalThreads = AffinityManager::numberOfDevices() * 4; + nd4j_printf("Total threads: %i\n", totalThreads); + std::vector affinities(totalThreads); + + for (int e = 0; e < totalThreads; e++) { + std::thread thread(initAffinity, e, std::ref(affinities)); + + thread.join(); + } + + std::vector hits(AffinityManager::numberOfDevices()); + std::fill(hits.begin(), hits.end(), 0); + + // we need to make sure all threads were attached to "valid" devices + for (int e = 0; e < totalThreads; e++) { + auto aff = affinities[e]; + ASSERT_TRUE(aff >= 0 && aff < AffinityManager::numberOfDevices()); + + hits[aff]++; + } + + // now we check if all devices got some threads + for (int e = 0; e < AffinityManager::numberOfDevices(); e++) { + ASSERT_GT(hits[e], 0); + } +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/LegacyOpsCudaTests.cu b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/LegacyOpsCudaTests.cu new file mode 100644 index 000000000..41aaf279a --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/LegacyOpsCudaTests.cu @@ -0,0 +1,114 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace sd; +using namespace sd::ops; + +class LegacyOpsCudaTests : public testing::Test { + +}; + + +TEST_F(LegacyOpsCudaTests, test_sortTad_1) { + auto x = NDArrayFactory::create('c', {3, 5}, {1.f, 3.f, 0.f, 2.f, 4.f, + 6.f, 5.f, 9.f, 7.f, 8.f, + 10.f, 11.f, 14.f, 12.f, 13.f}); + + auto e = NDArrayFactory::create('c', {3, 5}, {0.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f}); + + int axis = 1; + auto packX = ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), axis); + + Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; + + x.syncToDevice(); + sortTad(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), &axis, 1, packX.platformShapeInfo(), packX.platformOffsets(), false); + x.tickWriteDevice(); + + ASSERT_EQ(e, x); +} + +TEST_F(LegacyOpsCudaTests, test_sort_1) { + auto x = NDArrayFactory::create('c', {4}, {4.f, 2.f, 1.f, 3.f}); + auto e = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + + Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; + + NDArray::prepareSpecialUse({&x}, {&x}); + ::sort(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), false); + NDArray::registerSpecialUse({&x}); + + ASSERT_EQ(e, x); +} + +TEST_F(LegacyOpsCudaTests, test_sort_2) { + auto x = NDArrayFactory::create('c', {4}, {4.f, 2.f, 1.f, 3.f}); + auto e = NDArrayFactory::create('c', {4}, {4.f, 3.f, 2.f, 1.f}); + + Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; + + NDArray::prepareSpecialUse({&x}, {&x}); + ::sort(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), true); + NDArray::registerSpecialUse({&x}); + + ASSERT_EQ(e, x); +} + +TEST_F(LegacyOpsCudaTests, test_sort_3) { + auto x = NDArrayFactory::create('c', {4}, {0.5, 0.4, 0.1, 0.2}); + auto e = NDArrayFactory::create('c', {4}, {0.1, 0.2, 0.4, 0.5}); + + Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; + + NDArray::prepareSpecialUse({&x}, {&x}); + ::sort(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), false); + NDArray::registerSpecialUse({&x}); + + ASSERT_EQ(e, x); +} + +TEST_F(LegacyOpsCudaTests, test_sort_4) { + auto x = NDArrayFactory::create('c', {4}, {7, 4, 9, 2}); + auto e = NDArrayFactory::create('c', {4}, {2, 4, 7, 9}); + + Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; + + NDArray::prepareSpecialUse({&x}, {&x}); + ::sort(extras, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), false); + NDArray::registerSpecialUse({&x}); + + ASSERT_EQ(e, x); +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/LegacyOpsTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/LegacyOpsTests.cpp new file mode 100644 index 000000000..7b338cad2 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/LegacyOpsTests.cpp @@ -0,0 +1,770 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 16.10.2017. +// + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace sd; +using namespace sd::ops; + +class LegacyOpsTests : public testing::Test { + +}; + + +TEST_F(LegacyOpsTests, TransformTests_1) { + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(1.0); + auto z = NDArrayFactory::create('c', {5,5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + exp.assign(-1.0); + + sd::ops::LegacyTransformSameOp op(transform::Neg); // Neg + auto status = op.execute({&x}, {&z}, {}, {}, {}); + ASSERT_EQ(status, ND4J_STATUS_OK); + //z.printIndexedBuffer("Output NEG"); + ASSERT_TRUE(z.equalsTo(&exp)); +} + +TEST_F(LegacyOpsTests, TransformTests_2) { + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(1.0); + + auto exp = NDArrayFactory::create('c', {5, 5}); + exp.assign(-1.0); + + sd::ops::LegacyTransformSameOp op(transform::Neg); // Neg + auto result = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(1, result.size()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(LegacyOpsTests, Reciprocal_1) { + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(2.0f); + + auto ethalon = NDArrayFactory::create('c', {5, 5}); + ethalon.assign(0.5f); + + sd::ops::LegacyTransformSameOp op(transform::Reciprocal); // Reciprocal + Nd4jStatus status = op.execute({&x}, {&x}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(ethalon.equalsTo(&x)); + +} + +TEST_F(LegacyOpsTests, PWT_Tests_1) { + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(2.0); + + auto y = NDArrayFactory::create('c', {5, 5}); + y.assign(3.0); + + auto exp = NDArrayFactory::create('c', {5, 5}); + exp.assign(6.0); + + sd::ops::LegacyPairwiseTransformOp op(pairwise::Multiply); // Multiply + Nd4jStatus status = op.execute({&x, &y}, {&x}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + + ASSERT_TRUE(exp.equalsTo(&x)); + + +} + +TEST_F(LegacyOpsTests, PWT_Tests_2) { + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(2.0); + + auto y = NDArrayFactory::create('c', {5, 5}); + y.assign(3.0); + + auto exp = NDArrayFactory::create('c', {5, 5}); + exp.assign(6.0); + + sd::ops::LegacyPairwiseTransformOp op(pairwise::Multiply); // Multiply + auto result = op.evaluate({&x, &y}, {}, {}); + + auto z = result.at(0); + + //z->printBuffer("Z"); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(LegacyOpsTests, Scalar_Test_1) { + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(2.0); + + auto exp = NDArrayFactory::create('c', {5, 5}); + exp.assign(7.0); + + sd::ops::LegacyScalarOp op(scalar::Add); + op.execute({&x}, {&x}, {5.0}, {}, {}); // + + ASSERT_TRUE(exp.equalsTo(&x)); +} + +TEST_F(LegacyOpsTests, Scalar_Test_2) { + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(2.0); + + auto exp = NDArrayFactory::create('c', {5, 5}); + exp.assign(7.0); + + auto y = NDArrayFactory::create(5.0f); + + sd::ops::LegacyScalarOp op(scalar::Add, y); + auto result = op.evaluate({&x}, {}, {}); + + auto z = result.at(0); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(LegacyOpsTests, ReduceTests_1) { + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(1.0); + int opNum = reduce::Sum; + sd::ops::LegacyReduceSameOp op(opNum); + + auto result = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(1, result.size()); + + auto z = result.at(0); + // z->printBuffer("ReduceTest1"); + ASSERT_TRUE(z->isScalar()); + ASSERT_NEAR(x.sumNumber().e(0), z->e(0), 1e-5f); + + +} + + +TEST_F(LegacyOpsTests, ReduceTests_2) { + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(1.0); + + sd::ops::LegacyReduceSameOp op(reduce::Sum); + auto axis = NDArrayFactory::create('c', {1}, {1}); + auto result = op.evaluate({&x, &axis}, {}, {}); + + ASSERT_EQ(1, result.size()); + + auto z = result.at(0); + + auto exp = x.reduceAlongDimension(reduce::Sum, {1}); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(LegacyOpsTests, ReduceTests_3) { + auto x = NDArrayFactory::create('c', {3, 5}); + x.linspace(1); + auto indices = NDArrayFactory::create('c', {1,1}, {1}); + + + sd::ops::LegacyReduceSameOp op(reduce::Sum); + auto result = op.evaluate({&x, &indices}, {}, {}); + auto z = result.at(0); + auto exp = x.reduceAlongDimension(reduce::Sum,{1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(LegacyOpsTests, ReduceTests_4) { + auto x = NDArrayFactory::create('c', {2, 3, 5}); + x.linspace(1); + auto indices = NDArrayFactory::create('c', {1, 1}, {1}); + + + sd::ops::LegacyReduceSameOp op(reduce::Sum); + auto result = op.evaluate({&x, &indices}, {}, {}, {true}); + auto z = result.at(0); + auto exp = x.reduceAlongDimension(reduce::Sum, {1}, true); + // indices.printShapeInfo("Indices shape"); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + // z->printIndexedBuffer("Output reduce 4"); + // exp.printIndexedBuffer("Expected reduce 4"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(LegacyOpsTests, ReduceTests_5) { + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(1.0); + int opNum = reduce::Mean; + sd::ops::LegacyReduceFloatOp op(opNum); + + auto result = op.evaluate({&x}); + + ASSERT_EQ(1, result.size()); + + auto z = result.at(0); + // z->printBuffer("ReduceTest1"); + ASSERT_TRUE(z->isScalar()); + ASSERT_NEAR(x.meanNumber().e(0), z->e(0), 1e-5f); + + +} + + +TEST_F(LegacyOpsTests, ReduceTests_6) { + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(1.0); + auto axis = NDArrayFactory::create('c', {1}, {1}); + sd::ops::LegacyReduceFloatOp op(reduce::Mean); + + auto result = op.evaluate({&x, &axis}, {}, {}); + + ASSERT_EQ(1, result.size()); + + auto z = result.at(0); + + auto exp = x.reduceAlongDimension(reduce::Mean, {1}); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(LegacyOpsTests, ReduceTests_7) { + auto x = NDArrayFactory::create('c', {3, 5}); + x.linspace(1); + auto indices = NDArrayFactory::create('c', {1,1}, {1}); + + + sd::ops::LegacyReduceFloatOp op(reduce::Mean); + auto result = op.evaluate({&x, &indices}, {}, {}); + auto z = result.at(0); + auto exp = x.reduceAlongDimension(reduce::Mean,{1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(LegacyOpsTests, ReduceTests_8) { + auto x = NDArrayFactory::create('c', {2, 3, 5}); + x.linspace(1); + auto indices = NDArrayFactory::create('c', {1}, {1}); + + + sd::ops::LegacyReduceFloatOp op(reduce::Mean); + auto result = op.evaluate({&x, &indices}, {}, {}, {true}); + auto z = result.at(0); + auto exp = x.reduceAlongDimension(reduce::Mean, {1}, true); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + // z->printIndexedBuffer("Reduce8 output"); + // z->printShapeInfo("Reduce8 shape"); + // exp.printShapeInfo("Reduce8 expected shape"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(LegacyOpsTests, IndexReduceTests_1) { + auto x = NDArrayFactory::create('c', {5, 5}); + x.linspace(1); + + sd::ops::LegacyIndexReduceOp op(indexreduce::IndexMax); + + auto result = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(1, result.size()); + + auto z = result.at(0); + + ASSERT_TRUE(z->isScalar()); + ASSERT_EQ(24, z->e(0)); + + +} + + +TEST_F(LegacyOpsTests, IndexReduceTests_2) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto indices = NDArrayFactory::create('c', {1}, {1}); + x.linspace(1); + auto exp = NDArrayFactory::create({4,4,4,4,4}); + sd::ops::LegacyIndexReduceOp op(indexreduce::IndexMax); + + auto result = op.evaluate({&x, &indices}, {}, {}); + + ASSERT_EQ(1, result.size()); + + auto z = result.at(0); + // z->printIndexedBuffer("Hello indexreduce2"); + ASSERT_TRUE(exp.equalsTo(z)); + //ASSERT_EQ(4, z->e(0)); + //ASSERT_EQ(4, z->e(1)); + //ASSERT_EQ(4, z->e(2)); + //ASSERT_EQ(4, z->e(3)); + //ASSERT_EQ(4, z->e(4)); + + +} + +TEST_F(LegacyOpsTests, BroadcastingTests_1) { + auto x = NDArrayFactory::create('c', {5, 5}); + x.assign(0.0f); + + auto row = NDArrayFactory::create('c', {1, 5}); + row.linspace(1); + auto axis = NDArrayFactory::create('c', {1}, {1}); + sd::ops::LegacyBroadcastOp op(broadcast::Add); + Nd4jStatus status = op.execute({&x, &row, &axis}, {&x}, {}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto list = x.allTensorsAlongDimension({1}); + // x.printIndexedBuffer("Output broadcast"); + // list->at(0)->printIndexedBuffer("Column 0:"); + for (int e = 0; e < list.size(); e++) + ASSERT_TRUE(row.equalsTo(list.at(e))); +} + +TEST_F(LegacyOpsTests, BroadcastingTests_2) { + auto x = NDArrayFactory::create('c', {5}, {1, 1, 1, 1, 1}); + auto y = NDArrayFactory::create('c', {10, 5}); + auto e = NDArrayFactory::create('c', {10, 5}); + y.assign(3.0); + e.assign(4.0); + + int axis = 1; + + // shape::printShapeInfoLinear("tad shape", tad.tadOnlyShapeInfo); + auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions(y.shapeInfo(), {axis}); + + NDArray::prepareSpecialUse({&y}, {&x}); + + NativeOpExecutioner::execInverseBroadcast(LaunchContext::defaultContext(), broadcast::Add, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), &axis, 1, packY.platformShapeInfo(), packY.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); + + NDArray::registerSpecialUse({&y}, {&x}); + + ASSERT_EQ(e, y); +} + +TEST_F(LegacyOpsTests, PowDerivative_1) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + x.assign(3.f); + exp.assign(6.f); + + float p = 2.0f; + + x.applyScalar(scalar::PowDerivative, p, x); + + ASSERT_TRUE(exp.equalsTo(&x)); +} + +#ifndef __CUDABLAS__ +TEST_F(LegacyOpsTests, reduce3_1) { + + Nd4jLong yShape[2] = {4,4}; + Nd4jLong xShape[1] = {4}; + float y[16] ={1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}; + float x[4] = {1,2,3,4}; + int dimension[1] = {1}; + int dimensionLength = 1; + int opNum = 1; + float extraVals[1] = {0}; + float result[4] = {0.0,0.0,0.0,0.0}; + + std::vector dim = {1}; + + auto shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, yShape); + auto xShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 1, xShape); + + //int *tadShapeBuffer = shape::computeResultShape(shapeBuffer,dimension,dimensionLength); + auto tadShapeBuffer = sd::ShapeUtils::evalReduceShapeInfo('c', dim, shapeBuffer, false, true, nullptr); + functions::reduce3::Reduce3::exec(opNum, x, xShapeBuffer, extraVals, y, shapeBuffer, result, tadShapeBuffer, dimension, dimensionLength, 0, 4); + + float distancesAssertion[4] = {0.0,8.0,16.0,24.0}; + for(int i = 0; i < 4; i++) + ASSERT_NEAR(distancesAssertion[i],result[i], 1e-5); + + delete[] shapeBuffer; + delete[] xShapeBuffer; +} + +#endif + + +TEST_F(LegacyOpsTests, Reduce3_2) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5}); + auto z = NDArrayFactory::create('c', {5}); + + auto dim = NDArrayFactory::create('c', {1}, {1}); + dim.syncToHost(); + + sd::LaunchContext* context = sd::LaunchContext::defaultContext(); + + Nd4jPointer* extraPointers = nullptr; + #ifdef __CUDABLAS__ + extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()}; + #endif + + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), {1}); + auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions(y.shapeInfo(), {1}); + + NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer dimBuf(dim.dataBuffer()); + + execReduce3Tad(extraPointers, reduce3::CosineSimilarity, + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + nullptr, &yBuf, y.shapeInfo(), y.specialShapeInfo(), + &zBuf, z.shapeInfo(), z.specialShapeInfo(), + &dimBuf, dim.shapeInfo(), dim.specialShapeInfo(), + packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); + + NDArray::registerSpecialUse({&z}, {&x, &y, &dim}); + + delete []extraPointers; +} + +TEST_F(LegacyOpsTests, Reduce3_3) { + auto x = NDArrayFactory::create('c', {3, 5}, {-0.84443557262, -0.06822254508, 0.74266910552, 0.61765557527, -0.77555125951, + -0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, -1.25485503673, + 0.62955373525, -0.31357592344, 1.03362500667, -0.59279078245, 1.1914824247}); + + auto y = NDArrayFactory::create('c', {5}, {-0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, -1.25485503673}); + auto e = NDArrayFactory::create('c', {3}, {0.577452, 0.0, 1.80182}); + auto z = NDArrayFactory::create('c', {3}); + + auto dim = NDArrayFactory::create('c', {1}, {1}); + dim.syncToHost(); + + sd::LaunchContext* context = sd::LaunchContext::defaultContext(); + + Nd4jPointer* extraPointers = nullptr; + #ifdef __CUDABLAS__ + extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()}; + #endif + + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), {1}); + auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions(y.shapeInfo(), {1}); + + NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer dimBuf(dim.dataBuffer()); + + execReduce3Tad(extraPointers, reduce3::CosineDistance, + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + nullptr, + &yBuf, y.shapeInfo(), y.specialShapeInfo(), + &zBuf, z.shapeInfo(), z.specialShapeInfo(), + &dimBuf, dim.shapeInfo(), dim.specialShapeInfo(), + packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); + ASSERT_EQ(e, z); + NDArray::registerSpecialUse({&z}, {&x, &y, &dim}); + delete []extraPointers; +} + +TEST_F(LegacyOpsTests, Reduce3_4) { + auto x = NDArrayFactory::create('c', {3, 5}, {-0.84443557262, -0.06822254508, 0.74266910552, 0.61765557527, -0.77555125951, + -0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, -1.25485503673, + 0.62955373525, -0.31357592344, 1.03362500667, -0.59279078245, 1.1914824247}); + + auto y = NDArrayFactory::create('c', {1, 5}, {-0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, -1.25485503673}); + auto e = NDArrayFactory::create('c', {1, 3}, {0.577452, 0.0, 1.80182}); + auto z = NDArrayFactory::create('c', {1, 3}); + + auto dim = NDArrayFactory::create('c', {1}, {1}); + dim.syncToHost(); + + sd::LaunchContext* context = sd::LaunchContext::defaultContext(); + + Nd4jPointer* extraPointers = nullptr; + #ifdef __CUDABLAS__ + extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()}; + #endif + + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), {1}); + auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions(y.shapeInfo(), {1}); + + NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer dimBuf(dim.dataBuffer()); + + execReduce3Tad(extraPointers, reduce3::CosineDistance, + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + nullptr, + &yBuf, y.shapeInfo(), y.specialShapeInfo(), + &zBuf, z.shapeInfo(), z.specialShapeInfo(), + &dimBuf, dim.shapeInfo(), dim.specialShapeInfo(), + packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); + + // z.printIndexedBuffer("z"); + NDArray::registerSpecialUse({&z}, {&x, &y, &dim}); + ASSERT_EQ(e, z); + delete []extraPointers; +} + +TEST_F(LegacyOpsTests, Reduce3_5) { + auto x = NDArrayFactory::create('c', {3, 5}, {-0.84443557262, -0.06822254508, 0.74266910552, 0.61765557527, -0.77555125951, + -0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, -1.25485503673, + 0.62955373525, -0.31357592344, 1.03362500667, -0.59279078245, 1.1914824247}); + + auto y = NDArrayFactory::create('c', {1, 5}, {-0.99536740779, -0.0257304441183, -0.6512106060, -0.345789492130, -1.25485503673}); + auto e = NDArrayFactory::create('c', {1, 3}, {0.577452, 0.0, 1.80182}); + auto z = NDArrayFactory::create('c', {1, 3}); + + auto dim = NDArrayFactory::create('c', {1}, {1}); + dim.syncToHost(); + + sd::LaunchContext* context = sd::LaunchContext::defaultContext(); + + Nd4jPointer* extraPointers = nullptr; + #ifdef __CUDABLAS__ + extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()}; + #endif + + auto packX = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), {1}); + auto packY = sd::ConstantTadHelper::getInstance().tadForDimensions(y.shapeInfo(), {1}); + + NDArray::prepareSpecialUse({&z}, {&x, &y, &dim}); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer dimBuf(dim.dataBuffer()); + + execReduce3Tad(extraPointers, reduce3::CosineDistance, + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + nullptr, + &yBuf, y.shapeInfo(), y.specialShapeInfo(), &zBuf, z.shapeInfo(), z.specialShapeInfo(), + &dimBuf, dim.shapeInfo(), dim.specialShapeInfo(), + packX.platformShapeInfo(), packX.platformOffsets(), packY.platformShapeInfo(), packY.platformOffsets()); + + NDArray::registerSpecialUse({&z}, {&x, &y, &dim}); + ASSERT_EQ(e, z); + delete []extraPointers; +} + +TEST_F(LegacyOpsTests, test_Reduce3_All_1) { + auto x = NDArrayFactory::create('c', {1000, 100}); + auto y = NDArrayFactory::create('c', {1, 100}); + auto z = NDArrayFactory::create('c', {1000, 1}); + auto dim = NDArrayFactory::create('c', {1}, {-1}); + + auto tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), -1); + auto tadPackY = sd::ConstantTadHelper::getInstance().tadForDimensions(y.shapeInfo(), -1); + + sd::LaunchContext* context = sd::LaunchContext::defaultContext(); + + Nd4jPointer* extraPointers = nullptr; + #ifdef __CUDABLAS__ + extraPointers = new Nd4jPointer[7] {nullptr, context->getCudaStream(), context->getScalarPointer(), nullptr, context->getCudaSpecialStream(), context->getReductionPointer(), context->getAllocationPointer()}; + #endif + + NDArray::prepareSpecialUse({&z}, {&x, &y}); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer dimBuf(dim.dataBuffer()); + + execReduce3All(extraPointers, reduce3::EuclideanDistance, &xBuf, x.shapeInfo(), x.specialShapeInfo(), + nullptr, &yBuf, y.shapeInfo(), y.specialShapeInfo(), + &zBuf, z.shapeInfo(), z.specialShapeInfo(), + &dimBuf, dim.shapeInfo(), dim.specialShapeInfo(), + tadPackX.platformShapeInfo(), tadPackX.platformOffsets(), + tadPackY.platformShapeInfo(), tadPackY.platformOffsets()); + + NDArray::registerSpecialUse({&z}, {&x, &y}); + + delete []extraPointers; +} + + +TEST_F(LegacyOpsTests, test_inverse_broadcast_1) { + auto x = NDArrayFactory::create('c', {4}, {2.0f, 2.0f, 2.0f, 2.0f}); + auto y = NDArrayFactory::create('c', {3, 4}); + auto e = NDArrayFactory::create('c', {3, 4}); + e.assign(2.0f); + + auto tadPackY = sd::ConstantTadHelper::getInstance().tadForDimensions(y.shapeInfo(), 1); + + y.tickWriteDevice(); + + NativeOpExecutioner::execInverseBroadcast(LaunchContext::defaultContext(), broadcast::Add, + x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), + y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), + nullptr, 0, + tadPackY.platformShapeInfo(), tadPackY.platformOffsets(), + tadPackY.platformShapeInfo(), tadPackY.platformOffsets()); + + ASSERT_EQ(e, y); +} + +TEST_F(LegacyOpsTests, test_inverse_broadcast_2) { + auto x = NDArrayFactory::create('c', {4}, {2.0f, 2.0f, 2.0f, 2.0f}); + auto y = NDArrayFactory::create('c', {3, 4}); + auto z = NDArrayFactory::create('c', {3, 4}); + auto e = NDArrayFactory::create('c', {3, 4}); + e.assign(false); + + auto row = y(1, {0}); + row.assign(2.0f); + + auto erow = e(1, {0}); + erow.assign(true); + + auto tadPackY = sd::ConstantTadHelper::getInstance().tadForDimensions(y.shapeInfo(), 1); + + z.tickWriteDevice(); + + NativeOpExecutioner::execInverseBroadcastBool(LaunchContext::defaultContext(), broadcast::BoolOps::EqualTo, + x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), + z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + nullptr, + nullptr, 0, + tadPackY.platformShapeInfo(), tadPackY.platformOffsets(), + tadPackY.platformShapeInfo(), tadPackY.platformOffsets()); + + ASSERT_EQ(e, z); +} + +TEST_F(LegacyOpsTests, test_legacy_reduce_empty_1) { + auto x = NDArrayFactory::create('c', {2, 0, 3}); + auto z = NDArrayFactory::create('c', {2, 3}); + auto e = NDArrayFactory::create('c', {2, 3}); + + int dim = 1; + + NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), reduce::SameOps::Sum, + x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, + z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + &dim, 1); + + ASSERT_EQ(e, z); +} + +TEST_F(LegacyOpsTests, test_legacy_reduce_empty_2) { + auto x = NDArrayFactory::create('c', {2, 0, 3}); + auto z = NDArrayFactory::create('c', {2, 3}); + auto e = NDArrayFactory::create('c', {2, 3}); + e.assign(std::numeric_limits::infinity()); + + int dim = 1; + + NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), reduce::SameOps::Min, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &dim, 1); + + ASSERT_EQ(e, z); +} + +TEST_F(LegacyOpsTests, test_legacy_reduce_empty_3) { + auto x = NDArrayFactory::create('c', {2, 0, 3}); + auto z = NDArrayFactory::create('c', {2, 3}); + auto e = NDArrayFactory::create('c', {2, 3}); + e.assign(-std::numeric_limits::infinity()); + + int dim = 1; + + NativeOpExecutioner::execReduceSame(LaunchContext::defaultContext(), reduce::SameOps::Max, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), &dim, 1); + + ASSERT_EQ(e, z); +} + +TEST_F(LegacyOpsTests, test_legacy_reduce_empty_4) { + if (!Environment::getInstance().isCPU()) + return; + int a = 0; + + auto x = NDArrayFactory::create('c', {1, 0, 2}); + auto d = NDArrayFactory::create('c', {1}, {a}); + auto z = NDArrayFactory::create('c', {0, 2}); + auto e = NDArrayFactory::create('c', {0, 2}); + + InteropDataBuffer xdb(x.dataBuffer()); + InteropDataBuffer ddb(d.dataBuffer()); + InteropDataBuffer zdb(z.dataBuffer()); + + + ::execReduceSame2(nullptr, reduce::SameOps::Sum, + &xdb, x.shapeInfo(), x.specialShapeInfo(), + nullptr, + &zdb, z.shapeInfo(), z.specialShapeInfo(), + &ddb, d.shapeInfo(), d.specialShapeInfo()); + +} + +TEST_F(LegacyOpsTests, test_legacy_transform_float_1) { + auto x = NDArrayFactory::create('c', {1, 0, 4}); + + NativeOpExecutioner::execTransformFloat(LaunchContext::defaultContext(), transform::FloatOps::RSqrt, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, nullptr, nullptr); +} diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ListOperationsTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ListOperationsTests.cpp new file mode 100644 index 000000000..36b4c016e --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ListOperationsTests.cpp @@ -0,0 +1,663 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include + +using namespace sd; +using namespace sd::ops; + +class ListOperationsTests : public testing::Test { + +}; + +TEST_F(ListOperationsTests, BasicTest_Write_1) { + NDArrayList list(5); + auto x = NDArrayFactory::create('c', {128}); + x.linspace(1); + + sd::ops::write_list op; + + auto result = op.execute(&list, {&x}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_EQ(1, list.elements()); + + auto result2 = op.execute(&list, {&x}, {}, {2}); + + ASSERT_EQ(2, list.elements()); + + + +} + +TEST_F(ListOperationsTests, BasicTest_Stack_1) { + NDArrayList list(10); + auto exp = NDArrayFactory::create('c', {10, 100}); + auto tads = exp.allTensorsAlongDimension({1}); + for (int e = 0; e < 10; e++) { + auto row = NDArrayFactory::create_('c', {100}); + row->assign((double) e); + list.write(e, row); + tads.at(e)->assign(row); + } + + sd::ops::stack_list op; + + auto result = op.execute(&list, {}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + // z->printShapeInfo(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(ListOperationsTests, BasicTest_UnStackList_1) { + NDArrayList list(0, true); + auto x = NDArrayFactory::create('c', {10, 100}); + auto tads = x.allTensorsAlongDimension({1}); + for (int e = 0; e < 10; e++) { + auto row = NDArrayFactory::create_('c', {100}); + row->assign((double) e); + //list.write(e, row); + tads.at(e)->assign(row); + delete row; + } + + sd::ops::unstack_list op; + + auto result = op.execute(&list, {&x}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(list.elements(), 10); + +// auto z = result.at(0); +// z->printShapeInfo("The first of"); +// ASSERT_TRUE(exp.isSameShape(z)); +// ASSERT_TRUE(exp.equalsTo(z)); + for (int e = 0; e < 10; e++) { + auto row = list.read(e); + ASSERT_TRUE(row->equalsTo(tads.at(e))); + //list.write(e, row); + delete row; + } + + +} + +//TEST_F(ListOperationsTests, BasicTest_UnStackList_2) { +//// NDArrayList list(0, true); +// auto x = NDArrayFactory::create('c', {10, 100}); +// auto tads = x.allTensorsAlongDimension({1}); +// for (int e = 0; e < 10; e++) { +// auto row = NDArrayFactory::create_('c', {100}); +// row->assign((double) e); +// //list.write(e, row); +// tads->at(e)->assign(row); +// delete row; +// } +// +// sd::ops::unstack_list op; +// +// auto result = op.execute(nullptr, {&x}, {}, {0}); +// +// ASSERT_EQ(ND4J_STATUS_OK, result.status()); +// ASSERT_EQ(result->size(), 10); +// +// // auto z = result.at(0); +//// z->printShapeInfo("The first of"); +//// ASSERT_TRUE(exp.isSameShape(z)); +//// ASSERT_TRUE(exp.equalsTo(z)); +// for (int e = 0; e < 10; e++) { +// auto row = result.at(e); +// ASSERT_TRUE(row->equalsTo(tads->at(e))); +// //list.write(e, row); +// } +// +// +// delete tads; +//} + +TEST_F(ListOperationsTests, BasicTest_Read_1) { + NDArrayList list(10); + auto exp = NDArrayFactory::create('c', {1, 100}); + exp.assign(4.0f); + + for (int e = 0; e < 10; e++) { + auto row = NDArrayFactory::create_('c', {1, 100}); + row->assign((double) e); + list.write(e, new NDArray(row->dup())); + + delete row; + } + + sd::ops::read_list op; + + auto result = op.execute(&list, {}, {}, {4}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(ListOperationsTests, BasicTest_Pick_1) { + NDArrayList list(10); + auto exp = NDArrayFactory::create('c', {4, 100}); + + for (int e = 0; e < 10; e++) { + auto row = NDArrayFactory::create_('c', {100}); + row->assign((double) e); + list.write(e, new NDArray(row->dup())); + + delete row; + } + + auto tads = exp.allTensorsAlongDimension({1}); + tads.at(0)->assign(1.0f); + tads.at(1)->assign(1.0f); + tads.at(2)->assign(3.0f); + tads.at(3)->assign(3.0f); + + + sd::ops::pick_list op; + auto result = op.execute(&list, {}, {}, {1, 1, 3, 3}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(ListOperationsTests, BasicTest_Size_1) { + NDArrayList list(10); + auto exp = NDArrayFactory::create(10); + for (int e = 0; e < 10; e++) { + auto row = NDArrayFactory::create_('c', {100}); + row->assign((double) e); + list.write(e, new NDArray(row->dup())); + + delete row; + } + + sd::ops::size_list op; + + auto result = op.execute(&list, {}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(ListOperationsTests, BasicTest_Create_1) { + auto matrix = NDArrayFactory::create('c', {3, 2}); + matrix.linspace(1); + + sd::ops::create_list op; + + auto result = op.execute(nullptr, {&matrix}, {}, {1, 1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + // we return flow as well + ASSERT_EQ(1, result.size()); + + +} + +TEST_F(ListOperationsTests, BasicTest_Split_1) { + NDArrayList list(0, true); + + auto exp0 = NDArrayFactory::create('c', {2, 5}); + auto exp1 = NDArrayFactory::create('c', {3, 5}); + auto exp2 = NDArrayFactory::create('c', {5, 5}); + + auto matrix = NDArrayFactory::create('c', {10, 5}); + + auto lengths = NDArrayFactory::create('c', {3}); + lengths.p(0, 2); + lengths.p(1, 3); + lengths.p(2, 5); + + auto tads = matrix.allTensorsAlongDimension({1}); + + auto tads0 = exp0.allTensorsAlongDimension({1}); + auto tads1 = exp1.allTensorsAlongDimension({1}); + auto tads2 = exp2.allTensorsAlongDimension({1}); + + int cnt0 = 0; + int cnt1 = 0; + int cnt2 = 0; + for (int e = 0; e < 10; e++) { + auto row = NDArrayFactory::create_('c', {5}); + row->assign((double) e); + tads.at(e)->assign(row); + + if (e < 2) + tads0.at(cnt0++)->assign(row); + else if (e < 5) + tads1.at(cnt1++)->assign(row); + else + tads2.at(cnt2++)->assign(row); + + delete row; + } + + sd::ops::split_list op; + auto result = op.execute(&list, {&matrix, &lengths}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + ASSERT_EQ(3, list.height()); + + ASSERT_TRUE(exp0.isSameShape(list.readRaw(0))); + ASSERT_TRUE(exp0.equalsTo(list.readRaw(0))); + + ASSERT_TRUE(exp1.isSameShape(list.readRaw(1))); + ASSERT_TRUE(exp1.equalsTo(list.readRaw(1))); + + ASSERT_TRUE(exp2.isSameShape(list.readRaw(2))); + ASSERT_TRUE(exp2.equalsTo(list.readRaw(2))); + + +} + +TEST_F(ListOperationsTests, BasicTest_Scatter_1) { + NDArrayList list(0, true); + auto s = NDArrayFactory::create(0.0); + + auto matrix = NDArrayFactory::create('c', {10, 5}); + auto tads = matrix.allTensorsAlongDimension({1}); + for (int e = 0; e < 10; e++) { + auto row = NDArrayFactory::create_('c', {1, 5}); + row->assign((double) e); + tads.at(e)->assign(row); + + delete row; + } + auto indices = NDArrayFactory::create('c', {1, 10}); + for (int e = 0; e < matrix.rows(); e++) + indices.p(e, 9 - e); + + sd::ops::scatter_list op; + auto result = op.execute(&list, {&indices, &matrix, &s}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + for (int e = 0; e < 10; e++) { + auto row = tads.at(9 - e); + auto chunk = list.readRaw(e); + + ASSERT_TRUE(chunk->isSameShape(row)); + + ASSERT_TRUE(chunk->equalsTo(row)); + } + +} + +TEST_F(ListOperationsTests, BasicTest_Clone_1) { + auto list = new NDArrayList(0, true); + + VariableSpace variableSpace; + auto var = new Variable(nullptr, nullptr, -1, 0); + var->setNDArrayList(list); + + variableSpace.putVariable(-1, var); + variableSpace.trackList(list); + + Context block(1, &variableSpace); + block.pickInput(-1); + + sd::ops::clone_list op; + + ASSERT_TRUE(list == block.variable(0)->getNDArrayList()); + + auto result = op.execute(&block); + + ASSERT_EQ(ND4J_STATUS_OK, result); + + auto resVar = variableSpace.getVariable(1); + + auto resList = resVar->getNDArrayList(); + + ASSERT_TRUE( resList != nullptr); + + ASSERT_TRUE(list->equals(*resList)); +} + +TEST_F(ListOperationsTests, BasicTest_Gather_1) { + NDArrayList list(0, true); + for (int e = 0; e < 10; e++) { + auto row = NDArrayFactory::create_('c', {3}); + row->assign((double) e); + list.write(e, new NDArray(row->dup())); + + delete row; + } + + auto exp = NDArrayFactory::create('c', {10, 3}); + auto tads = exp.allTensorsAlongDimension({1}); + for (int e = 0; e < 10; e++) { + auto tad = tads.at(9 - e); + tad->assign(e); + } + + auto indices = NDArrayFactory::create('c', {1, 10}); + indices.linspace(9, -1); + + sd::ops::gather_list op; + auto result = op.execute(&list, {&indices}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + ASSERT_EQ(1, result.size()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + + //exp.printIndexedBuffer("e"); + //z->printIndexedBuffer("z"); + + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(ListOperationsTests, GraphTests_Sequential_1) { + Graph graph; + + auto matrix = NDArrayFactory::create_('c', {3, 3}); + auto tads = matrix->allTensorsAlongDimension({1}); + for (int e = 0; e < tads.size(); e++) { + tads.at(e)->assign((float) (e+1)); + } + + + auto exp = NDArrayFactory::create('c', {3, 3}); + auto tadsExp = exp.allTensorsAlongDimension({1}); + tadsExp.at(0)->assign(0.f); + tadsExp.at(1)->assign(-1.f); + tadsExp.at(2)->assign(-2.f); + + auto indices = NDArrayFactory::valueOf({3}, 1, 'c'); + //indices->linspace(0); + + + auto variableSpace = graph.getVariableSpace(); + variableSpace->putVariable(-1, matrix); + variableSpace->putVariable(-2, indices); + + + auto nodeA = new Node(OpType_TRANSFORM_SAME, 0, 1, {-1}); + + // creating list + sd::ops::create_list opB; + auto nodeB = new Node(&opB, 2, {1},{},{}, 0.0f, {}, {0, 1}); + //nodeB->setCustomOp(&opB); + + // filling list with matrix + sd::ops::split_list opC; + auto nodeC = new Node(&opC, 3, {2, 1, -2}); + //nodeC->setCustomOp(&opC); + + // reading chunks from List. We're adding op number 3 in inputs, to ensure graph will execute this node after split + sd::ops::read_list opD; + auto nodeD0 = new Node(&opD, 5, {2, 3}, {},{}, 0.0f, {}, {0}); + auto nodeD1 = new Node(&opD, 6, {2, 3}, {},{}, 0.0f, {}, {1}); + auto nodeD2 = new Node(&opD, 7, {2, 3}, {},{}, 0.0f, {}, {2}); + //nodeD0->setCustomOp(&opD); + //nodeD1->setCustomOp(&opD); + //nodeD2->setCustomOp(&opD); + + // using OneMinus on each chunk separately + auto nodeE0 = new Node(OpType_TRANSFORM_SAME, sd::transform::OneMinus, 10, {5}); + auto nodeE1 = new Node(OpType_TRANSFORM_SAME, sd::transform::OneMinus, 11, {6}); + auto nodeE2 = new Node(OpType_TRANSFORM_SAME, sd::transform::OneMinus, 12, {7}); + + // writing chunks back to the List + sd::ops::write_list opF; + auto nodeF0 = new Node(&opF, 15, {2, 10}, {},{}, 0.0f, {}, {0}); + auto nodeF1 = new Node(&opF, 16, {2, 11}, {},{}, 0.0f, {}, {1}); + auto nodeF2 = new Node(&opF, 17, {2, 12}, {},{}, 0.0f, {}, {2}); + +// nodeF0->setCustomOp(&opF); +// nodeF1->setCustomOp(&opF); +// nodeF2->setCustomOp(&opF); + + // now we're stacking chunks back to matrix state + sd::ops::stack_list opG; + auto nodeG = new Node(&opG, 20, {2, 15, 16, 17}); + //auto nodeG = new Node(OpType_CUSTOM, 0, 20, {2}); + +// nodeG->setCustomOp(&opG); + + + graph.addNode(nodeA); + graph.addNode(nodeB); + graph.addNode(nodeC); + graph.addNode(nodeD0); + graph.addNode(nodeD1); + graph.addNode(nodeD2); + graph.addNode(nodeE0); + graph.addNode(nodeE1); + graph.addNode(nodeE2); + + graph.addNode(nodeF0); + graph.addNode(nodeF1); + graph.addNode(nodeF2); + + graph.addNode(nodeG); + + // let's also validate structural integrity + graph.buildGraph(); + + ASSERT_EQ(0, nodeA->getLayer()); + ASSERT_EQ(1, nodeB->getLayer()); + ASSERT_EQ(2, nodeC->getLayer()); + + ASSERT_EQ(3, nodeD0->getLayer()); + ASSERT_EQ(3, nodeD1->getLayer()); + ASSERT_EQ(3, nodeD2->getLayer()); + + ASSERT_EQ(4, nodeE0->getLayer()); + ASSERT_EQ(4, nodeE1->getLayer()); + ASSERT_EQ(4, nodeE2->getLayer()); + + ASSERT_EQ(5, nodeF0->getLayer()); + ASSERT_EQ(5, nodeF1->getLayer()); + ASSERT_EQ(5, nodeF2->getLayer()); + + ASSERT_EQ(6, nodeG->getLayer()); + + auto result = GraphExecutioner::execute(&graph); + ASSERT_EQ(ND4J_STATUS_OK, result); + + ASSERT_TRUE(variableSpace->hasVariable(2)); + auto list = variableSpace->getVariable(2)->getNDArrayList(); + + ASSERT_TRUE(list != nullptr); + + ASSERT_EQ(3, list->height()); + ASSERT_EQ(3, list->elements()); + + + ASSERT_TRUE(variableSpace->hasVariable(20)); + + auto stack = variableSpace->getVariable(20)->getNDArray(); + + ASSERT_TRUE(stack != nullptr); + + ASSERT_TRUE(exp.isSameShape(stack)); + ASSERT_TRUE(exp.equalsTo(stack)); +} + + +TEST_F(ListOperationsTests, GraphTests_Sequential_2) { + Graph graph; + + auto scalar = NDArrayFactory::create_(0.0f); + auto matrix = NDArrayFactory::create_('c', {3, 3}); + auto tads = matrix->allTensorsAlongDimension({1}); + for (int e = 0; e < tads.size(); e++) { + tads.at(e)->assign((float) (e+1)); + } + + auto exp = NDArrayFactory::create('c', {3, 3}); + auto tadsExp = exp.allTensorsAlongDimension({1}); + tadsExp.at(0)->assign(0.f); + tadsExp.at(1)->assign(-1.f); + tadsExp.at(2)->assign(-2.f); + + //auto indices = NDArray::valueOf({1, 3}, 1.0f, 'c'); + auto indices = NDArrayFactory::create_('c', {1, 3}); + indices->linspace(0); + + + auto variableSpace = graph.getVariableSpace(); + variableSpace->putVariable(-1, matrix); + variableSpace->putVariable(-2, indices); + variableSpace->putVariable(-3, scalar); + + + auto nodeA = new Node(OpType_TRANSFORM_SAME, 0, 1, {-1}); + + // creating list + sd::ops::create_list opB; + auto nodeB = new Node(&opB, 2, {1},{},{}, 0.0f, {}, {0, 1}); +// nodeB->setCustomOp(&opB); + + // filling list with matrix + sd::ops::scatter_list opC; + auto nodeC = new Node(&opC, 3, {2, -2, 1, -3}); + + //nodeC->setCustomOp(&opC); + + sd::ops::read_list opD; + auto nodeD0 = new Node(&opD, 5, {2, 3}, {},{}, 0.0f, {}, {0}); + auto nodeD1 = new Node(&opD, 6, {2, 3, 15}, {},{}, 0.0f, {}, {1}); + auto nodeD2 = new Node(&opD, 7, {2, 3, 16}, {},{}, 0.0f, {}, {2}); + +// nodeD0->setCustomOp(&opD); +// nodeD1->setCustomOp(&opD); +// nodeD2->setCustomOp(&opD); + + + // using OneMinus on each chunk separately + auto nodeE0 = new Node(OpType_TRANSFORM_SAME, sd::transform::OneMinus, 10, {5}); + auto nodeE1 = new Node(OpType_TRANSFORM_SAME, sd::transform::OneMinus, 11, {6}); + auto nodeE2 = new Node(OpType_TRANSFORM_SAME, sd::transform::OneMinus, 12, {7}); + + // writing chunks back to the List + sd::ops::write_list opF; + auto nodeF0 = new Node(&opF, 15, {2, 10}, {},{}, 0.0f, {}, {0}); + auto nodeF1 = new Node(&opF, 16, {2, 11}, {},{}, 0.0f, {}, {1}); + auto nodeF2 = new Node(&opF, 17, {2, 12}, {},{}, 0.0f, {}, {2}); + +// nodeF0->setCustomOp(&opF); +// nodeF1->setCustomOp(&opF); +// nodeF2->setCustomOp(&opF); + + // now we're gathering chunks back to matrix state + sd::ops::pick_list opG; + auto nodeG = new Node(&opG, 20, {2, -2, 15, 16, 17}); + //auto nodeG = new Node(OpType_CUSTOM, 0, 20, {2}); + + //nodeG->setCustomOp(&opG); + + graph.addNode(nodeA); + graph.addNode(nodeB); + graph.addNode(nodeC); + graph.addNode(nodeD0); + graph.addNode(nodeD1); + graph.addNode(nodeD2); + graph.addNode(nodeE0); + graph.addNode(nodeE1); + graph.addNode(nodeE2); + + graph.addNode(nodeF0); + graph.addNode(nodeF1); + graph.addNode(nodeF2); + + graph.addNode(nodeG); + + // let's also validate structural integrity + graph.buildGraph(); + + ASSERT_EQ(0, nodeA->getLayer()); + ASSERT_EQ(1, nodeB->getLayer()); + ASSERT_EQ(2, nodeC->getLayer()); + + ASSERT_EQ(3, nodeD0->getLayer()); + ASSERT_EQ(4, nodeE0->getLayer()); + ASSERT_EQ(5, nodeF0->getLayer()); + + ASSERT_EQ(6, nodeD1->getLayer()); + ASSERT_EQ(7, nodeE1->getLayer()); + ASSERT_EQ(8, nodeF1->getLayer()); + + ASSERT_EQ(9, nodeD2->getLayer()); + ASSERT_EQ(10, nodeE2->getLayer()); + ASSERT_EQ(11, nodeF2->getLayer()); + + ASSERT_EQ(12, nodeG->getLayer()); + + + auto result = GraphExecutioner::execute(&graph); + ASSERT_EQ(ND4J_STATUS_OK, result); + + ASSERT_TRUE(variableSpace->hasVariable(2)); + auto list = variableSpace->getVariable(2)->getNDArrayList(); + + ASSERT_TRUE(list != nullptr); + + ASSERT_EQ(3, list->height()); + ASSERT_EQ(3, list->elements()); + + ASSERT_TRUE(variableSpace->hasVariable(20)); + + auto stack = variableSpace->getVariable(20)->getNDArray(); + + ASSERT_TRUE(stack != nullptr); + + ASSERT_TRUE(exp.isSameShape(stack)); + ASSERT_TRUE(exp.equalsTo(stack)); +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/LoopCoordsHelperTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/LoopCoordsHelperTests.cpp new file mode 100644 index 000000000..7e2da69b2 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/LoopCoordsHelperTests.cpp @@ -0,0 +1,225 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // + // @author Abdelrauf + // + +#include "testlayers.h" +#include +#include +using namespace sd; + +class LoopCoordsHelper : public testing::Test { +public: + +}; + + +template +FORCEINLINE +typename std::enable_if<(Rank - 1 == rankIndex), bool>::type +eq_strides(CoordsState& cbs, const Nd4jLong* strides) { + return STRIDE(cbs, rankIndex) == strides[rankIndex]; +} + +template +FORCEINLINE +typename std::enable_if<(Rank - 1 != rankIndex), bool>::type +eq_strides(CoordsState& cbs, const Nd4jLong* strides) { + return STRIDE(cbs, rankIndex) == strides[rankIndex] && eq_strides(cbs, strides); +} + +template +FORCEINLINE +typename std::enable_if<(Rank - 1 == rankIndex), bool>::type +eq_zip_strides(ZipCoordsState& cbs, const Nd4jLong* strides1, const Nd4jLong* strides2) { + return ZIP_STRIDE1(cbs, rankIndex) == strides1[rankIndex] && ZIP_STRIDE2(cbs, rankIndex) == strides2[rankIndex]; +} + +template +FORCEINLINE +typename std::enable_if<(Rank - 1 != rankIndex), bool>::type +eq_zip_strides(ZipCoordsState& cbs, const Nd4jLong* strides1, const Nd4jLong* strides2) { + return ZIP_STRIDE1(cbs, rankIndex) == strides1[rankIndex] && ZIP_STRIDE2(cbs, rankIndex) == strides2[rankIndex] + && eq_zip_strides(cbs, strides1, strides2); +} + + + + +TEST_F(LoopCoordsHelper, Init_Tests) { + + constexpr size_t test_Index = 131; + constexpr size_t Rank = 5; + + Nd4jLong shape[Rank] = { 3, 5 ,7, 8, 9}; + Nd4jLong multiply_st[] = { 2,3,3,5,6,7,9,3 }; + Nd4jLong strides_c[Rank] ; + Nd4jLong strides_f[Rank]; + + Nd4jLong coords[Rank]; + Nd4jLong coords_f[Rank]; + + strides_f[0] = multiply_st[0] * shape[0]; + strides_c[Rank-1] = multiply_st[Rank-1] * shape[Rank-1]; + + for (int i = 1; i < Rank; i++) { + strides_f[i] = strides_f[i - 1] * multiply_st[i] * shape[i]; + } + + for (int i = Rank-2; i >=0; i--) { + strides_c[i] = strides_c[i+1] * multiply_st[i] * shape[i]; + } + + //init our base coords + index2coords_C(test_Index, Rank, shape, coords); + index2coords_F(test_Index, Rank, shape, coords_f); + + + size_t offset_calc = offset_from_coords(strides_c, coords, Rank); + size_t offset_calc_f = offset_from_coords(strides_f, coords_f, Rank); + + CoordsState cts; + CoordsState cts_f; + + ZipCoordsState zcts; + ZipCoordsState zcts_f; + + size_t offset = init_coords(cts, test_Index, shape, strides_c); + size_t offset_f = init_coords(cts_f, test_Index, shape, strides_f); + + zip_size_t zoffset = init_coords(zcts, test_Index, shape, strides_c, strides_c); + zip_size_t zoffset_f = init_coords(zcts_f, test_Index, shape, strides_f, strides_f); + + ASSERT_TRUE(eq_coords(cts, coords)); + ASSERT_TRUE(eq_coords(cts_f, coords_f)); + + ASSERT_TRUE(eq_zip_coords(zcts, coords)); + ASSERT_TRUE(eq_zip_coords(zcts_f, coords_f)); + + ASSERT_TRUE(eq_strides(cts,strides_c)); + ASSERT_TRUE(eq_strides(cts_f,strides_f)); + + ASSERT_TRUE(eq_zip_strides(zcts, strides_c, strides_c)); + ASSERT_TRUE(eq_zip_strides(zcts_f, strides_f, strides_f)); + + + ASSERT_EQ(offset , offset_calc); + ASSERT_EQ(zoffset.first , offset_calc); + ASSERT_EQ(zoffset.second , offset_calc); + ASSERT_EQ(offset_f , offset_calc_f); + ASSERT_EQ(zoffset_f.first , offset_calc_f); + ASSERT_EQ(zoffset_f.second , offset_calc_f); +} + +TEST_F(LoopCoordsHelper, Increment_Use_Tests) { + + + constexpr size_t Rank = 4; + + Nd4jLong shape[Rank] = { 3, 5 ,7, 8 }; + Nd4jLong multiply_st[] = { 2,3,3,5,6,7,9,3 }; + Nd4jLong strides_c[Rank]; + Nd4jLong strides_f[Rank]; + + Nd4jLong coords[Rank] = {}; + Nd4jLong coords_f[Rank] = {}; + Nd4jLong coords2[Rank] = {}; + Nd4jLong coords2_f[Rank] = {}; + Nd4jLong zcoords2[Rank] = {}; + Nd4jLong zcoords2_f[Rank] = {}; + + strides_f[0] = multiply_st[0] * shape[0]; + strides_c[Rank - 1] = multiply_st[Rank - 1] * shape[Rank - 1]; + + for (int i = 1; i < Rank; i++) { + strides_f[i] = strides_f[i - 1] * multiply_st[i] * shape[i]; + } + + for (int i = Rank - 2; i >= 0; i--) { + strides_c[i] = strides_c[i + 1] * multiply_st[i] * shape[i]; + } + + int total = 1; + for (int i = 0; i < Rank; i++) { + total *= shape[i]; + } + + CoordsState cts; + CoordsState cts_f; + + ZipCoordsState zcts; + ZipCoordsState zcts_f; + + size_t offset = init_coords(cts, 0, shape, strides_c); + size_t offset_f = init_coords(cts_f, 0, shape, strides_f); + + zip_size_t zoffset = init_coords(zcts, 0, shape, strides_c, strides_c); + zip_size_t zoffset_f = init_coords(zcts_f, 0, shape, strides_f, strides_f); + + size_t offset2 = 0; + size_t offset2_f = 0; + zip_size_t zoffset2 = {}; + zip_size_t zoffset2_f = {}; + + for (int j = 0; j < total; j++) { + + + index2coords_C(j, Rank, shape, coords); + index2coords_F(j, Rank, shape, coords_f); + + size_t offset_calc = offset_from_coords(strides_c, coords, Rank); + size_t offset_calc_f = offset_from_coords(strides_f, coords_f, Rank); + + + ASSERT_TRUE(eq_coords(cts, coords)); + ASSERT_TRUE(eq_coords(cts_f, coords_f)); + + ASSERT_TRUE(eq_zip_coords(zcts, coords)); + ASSERT_TRUE(eq_zip_coords(zcts_f, coords_f)); + + ASSERT_EQ(offset, offset_calc); + ASSERT_EQ(zoffset.first, offset_calc); + ASSERT_EQ(zoffset.second, offset_calc); + ASSERT_EQ(offset_f, offset_calc_f); + ASSERT_EQ(zoffset_f.first, offset_calc_f); + ASSERT_EQ(zoffset_f.second, offset_calc_f); + + + ASSERT_EQ(offset2, offset_calc); + ASSERT_EQ(zoffset2.first, offset_calc); + ASSERT_EQ(zoffset2.second, offset_calc); + ASSERT_EQ(offset2_f, offset_calc_f); + ASSERT_EQ(zoffset2_f.first, offset_calc_f); + ASSERT_EQ(zoffset2_f.second, offset_calc_f); + + offset = inc_coords(cts, offset); + offset_f = inc_coords(cts_f, offset_f); + zoffset = inc_coords(zcts, zoffset); + zoffset_f = inc_coords(zcts_f, zoffset_f); + + offset2 = inc_coords(shape,strides_c, coords2, offset2, Rank); + offset2_f = inc_coords(shape, strides_f, coords2_f, offset2_f, Rank); + zoffset2 = inc_coords(shape, strides_c, strides_c, zcoords2, zoffset2, Rank); + zoffset2_f = inc_coords(shape, strides_f, strides_f, zcoords2_f, zoffset2_f, Rank); + + } + +} + diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/MemoryUtilsTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/MemoryUtilsTests.cpp new file mode 100644 index 000000000..a8514e5ad --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/MemoryUtilsTests.cpp @@ -0,0 +1,48 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 11.10.2017. +// + +#include +#include +#include "testlayers.h" + +using namespace sd::memory; + +class MemoryUtilsTests : public testing::Test { +public: + +}; + +TEST_F(MemoryUtilsTests, BasicRetrieve_1) { + MemoryReport reportA; + MemoryReport reportB; + +#ifdef _WIN32 + if (1 > 0) + return; +#endif + + + MemoryUtils::retrieveMemoryStatistics(reportA); + + + ASSERT_NE(reportA, reportB); +} diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/MklDnnTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/MklDnnTests.cpp new file mode 100644 index 000000000..d2616e946 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/MklDnnTests.cpp @@ -0,0 +1,111 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +// +// @author raver119@gmail.com +// + +#ifdef HAVE_MKLDNN + +#include "testlayers.h" +#include +#include +#include +#include +#include + + +using namespace sd; + +class MklDnnTests : public testing::Test { +public: + +}; + +static void printer(std::initializer_list helpers) { + + for (auto v:helpers) { + nd4j_printf("Initialized [%s]\n", v->name().c_str()); + } +} + + +TEST_F(MklDnnTests, helpers_includer) { + // we need this block, to make sure all helpers are still available within binary, and not optimized out by linker + sd::ops::platforms::PLATFORM_conv2d_ENGINE_CPU conv2d; + sd::ops::platforms::PLATFORM_conv2d_bp_ENGINE_CPU conv2d_bp; + + sd::ops::platforms::PLATFORM_conv2d_ENGINE_CPU conv3d; + sd::ops::platforms::PLATFORM_conv2d_bp_ENGINE_CPU conv3d_bp; + + sd::ops::platforms::PLATFORM_avgpool2d_ENGINE_CPU avgpool2d; + sd::ops::platforms::PLATFORM_avgpool2d_bp_ENGINE_CPU avgpool2d_bp; + + sd::ops::platforms::PLATFORM_maxpool2d_ENGINE_CPU maxpool2d; + sd::ops::platforms::PLATFORM_maxpool2d_bp_ENGINE_CPU maxpool2d_bp; + + sd::ops::platforms::PLATFORM_avgpool3dnew_ENGINE_CPU avgpool3d; + sd::ops::platforms::PLATFORM_avgpool3dnew_bp_ENGINE_CPU avgpool3d_bp; + + sd::ops::platforms::PLATFORM_maxpool3dnew_ENGINE_CPU maxpool3d; + sd::ops::platforms::PLATFORM_maxpool3dnew_bp_ENGINE_CPU maxpool3d_bp; + + sd::ops::platforms::PLATFORM_lrn_ENGINE_CPU lrn; + + sd::ops::platforms::PLATFORM_batchnorm_ENGINE_CPU batchnorm; + + sd::ops::platforms::PLATFORM_matmul_ENGINE_CPU matmul; + + sd::ops::platforms::PLATFORM_softmax_ENGINE_CPU softmax; + + sd::ops::platforms::PLATFORM_softmax_bp_ENGINE_CPU softmax_bp; + + sd::ops::platforms::PLATFORM_tanh_ENGINE_CPU tanh; + + sd::ops::platforms::PLATFORM_tanh_ENGINE_CPU tanh_bp; + + sd::ops::platforms::PLATFORM_xw_plus_b_ENGINE_CPU xw_plus_b; + + sd::ops::platforms::PLATFORM_xw_plus_b_bp_ENGINE_CPU xw_plus_b_bp; + + + printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm, &matmul, &softmax, &softmax_bp, &tanh, &tanh_bp, &xw_plus_b, &xw_plus_b_bp }); +} + +TEST_F(MklDnnTests, test_tanh_1) { + auto x = NDArrayFactory::create(1.0f); + auto z = NDArrayFactory::create(0.0f); + + sd::ops::tanh op; + auto status = op.execute({&x}, {&z}); + + ASSERT_EQ(Status::OK(), status); +} + +TEST_F(MklDnnTests, test_tanh_2) { + auto x = NDArrayFactory::create('c', {1}, {1.0f}); + auto z = NDArrayFactory::create('c', {1}, {0.0f}); + + sd::ops::tanh op; + auto status = op.execute({&x}, {&z}); + + ASSERT_EQ(Status::OK(), status); +} + +#endif \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/MmapTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/MmapTests.cpp new file mode 100644 index 000000000..8e747a23e --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/MmapTests.cpp @@ -0,0 +1,57 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver on 5/13/2018. +// + +#include "testlayers.h" +#include +#include +#include +#include + +using namespace sd; +using namespace sd::graph; + +class MmapTests : public testing::Test { +public: + +}; + +TEST_F(MmapTests, Test_Basic_Mmap_1) { + // FIXME: we must adopt this for CUDA as well + if (!Environment::getInstance().isCPU()) + return; + + // just 10GB + Nd4jLong size = 100000L; + + std::ofstream ofs("file", std::ios::binary | std::ios::out); + ofs.seekp(size + 1024L); + ofs.write("", 1); + ofs.close(); + + auto result = mmapFile(nullptr, "file", size); + + ASSERT_FALSE(result == nullptr); + + munmapFile(nullptr, result, size); + + remove("file"); +} diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/MultiDataTypeTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/MultiDataTypeTests.cpp new file mode 100644 index 000000000..2e780bdb0 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/MultiDataTypeTests.cpp @@ -0,0 +1,1984 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include + + +using namespace sd; + +class MultiDataTypeTests : public testing::Test { +public: + +}; + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, DataTypeUtils_Test_1) { + auto dtype = DataTypeUtils::pickPairwiseResultType(sd::INT32, sd::FLOAT32); + + ASSERT_EQ(sd::FLOAT32, dtype); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, DataTypeUtils_Test_2) { + auto dtype = DataTypeUtils::pickPairwiseResultType(sd::INT32, sd::DOUBLE); + ASSERT_EQ(sd::DOUBLE, dtype); + + ASSERT_EQ(sd::DOUBLE, DataTypeUtils::pickPairwiseResultType(sd::DOUBLE, sd::INT32)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, DataTypeUtils_Test_3) { + auto dtype = DataTypeUtils::pickPairwiseResultType(sd::FLOAT32, sd::DOUBLE); + ASSERT_EQ(sd::FLOAT32, dtype); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, Basic_Test_1) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + auto x = NDArrayFactory::create('c', {2, 3}, {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); + auto y = NDArrayFactory::create('c', {2, 3}, {0.0, 1.0, 2.0, 3.0, 4.0, 5.0}); + auto e = NDArrayFactory::create('c', {2, 3}, {0.0f, 2.0f, 4.0f, 6.0f, 8.0f, 10.0f}); + + auto z = x + y; + + ASSERT_EQ(e, z); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, Basic_Test_2) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + auto x = NDArrayFactory::create('c', {2, 3}, {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); + auto y = NDArrayFactory::create(2.0); + auto e = NDArrayFactory::create('c', {2, 3}, {0.0f, 2.0f, 4.0f, 6.0f, 8.0f, 10.0f}); + + auto z = x * y; + + ASSERT_EQ(e, z); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, Basic_Test_3) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + auto x = NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create(2.0); + auto e = NDArrayFactory::create('c', {2, 3}, {0.0f, 2.0f, 4.0f, 6.0f, 8.0f, 10.0f}); + + auto z = x * y; + + ASSERT_EQ(e, z); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, Basic_Test_4) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + auto x = NDArrayFactory::create('c', {2, 3}, {0.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f}); + auto y = NDArrayFactory::create(2.0); + auto e = NDArrayFactory::create('c', {2, 3}, {0.0f, 2.0f, 4.0f, 6.0f, 8.0f, 10.0f}); + + auto z = x * y; + + ASSERT_EQ(e, z); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, Basic_Test_5) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + auto x = NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create(2); + auto e = NDArrayFactory::create('c', {2, 3}, {0, 2, 4, 6, 8, 10}); + + auto z = x * y; + + ASSERT_EQ(e, z); +} + +TEST_F(MultiDataTypeTests, Basic_Test_7) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + auto x = NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {2, 3}, {0.f, 1.f, 2.f, 3.f, 4.f, 5.f}); + auto e = NDArrayFactory::create('c', {2, 3}, {0.f, 2.f, 4.f, 6.f, 8.f, 10.f}); + + sd::ops::add op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_EQ(e, *z); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, Basic_Test_6) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + auto x = NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create(2); + auto e = NDArrayFactory::create('c', {2, 3}, {0, 2, 4, 6, 8, 10}); + + auto z = x * y; + + ASSERT_EQ(e, z); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_assign_number_test1) { + NDArray x('c', {2, 3}, {0, 1, 2, 3, 4, 5}, sd::DataType::UINT8); + NDArray exp('c', {2, 3}, {10, 10, 10, 10, 10, 10}, sd::DataType::UINT8); + + const double number = 10.8; + x = number; + + ASSERT_EQ(x,exp); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_assign_number_test2) { + NDArray x('c', {2, 3}, {0, 1, 2, 3, 4, 5}, sd::DataType::INT64); + NDArray exp('c', {2, 3}, {1, 1, 1, 1, 1, 1}, sd::DataType::INT64); + + const bool number = 1000; + x = number; + + ASSERT_EQ(x,exp); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_assign_number_test3) { + NDArray x('c', {2, 3}, {0, 1, 0, 1, 0, 1}, sd::DataType::BOOL); + NDArray exp('c', {2, 3}, {1, 1, 1, 1, 1, 1}, sd::DataType::BOOL); + + const int number = 1000; + x = number; + + ASSERT_EQ(x,exp); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_repeat_test1) { + NDArray x('c', {2, 2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::HALF); + NDArray y('c', {2, 4}, sd::DataType::HALF); + NDArray exp('c', {2, 4}, {0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5}, sd::DataType::HALF); + + x.repeat(1, {2}, y); + + ASSERT_EQ(y, exp); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_bufferAsT_test1) { + NDArray x('f', {2}, {1.5, 3.5}, sd::DataType::FLOAT32); + NDArray y('c', {}, std::vector{1.5}, sd::DataType::FLOAT32); + + const int* buffX = x.bufferAsT(); + const int* buffY = y.bufferAsT(); + + ASSERT_EQ(*buffX, *buffY); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_assign_test1) { + NDArray x('c', {2,2}, {0, 1, 2, 3}, sd::DataType::UINT8); + NDArray exp('c', {2,2}, {10, 10, 20, 20}, sd::DataType::UINT8); + + NDArray scalar1('c', {}, std::vector{10.5}, sd::DataType::FLOAT32); + NDArray scalar2('c', {}, std::vector{20.8}, sd::DataType::DOUBLE); + + x(0,{0}).assign(scalar1); + x(1,{0}).assign(scalar2); + + ASSERT_EQ(x, exp); + + x.assign(exp); + + ASSERT_EQ(x, exp); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test1) { + NDArray x('f', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::HALF); + NDArray exp1('c', {}, std::vector{3}, sd::DataType::INT64); + NDArray exp2('c', {1,1}, std::vector{1}, sd::DataType::INT64); + NDArray exp3('c', {2}, std::vector{1,2}, sd::DataType::INT64); + + auto scalar1 = x.reduceAlongDimension(sd::reduce::CountNonZero, {}/*whole range*/); + ASSERT_EQ(scalar1, exp1); + + auto scalar2 = x.reduceAlongDimension(sd::reduce::CountZero, {}/*whole range*/, true); + ASSERT_EQ(scalar2, exp2); + + auto scalar3 = x.reduceAlongDimension(sd::reduce::CountNonZero, {1}); + ASSERT_EQ(scalar3, exp3); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test2) { + NDArray x('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT32); + NDArray exp1('c', {}, std::vector{1.5}, sd::DataType::FLOAT32); + NDArray exp2('c', {2}, {0.5,2.5}, sd::DataType::FLOAT32); + + auto scalar1 = x.reduceAlongDimension(sd::reduce::Mean, {}/*whole range*/); + // scalar1->printShapeInfo(); + // scalar1->printIndexedBuffer(); + ASSERT_EQ(scalar1, exp1); + + auto scalar2 = x.reduceAlongDimension(sd::reduce::Mean, {1}); + ASSERT_EQ(scalar2, exp2); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test3) { + NDArray x('c', {2, 2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::HALF); + NDArray exp1('c', {}, std::vector{8.}, sd::DataType::HALF); + NDArray exp2('c', {2}, {2.,6.}, sd::DataType::HALF); + + auto scalar1 = x.reduceAlongDimension(sd::reduce::Sum, {}/*whole range*/); + ASSERT_EQ(scalar1, exp1); + + auto scalar2 = x.reduceAlongDimension(sd::reduce::Sum, {1}); + ASSERT_EQ(scalar2, exp2); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_reduceAlongDimension_test4) { + NDArray x('c', {2, 2}, {10.5, 1.5, -2.5, -3.5}, sd::DataType::HALF); + NDArray exp1('c', {}, std::vector{1}, sd::DataType::BOOL); + NDArray exp2('c', {2}, std::vector{1, 0}, sd::DataType::BOOL); + + auto scalar1 = x.reduceAlongDimension(sd::reduce::IsPositive, {}/*whole range*/); + ASSERT_EQ(scalar1, exp1); + + auto scalar2 = x.reduceAlongDimension(sd::reduce::IsPositive, {1}); + ASSERT_EQ(scalar2, exp2); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_varianceNumber_test1) { + NDArray x('f', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray exp1('c', {}, std::vector{1.666666667}, sd::DataType::FLOAT32); + NDArray exp2('c', {}, std::vector{1.118033989}, sd::DataType::FLOAT32); + + auto scalar1 = x.varianceNumber(variance::SummaryStatsVariance); + ASSERT_EQ(scalar1, exp1); + + auto scalar2 = x.varianceNumber(variance::SummaryStatsStandardDeviation, false); + ASSERT_EQ(scalar2, exp2); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_operatorPlus_test1) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {-1, -2, -1, -2}, sd::DataType::FLOAT32); + NDArray x3('c', {2}, {-1, -2}, sd::DataType::FLOAT32); + + NDArray exp('c', {2, 2}, {-1, -1, 1, 1}, sd::DataType::FLOAT32); + + ASSERT_EQ(x1+x2, exp); + ASSERT_EQ(x1+x3, exp); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_operatorPlus_test2) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); + NDArray x3('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::HALF); + const double val1 = -2; + const int val2 = -2; + NDArray exp1('c', {2,2}, {-2, -1, 0, 1}, sd::DataType::DOUBLE); + NDArray exp2('c', {2,2}, {-2, -1, 0, 1}, sd::DataType::FLOAT32); + NDArray exp3('c', {2,2}, {-2, -1, 0, 1}, sd::DataType::HALF); + + ASSERT_EQ(x1+val1, exp1); + ASSERT_EQ(val1+x1, exp1); + + ASSERT_EQ(x2+val2, exp2); + ASSERT_EQ(val2+x2, exp2); + + ASSERT_EQ(x3+val1, exp3); + ASSERT_EQ(val1+x3, exp3); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_operatorMinus_test1) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {-1, -2, -1, -2}, sd::DataType::HALF); + NDArray x3('c', {2}, {-1, -2}, sd::DataType::HALF); + + NDArray exp('c', {2, 2}, {1, 3, 3, 5}, sd::DataType::HALF); + + ASSERT_EQ(x1-x2, exp); + ASSERT_EQ(x1-x3, exp); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_operatorMinus_test2) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); + NDArray x3('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::HALF); + const double val1 = 2; + const int val2 = 2; + NDArray exp1('c', {2,2}, {-2, -1, 0, 1}, sd::DataType::DOUBLE); + NDArray exp2('c', {2,2}, {2, 1, 0, -1}, sd::DataType::DOUBLE); + NDArray exp3('c', {2,2}, {-2, -1, 0, 1}, sd::DataType::FLOAT32); + NDArray exp4('c', {2,2}, {-2, -1, 0, 1}, sd::DataType::HALF); + NDArray exp5('c', {2,2}, {2, 1, 0, -1}, sd::DataType::FLOAT32); + NDArray exp6('c', {2,2}, {2, 1, 0, -1}, sd::DataType::HALF); + + ASSERT_EQ(x1-val1, exp1); + ASSERT_EQ(val1-x1, exp2); + + ASSERT_EQ(x2-val2, exp3); + ASSERT_EQ(val2-x2, exp5); + + ASSERT_EQ(x3-val1, exp4); + ASSERT_EQ(val1-x3, exp6); +} + +//////////////////////////////////////////////////////////////////////////////// multiply +TEST_F(MultiDataTypeTests, ndarray_operatorMultiply_test1) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {-1, -2, -1, -2}, sd::DataType::DOUBLE); + NDArray x3('c', {2}, {-1, -2}, sd::DataType::DOUBLE); + + NDArray exp('c', {2, 2}, {0, -2, -2, -6}, sd::DataType::DOUBLE); + + ASSERT_EQ(x1*x2, exp); + ASSERT_EQ(x1*x3, exp); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_operatorMultiply_test2) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x1('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); + NDArray x3('c', {2, 2}, {0, 1, 2, 3}, sd::DataType::HALF); + const double val1 = -2; + const int val2 = -2; + NDArray exp1('c', {2,2}, {0, -2, -4, -6}, sd::DataType::DOUBLE); + NDArray exp2('c', {2,2}, {0, -2, -4, -6}, sd::DataType::FLOAT32); + NDArray exp3('c', {2,2}, {0, -2, -4, -6}, sd::DataType::HALF); + + ASSERT_EQ(x1*val1, exp1); + ASSERT_EQ(val1*x1, exp1); + + ASSERT_EQ(x2*val2, exp2); + ASSERT_EQ(val2*x2, exp2); + + ASSERT_EQ(x3*val1, exp3); + ASSERT_EQ(val1*x3, exp3); +} + + +//////////////////////////////////////////////////////////////////////////////// multiply +TEST_F(MultiDataTypeTests, ndarray_operatorDivide_test1) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x1('c', {2, 2}, {4, 1, 2, 3}, sd::DataType::HALF); + NDArray x2('c', {2, 2}, {-1, -2, -1, -9}, sd::DataType::DOUBLE); + NDArray x3('c', {2}, {-1, -2}, sd::DataType::FLOAT32); + + NDArray exp1('c', {2, 2}, {-4, -0.5, -2, -0.3333333}, sd::DataType::HALF); + NDArray exp2('c', {2, 2}, {-0.25, -2, -0.5, -0.666667}, sd::DataType::HALF); + + ASSERT_EQ(x1/x2, exp1); + ASSERT_EQ(x3/x1, exp2); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_operatorDivide_test2) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x1('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::INT64); + NDArray x2('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::FLOAT32); + NDArray x3('c', {2, 2}, {1, 2, 3, 4}, sd::DataType::HALF); + const double val1 = 2; + const int val2 = -2; + NDArray exp1('c', {2,2}, {0.5, 1, 1.5, 2}, sd::DataType::DOUBLE); + NDArray exp2('c', {2,2}, {2, 1, 0.666667, 0.5}, sd::DataType::DOUBLE); + NDArray exp3('c', {2,2}, {0, -1, -1, -2}, sd::DataType::INT64); + NDArray exp4('c', {2,2}, {-2, -1, 0., 0.}, sd::DataType::INT64); + NDArray exp5('c', {2,2}, {-0.5, -1, -1.5, -2}, sd::DataType::FLOAT32); + NDArray exp6('c', {2,2}, {-2, -1, -0.666667, -0.5}, sd::DataType::FLOAT32); + NDArray exp7('c', {2,2}, {0.5, 1, 1.5, 2}, sd::DataType::HALF); + NDArray exp8('c', {2,2}, {2, 1, 0.666667, 0.5}, sd::DataType::HALF); + + ASSERT_EQ(x1/val1, exp1); + ASSERT_EQ(val1/x1, exp2); + + ASSERT_EQ(x1/val2, exp3); + ASSERT_EQ(val2/x1, exp4); + + ASSERT_EQ(x2/val2, exp5); + ASSERT_EQ(val2/x2, exp6); + + ASSERT_EQ(x3/val1, exp7); + ASSERT_EQ(val1/x3, exp8); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_operatorPlusEqual_test1) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray scalar1('c', {0}, std::vector{4}, sd::DataType::INT32); + NDArray scalar2('c', {0}, std::vector{1.5}, sd::DataType::HALF); + + NDArray x1('c', {2,3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, sd::DataType::FLOAT32); + NDArray x2('c', {3,2}, {10, 20, 30, 40, 50, 60}, sd::DataType::INT64); + NDArray x3('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x4('c', {2}, {0.4, 0.5}, sd::DataType::HALF); + NDArray x5('c', {2,2}, {0, 1, 2, 3}, sd::DataType::HALF); + NDArray x6('c', {2}, {0.4, 0.5}, sd::DataType::FLOAT32); + + NDArray exp1('c', {0}, std::vector{5}, sd::DataType::INT32); + NDArray exp2('c', {0}, std::vector{6.5}, sd::DataType::HALF); + NDArray exp3('c', {3,2}, {11, 22, 33, 44, 55, 66}, sd::DataType::INT64); + NDArray exp4('c', {2,3}, {12.5, 24.5, 36.5, 48.5, 60.5, 72.5}, sd::DataType::FLOAT32); + NDArray exp5('c', {2,2}, {0.4, 1.5, 2.4, 3.5}, sd::DataType::HALF); + + scalar1 += scalar2; + ASSERT_EQ(scalar1, exp1); + + scalar2 += scalar1; + ASSERT_EQ(scalar2, exp2); + + x2 += x1; + ASSERT_EQ(x2, exp3); + + x1 += x2; + ASSERT_EQ(x1, exp4); + + x4 += x3; + ASSERT_EQ(x4, exp5); + + x6 += x5; + ASSERT_EQ(x6, exp5); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_operatorPlusEqual_test2) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); + NDArray x2('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT32); + + const Nd4jLong val1 = 1; + const float16 val2 = 1.5; + const double val3 = 2.2; + + NDArray exp1('c', {2,2}, {1, 2, 3, 4}, sd::DataType::FLOAT32); + NDArray exp2('c', {2,2}, {1, 2, 3, 4}, sd::DataType::INT32); + NDArray exp3('c', {2,2}, {2.5, 3.5, 4.5, 5.5}, sd::DataType::FLOAT32); + NDArray exp4('c', {2,2}, {2, 3, 4.5, 5}, sd::DataType::INT32); + NDArray exp5('c', {2,2}, {4.7, 5.7, 6.7, 7.7}, sd::DataType::FLOAT32); + NDArray exp6('c', {2,2}, {4, 5, 6, 7}, sd::DataType::INT32); + + x1 += val1; + ASSERT_EQ(x1, exp1); + + x2 += val1; + ASSERT_EQ(x2, exp2); + + x1 += val2; + ASSERT_EQ(x1, exp3); + + x2 += val2; + ASSERT_EQ(x2, exp4); + + x1 += val3; + ASSERT_EQ(x1, exp5); + + x2 += val3; + ASSERT_EQ(x2, exp6); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_operatorMinusEqual_test1) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray scalar1('c', {0}, std::vector{4}, sd::DataType::INT32); + NDArray scalar2('c', {0}, std::vector{1.5}, sd::DataType::HALF); + + NDArray x1('c', {2,3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, sd::DataType::FLOAT32); + NDArray x2('c', {3,2}, {10, 20, 30, 40, 50, 60}, sd::DataType::INT64); + NDArray x3('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x4('c', {2}, {0.4, 0.5}, sd::DataType::HALF); + NDArray x5('c', {2,2}, {0, 1, 2, 3}, sd::DataType::HALF); + NDArray x6('c', {2}, {0.4, 0.5}, sd::DataType::FLOAT32); + + NDArray exp1('c', {0}, std::vector{2}, sd::DataType::INT32); + NDArray exp2('c', {0}, std::vector{-0.5}, sd::DataType::HALF); + NDArray exp3('c', {3,2}, {8, 17, 26, 35, 44, 53}, sd::DataType::INT64); + NDArray exp4('c', {2,3}, {-6.5, -14.5, -22.5, -30.5, -38.5, -46.5}, sd::DataType::FLOAT32); + NDArray exp5('c', {2,2}, {0.4, -0.5, -1.6, -2.5}, sd::DataType::HALF); + + scalar1 -= scalar2; + ASSERT_EQ(scalar1, exp1); + + scalar2 -= scalar1; + ASSERT_EQ(scalar2, exp2); + + x2 -= x1; + ASSERT_EQ(x2, exp3); + + x1 -= x2; + ASSERT_EQ(x1, exp4); + + x4 -= x3; + ASSERT_EQ(x4, exp5); + + x6 -= x5; + ASSERT_EQ(x6, exp5); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_operatorMinusEqual_test2) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); + NDArray x2('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT32); + + const Nd4jLong val1 = 1; + const float16 val2 = 1.5; + const double val3 = 2.2; + + NDArray exp1('c', {2,2}, {-1, 0, 1, 2}, sd::DataType::FLOAT32); + NDArray exp2('c', {2,2}, {-1, 0, 1, 2}, sd::DataType::INT32); + NDArray exp3('c', {2,2}, {-2.5, -1.5, -0.5, 0.5}, sd::DataType::FLOAT32); + NDArray exp4('c', {2,2}, {-2., -1., 0., 0.}, sd::DataType::INT32); + NDArray exp5('c', {2,2}, {-4.7, -3.7, -2.7, -1.7}, sd::DataType::FLOAT32); + NDArray exp6('c', {2,2}, {-4, -3, -2, -2}, sd::DataType::INT32); + + x1 -= val1; + ASSERT_EQ(x1, exp1); + + x2 -= val1; + ASSERT_EQ(x2, exp2); + + x1 -= val2; + ASSERT_EQ(x1, exp3); + + x2 -= val2; + ASSERT_EQ(x2, exp4); + + x1 -= val3; + ASSERT_EQ(x1, exp5); + + x2 -= val3; + ASSERT_EQ(x2, exp6); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_operatorMultiplyEqual_test1) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray scalar1('c', {0}, std::vector{3}, sd::DataType::INT32); + NDArray scalar2('c', {0}, std::vector{2.5}, sd::DataType::HALF); + + NDArray x1('c', {2,3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, sd::DataType::FLOAT32); + NDArray x2('c', {3,2}, {1, 2, 3, 4, 5, 6}, sd::DataType::INT64); + NDArray x3('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x4('c', {2}, {0.4, 0.5}, sd::DataType::HALF); + NDArray x5('c', {2,2}, {0, 1, 2, 3}, sd::DataType::HALF); + NDArray x6('c', {2}, {0.4, 0.5}, sd::DataType::FLOAT32); + + NDArray exp1('c', {0}, std::vector{7}, sd::DataType::INT32); + NDArray exp2('c', {0}, std::vector{17.5}, sd::DataType::HALF); + NDArray exp3('c', {3,2}, {1, 5, 10, 18, 27, 39}, sd::DataType::INT64); + NDArray exp4('c', {2,3}, {1.5, 12.5, 35, 81, 148.5, 253.5}, sd::DataType::FLOAT32); + NDArray exp5('c', {2,2}, {0., 0.5, 0.8, 1.5}, sd::DataType::HALF); + + scalar1 *= scalar2; + ASSERT_EQ(scalar1, exp1); + + scalar2 *= scalar1; + ASSERT_EQ(scalar2, exp2); + + x2 *= x1; + ASSERT_EQ(x2, exp3); + + x1 *= x2; + ASSERT_EQ(x1, exp4); + + x4 *= x3; + ASSERT_EQ(x4, exp5); + + x6 *= x5; + ASSERT_EQ(x6, exp5); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_operatorMultiplyEqual_test2) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); + NDArray x2('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT32); + + const Nd4jLong val1 = 1; + const float16 val2 = 1.5; + const double val3 = 2.2; + + NDArray exp1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); + NDArray exp2('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT32); + NDArray exp3('c', {2,2}, {0, 1.5, 3, 4.5}, sd::DataType::FLOAT32); + NDArray exp4('c', {2,2}, {0, 1, 3, 4}, sd::DataType::INT32); + NDArray exp5('c', {2,2}, {0, 3.3, 6.6, 9.9}, sd::DataType::FLOAT32); + NDArray exp6('c', {2,2}, {0, 2, 6, 8}, sd::DataType::INT32); + + x1 *= val1; + ASSERT_EQ(x1, exp1); + + x2 *= val1; + ASSERT_EQ(x2, exp2); + + x1 *= val2; + ASSERT_EQ(x1, exp3); + + x2 *= val2; + ASSERT_EQ(x2, exp4); + + x1 *= val3; + ASSERT_EQ(x1, exp5); + + x2 *= val3; + ASSERT_EQ(x2, exp6); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_operatorDivideEqual_test1) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray scalar1('c', {0}, std::vector{3}, sd::DataType::INT32); + NDArray scalar2('c', {0}, std::vector{2.5}, sd::DataType::HALF); + + NDArray x1('c', {2,3}, {1.5, 2.5, 3.5, 4.5, 5.5, 6.5}, sd::DataType::FLOAT32); + NDArray x2('c', {3,2}, {10, 20, 30, 40, 50, 60}, sd::DataType::INT64); + NDArray x3('c', {2,2}, {1, 2, 3, 4}, sd::DataType::INT64); + NDArray x4('c', {2}, {0.4, 0.5}, sd::DataType::HALF); + NDArray x5('c', {2,2}, {1, 2, 3, 4}, sd::DataType::HALF); + NDArray x6('c', {2}, {0.4, 0.5}, sd::DataType::FLOAT32); + + NDArray exp1('c', {0}, std::vector{1}, sd::DataType::INT32); + NDArray exp2('c', {0}, std::vector{2.5}, sd::DataType::HALF); + NDArray exp3('c', {3,2}, {6, 8, 8, 8, 9, 9}, sd::DataType::INT64); + NDArray exp4('c', {2,3}, {0.25, 0.3125, 0.4375, 0.5625, 0.611111111, 0.722222222}, sd::DataType::FLOAT32); + NDArray exp5('c', {2,2}, {0.4, 0.25, 0.1333333, 0.125}, sd::DataType::HALF); + + scalar1 /= scalar2; + ASSERT_EQ(scalar1, exp1); + + scalar2 /= scalar1; + ASSERT_EQ(scalar2, exp2); + + x2 /= x1; + ASSERT_EQ(x2, exp3); + + x1 /= x2; + ASSERT_EQ(x1, exp4); + + x4 /= x3; + ASSERT_EQ(x4, exp5); + + x6 /= x5; + ASSERT_EQ(x6, exp5); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_operatorDivideEqual_test2) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x1('c', {2,2}, {0, 2, 4, 6}, sd::DataType::FLOAT32); + NDArray x2('c', {2,2}, {0, 2, 4, 6}, sd::DataType::INT32); + + const Nd4jLong val1 = 1; + const float16 val2 = 2.; + const double val3 = 2.2; + + NDArray exp1('c', {2,2}, {0, 2, 4, 6}, sd::DataType::FLOAT32); + NDArray exp2('c', {2,2}, {0, 2, 4, 6}, sd::DataType::INT32); + NDArray exp3('c', {2,2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); + NDArray exp4('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT32); + NDArray exp5('c', {2,2}, {0, 0.45454545, 0.909090909, 1.363636364}, sd::DataType::FLOAT32); + NDArray exp6('c', {2,2}, {0, 0, 0, 1}, sd::DataType::INT32); + + x1 /= val1; + ASSERT_EQ(x1, exp1); + + x2 /= val1; + ASSERT_EQ(x2, exp2); + + x1 /= val2; + ASSERT_EQ(x1, exp3); + + x2 /= val2; + ASSERT_EQ(x2, exp4); + + x1 /= val3; + ASSERT_EQ(x1, exp5); + + x2 /= val3; + ASSERT_EQ(x2, exp6); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_reduceNumberFloat_test1) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2,2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::HALF); + NDArray x3('c', {2,2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2,2}, {0, 1, 0, 1}, sd::DataType::BOOL); + + NDArray exp1('c', {0}, std::vector{1.5}, sd::DataType::FLOAT32); + NDArray exp2('c', {0}, std::vector{2}, sd::DataType::HALF); + NDArray exp3('c', {0}, std::vector{2}, sd::DataType::DOUBLE); + NDArray exp4('c', {0}, std::vector{0.25},sd::DataType::FLOAT32); + + + NDArray scalar = x1.reduceNumber(reduce::Mean); + ASSERT_EQ(scalar, exp1); + x1.reduceNumber(reduce::Mean, scalar); + ASSERT_EQ(scalar, exp1); + + scalar = x2.reduceNumber(reduce::Mean); + ASSERT_EQ(scalar, exp2); + x2.reduceNumber(reduce::Mean, scalar); + ASSERT_EQ(scalar, exp2); + + scalar = x3.reduceNumber(reduce::Mean); + ASSERT_EQ(scalar, exp3); + x3.reduceNumber(reduce::Mean,scalar); + ASSERT_EQ(scalar, exp3); + + scalar = x4.reduceNumber(reduce::Mean); + ASSERT_EQ(scalar, exp4); + x4.reduceNumber(reduce::Mean, scalar); + ASSERT_EQ(scalar, exp4); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_reduceNumberSame_test1) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2,2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::HALF); + NDArray x3('c', {2,2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2,2}, {0, 1, 0, 1}, sd::DataType::BOOL); + + NDArray exp1('c', {0}, std::vector{6}, sd::DataType::INT64); + NDArray exp2('c', {0}, std::vector{8}, sd::DataType::HALF); + NDArray exp3('c', {0}, std::vector{8}, sd::DataType::DOUBLE); + NDArray exp4('c', {0}, std::vector{1}, sd::DataType::BOOL); + + + NDArray scalar = x1.reduceNumber(reduce::Sum); + ASSERT_EQ(scalar, exp1); + x1.reduceNumber(reduce::Sum, scalar); + ASSERT_EQ(scalar, exp1); + + scalar = x2.reduceNumber(reduce::Sum); + ASSERT_EQ(scalar, exp2); + x2.reduceNumber(reduce::Sum, scalar); + ASSERT_EQ(scalar, exp2); + + scalar = x3.reduceNumber(reduce::Sum); + ASSERT_EQ(scalar, exp3); + x3.reduceNumber(reduce::Sum, scalar); + ASSERT_EQ(scalar, exp3); + + scalar = x4.reduceNumber(reduce::Sum); + ASSERT_EQ(scalar, exp4); + x4.reduceNumber(reduce::Sum, scalar); + ASSERT_EQ(scalar, exp4); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_reduceNumberBool_test1) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x1('c', {2,2}, {0, -1, 2, -3}, sd::DataType::INT64); + NDArray x2('c', {2,2}, {0.5, -1.5, 2.5, -3.5}, sd::DataType::HALF); + NDArray x3('c', {2,2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2,2}, {-2, -1, 0, 1}, sd::DataType::BOOL); + + NDArray exp1('c', {0}, std::vector{1}, sd::DataType::BOOL); + + NDArray scalar = x1.reduceNumber(reduce::IsFinite); + ASSERT_EQ(scalar, exp1); + x1.reduceNumber(reduce::IsFinite, scalar); + ASSERT_EQ(scalar, exp1); + + scalar = x2.reduceNumber(reduce::IsFinite); + ASSERT_EQ(scalar, exp1); + x2.reduceNumber(reduce::IsFinite, scalar); + ASSERT_EQ(scalar, exp1); + + scalar = x3.reduceNumber(reduce::IsFinite); + ASSERT_EQ(scalar, exp1); + x3.reduceNumber(reduce::IsFinite, scalar); + ASSERT_EQ(scalar, exp1); + + scalar = x4.reduceNumber(reduce::IsFinite); + ASSERT_EQ(scalar, exp1); + x4.reduceNumber(reduce::IsFinite, scalar); + ASSERT_EQ(scalar, exp1); +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_reduceNumberLong_test1) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2,2}, {0.5, 1.5, 2.5, 3.5}, sd::DataType::HALF); + NDArray x3('c', {2,2}, {0.5, -1.5, 0, 3.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2,2}, {0, 1, 0, 1}, sd::DataType::BOOL); + + NDArray exp1('c', {0}, std::vector{3}, sd::DataType::INT64); + NDArray exp2('c', {0}, std::vector{4}, sd::DataType::INT64); + NDArray exp3('c', {0}, std::vector{3}, sd::DataType::INT64); + NDArray exp4('c', {0}, std::vector{2}, sd::DataType::INT64); + + NDArray scalar = x1.reduceNumber(reduce::CountNonZero); + ASSERT_EQ(scalar, exp1); + x1.reduceNumber(reduce::CountNonZero, scalar); + ASSERT_EQ(scalar, exp1); + + scalar = x2.reduceNumber(reduce::CountNonZero); + ASSERT_EQ(scalar, exp2); + x2.reduceNumber(reduce::CountNonZero, scalar); + ASSERT_EQ(scalar, exp2); + + scalar = x3.reduceNumber(reduce::CountNonZero); + ASSERT_EQ(scalar, exp3); + x3.reduceNumber(reduce::CountNonZero, scalar); + ASSERT_EQ(scalar, exp3); + + scalar = x4.reduceNumber(reduce::CountNonZero); + ASSERT_EQ(scalar, exp4); + x4.reduceNumber(reduce::CountNonZero, scalar); + ASSERT_EQ(scalar, exp4); +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_indexReduceNumber_test1) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT32); + NDArray x2('c', {2,2}, {0.5, 1.5, -4.5, 3.5}, sd::DataType::HALF); + NDArray x3('c', {2,2}, {0, -1, 0, 1}, sd::DataType::BOOL); + + NDArray exp1('c', {0}, std::vector{3}, sd::DataType::INT64); + NDArray exp2('c', {0}, std::vector{2}, sd::DataType::INT64); + NDArray exp3('c', {0}, std::vector{1}, sd::DataType::INT64); + + NDArray scalar = x1.indexReduceNumber(sd::indexreduce::IndexAbsoluteMax); + ASSERT_EQ(scalar, exp1); + + scalar = x2.indexReduceNumber(sd::indexreduce::IndexAbsoluteMax); + ASSERT_EQ(scalar, exp2); + + scalar = x3.indexReduceNumber(sd::indexreduce::IndexAbsoluteMax); + ASSERT_EQ(scalar, exp3); +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_applyTransformFloat_test1) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x1('c', {2,2}, {0, 4, 9, 16}, sd::DataType::INT64); + NDArray x2('c', {2,2}, {0, 2.25, 6.25, 12.25}, sd::DataType::HALF); + NDArray x3('c', {2,2}, {0, 2.25, 6.25, 12.25}, sd::DataType::DOUBLE); + NDArray x4('c', {2,2}, {0, 1, 0, 1}, sd::DataType::BOOL); + + NDArray exp1('c', {2,2}, {0, 2, 3, 4}, sd::DataType::FLOAT32); + NDArray exp2('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::DOUBLE); + NDArray exp3('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::HALF); + NDArray exp4('c', {2,2}, {0, 1, 0, 1}, sd::DataType::HALF); + + NDArray result1('c', {2,2}, sd::DataType::FLOAT32); + NDArray result2('c', {2,2}, sd::DataType::DOUBLE); + NDArray result3('c', {2,2}, sd::DataType::HALF); + + x1.applyTransform(sd::transform::Sqrt, result1); + ASSERT_EQ(result1, exp1); + + x2.applyTransform(sd::transform::Sqrt, result2); + ASSERT_EQ(result2, exp2); + + x3.applyTransform(sd::transform::Sqrt, result3); + ASSERT_EQ(result3, exp3); + + x4.applyTransform(sd::transform::Sqrt, result3); + ASSERT_EQ(result3, exp4); + + x2.applyTransform(sd::transform::Sqrt, x2); + ASSERT_EQ(x2, exp3); + + x3.applyTransform(sd::transform::Sqrt, x3); + ASSERT_EQ(x3, exp2); +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_applyTransformSame_test1) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::HALF); + NDArray x3('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2,2}, {0, 1, 0, 1}, sd::DataType::BOOL); + NDArray x5('c', {2,3}, {0, 1.5, 2.5, 3.5, 4.5, 5.5}, sd::DataType::DOUBLE); + + NDArray exp1('c', {2,2}, {0, 1, 4, 9}, sd::DataType::INT64); + NDArray exp2('c', {2,2}, {0, 2.25, 6.25, 12.25}, sd::DataType::HALF); + NDArray exp3('c', {2,2}, {0, 2.25, 6.25, 12.25}, sd::DataType::DOUBLE); + NDArray exp4('c', {2,2}, {0, 1, 0, 1}, sd::DataType::BOOL); + NDArray exp5('c', {3,2}, {0, 2.25, 6.25, 12.25, 20.25, 30.25}, sd::DataType::DOUBLE); + + NDArray result1('c', {2,2}, sd::DataType::INT64); + NDArray result2('c', {2,2}, sd::DataType::HALF); + NDArray result3('c', {2,2}, sd::DataType::DOUBLE); + NDArray result4('c', {2,2}, sd::DataType::BOOL); + NDArray result5('c', {3,2}, sd::DataType::DOUBLE); + + x1.applyTransform(sd::transform::Square, result1); + ASSERT_EQ(result1, exp1); + + x2.applyTransform(sd::transform::Square, result2); + ASSERT_EQ(result2, exp2); + + x3.applyTransform(sd::transform::Square, result3); + ASSERT_EQ(result3, exp3); + + x4.applyTransform(sd::transform::Square, result4); + ASSERT_EQ(result4, exp4); + + x2.applyTransform(sd::transform::Square, x2); + ASSERT_EQ(x2, exp2); + + x3.applyTransform(sd::transform::Square, x3); + ASSERT_EQ(x3, exp3); + + x5.applyTransform(sd::transform::Square, result5); + ASSERT_EQ(result5, exp5); +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_applyTransformBool_test1) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::HALF); + NDArray x3('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2,2}, {0, 1, 0, 1}, sd::DataType::BOOL); + NDArray x5('c', {2,3}, {0, 1.5, 2.5, 3.5, 4.5, 5.5}, sd::DataType::DOUBLE); + + NDArray exp1('c', {2,2}, {0, 0, 0, 1}, sd::DataType::BOOL); + NDArray exp2('c', {2,2}, {0, 1, 0, 0}, sd::DataType::BOOL); + NDArray exp3('c', {3,2}, {0, 0, 0, 0, 0, 1}, sd::DataType::BOOL); + + NDArray result1('c', {2,2}, sd::DataType::BOOL); + NDArray result2('c', {3,2}, sd::DataType::BOOL); + + /* + x1.applyTransform(sd::transform::IsMax, result1); + ASSERT_EQ(result1, exp1); + + x2.applyTransform(sd::transform::IsMax, result1); + ASSERT_EQ(result1, exp1); + + x3.applyTransform(sd::transform::IsMax, result1); + ASSERT_EQ(result1, exp1); + + x4.applyTransform(sd::transform::IsMax, result1); + ASSERT_EQ(result1, exp2); + + x5.applyTransform(sd::transform::IsMax, result2); + ASSERT_EQ(result2, exp3); + */ +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_applyTransformStrict_test1) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::HALF); + NDArray x2('c', {2,2}, {0, 1, 2, 3}, sd::DataType::FLOAT32); + NDArray x3('c', {2,2}, {0, 1, 2, 3}, sd::DataType::DOUBLE); + NDArray x4('c', {2,3}, {0, 1, 2, 3, 4, 5}, sd::DataType::DOUBLE); + + NDArray exp1('c', {2,2}, {0, 3, 12, 27}, sd::DataType::HALF); + NDArray exp2('c', {2,2}, {0, 3, 12, 27}, sd::DataType::FLOAT32); + NDArray exp3('c', {2,2}, {0, 3, 12, 27}, sd::DataType::DOUBLE); + NDArray exp4('c', {3,2}, {0, 3, 12, 27, 48, 75}, sd::DataType::DOUBLE); + NDArray exp5('c', {2,3}, {0, 3, 12, 27, 48, 75}, sd::DataType::DOUBLE); + + NDArray result1('c', {2,2}, sd::DataType::HALF); + NDArray result2('c', {2,2}, sd::DataType::FLOAT32); + NDArray result3('c', {2,2}, sd::DataType::DOUBLE); + NDArray result4('c', {3,2}, sd::DataType::DOUBLE); + + x1.applyTransform(sd::transform::CubeDerivative, result1); + ASSERT_EQ(result1, exp1); + + x2.applyTransform(sd::transform::CubeDerivative, result2); + ASSERT_EQ(result2, exp2); + + x3.applyTransform(sd::transform::CubeDerivative, result3); + ASSERT_EQ(result3, exp3); + + x4.applyTransform(sd::transform::CubeDerivative, result4); + ASSERT_EQ(result4, exp4); + + x1.applyTransform(sd::transform::CubeDerivative, x1); + ASSERT_EQ(x1, exp1); + + x2.applyTransform(sd::transform::CubeDerivative, x2); + ASSERT_EQ(x2, exp2); + + x3.applyTransform(sd::transform::CubeDerivative, x3); + ASSERT_EQ(x3, exp3); + + x4.applyTransform(sd::transform::CubeDerivative, x4); + ASSERT_EQ(x4, exp5); +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_applyPairwiseTransform_test1) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x1('c', {2,3}, {0, 1, 2, 3, 4, 5}, sd::DataType::INT32); + NDArray x2('c', {2,3}, {0, 1, 2, 3, 4, 5}, sd::DataType::FLOAT32); + NDArray x3('c', {2,3}, {0, 1, 0, 1, 0, 0}, sd::DataType::BOOL); + NDArray x4('c', {3,2}, {0.5, 1.5, 2.5, 3.5, 4.5, 0}, sd::DataType::DOUBLE); + NDArray x5('c', {3,2}, sd::DataType::INT32); + NDArray x6('c', {2,3}, sd::DataType::DOUBLE); + + NDArray exp1('c', {2,3}, {0, 2, 4, 6, 8, 5}, sd::DataType::INT32); + NDArray exp2('c', {2,3}, {0.5, 2.5, 4.5, 6.5, 8.5, 5.}, sd::DataType::FLOAT32); + NDArray exp3('c', {2,3}, {1, 1, 1, 1, 1, 0}, sd::DataType::BOOL); + NDArray exp4('c', {2,3}, {0.5, 2.5, 4.5, 6.5, 8.5, 5.}, sd::DataType::DOUBLE); + NDArray exp5('c', {3,2}, {0, 2, 4, 6, 8, 5}, sd::DataType::INT32); + + x1.applyPairwiseTransform(sd::pairwise::Add, x4, x5); + ASSERT_EQ(x5, exp5); + + x1.applyPairwiseTransform(sd::pairwise::Add, x4, x6); + ASSERT_EQ(x6, exp4); + + x1.applyPairwiseTransform(sd::pairwise::Add, x4); + ASSERT_EQ(x1, exp1); + + x2.applyPairwiseTransform(sd::pairwise::Add, x4); + ASSERT_EQ(x2, exp2); + + x3.applyPairwiseTransform(sd::pairwise::Add, x4); + ASSERT_EQ(x3, exp3); +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_applyPairwiseTransform_test2) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x1('c', {2,3}, {1, 1, 2, 3, 4, 5}, sd::DataType::INT32); + NDArray x2('c', {3,2}, {1, 0, 2, 0, 4, 0}, sd::DataType::INT32); + NDArray x3('c', {3,2}, {0.5, 1.5, 2.5, 3, 4.5, 0}, sd::DataType::DOUBLE); + NDArray x4('c', {2,3}, {0.5, 1., 2.5, 3, 4., 0}, sd::DataType::DOUBLE); + NDArray x5('c', {3,2}, {0, 1, 0, 1, 0, 1}, sd::DataType::BOOL); + NDArray x6('c', {2,3}, {1, 1, 1, 0, 1, 0}, sd::DataType::BOOL); + + NDArray x7('c', {3,2}, sd::DataType::BOOL); + NDArray x8('c', {2,3}, sd::DataType::BOOL); + + NDArray exp1('c', {3,2}, {1, 0, 1, 0, 1, 0}, sd::DataType::BOOL); + NDArray exp2('c', {2,3}, {1, 0, 1, 1, 0, 1}, sd::DataType::BOOL); + NDArray exp3('c', {2,3}, {0, 1, 0, 0, 0, 0}, sd::DataType::BOOL); + + x1.applyPairwiseTransform(sd::pairwise::EqualTo, x2, x7); + ASSERT_EQ(x7, exp1); + + x3.applyPairwiseTransform(sd::pairwise::EqualTo, x4, x8); + ASSERT_EQ(x8, exp2); + + x5.applyPairwiseTransform(sd::pairwise::EqualTo, x6, x8); + ASSERT_EQ(x8, exp3); +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_applyBroadcast_test1) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x1('c', {2,3}, {10, 20, 30, 40, 50, 60}, sd::DataType::INT32); + NDArray x2('c', {2}, {1, 2}, sd::DataType::INT64); + NDArray x3('c', {2,3}, sd::DataType::INT32); + NDArray x4('c', {2}, {1, 2}, sd::DataType::FLOAT32); + NDArray x5('c', {2,3}, sd::DataType::FLOAT32); + NDArray x6('c', {2}, {1, 1}, sd::DataType::BOOL); + + NDArray exp1('c', {2,3}, {11, 21, 31, 42, 52, 62}, sd::DataType::INT32); + NDArray exp2('c', {2,3}, {11, 21, 31, 42, 52, 62}, sd::DataType::FLOAT32); + NDArray exp3('c', {2,3}, {11, 21, 31, 41, 51, 61}, sd::DataType::INT32); + + x1.applyBroadcast(sd::broadcast::Add, {0}, x2, x3); + ASSERT_EQ(x3, exp1); + + x1.applyBroadcast(sd::broadcast::Add, {0}, x4, x5); + ASSERT_EQ(x5, exp2); + + x1.applyBroadcast(sd::broadcast::Add, {0}, x6, x3); + ASSERT_EQ(x3, exp3); +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_applyBroadcast_test2) { + + NDArray x1('c', {2,3}, {10, 20, 30, 40, 50, 60}, sd::DataType::INT32); + NDArray x2('c', {2}, {10, 60}, sd::DataType::INT32); + NDArray x3('c', {2,3}, sd::DataType::BOOL); + + NDArray x4('c', {2,3}, {0, 0, 0, 0, 0, 1}, sd::DataType::BOOL); + NDArray x5('c', {2}, {0, 1}, sd::DataType::BOOL); + + NDArray exp1('c', {2,3}, {1, 0, 0, 0, 0, 1}, sd::DataType::BOOL); + NDArray exp2('c', {2,3}, {1, 1, 1, 0, 0, 1}, sd::DataType::BOOL); + + x1.applyBroadcast(sd::broadcast::EqualTo, {0}, x2, x3); + ASSERT_EQ(x3, exp1); + + x4.applyBroadcast(sd::broadcast::EqualTo, {0}, x5, x3); + ASSERT_EQ(x3, exp2); +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_applyTrueBroadcast_test1) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x1('c', {2,2}, {10, 20, 30, 40}, sd::DataType::INT32); + NDArray x2('c', {2}, {1, 2}, sd::DataType::HALF); + NDArray x3('c', {2,2}, sd::DataType::HALF); + + NDArray x4('c', {2}, {1, 2}, sd::DataType::INT64); + NDArray x5('c', {2,2}, sd::DataType::INT32); + + NDArray x6('c', {2,2}, {0, 1, 0, 1}, sd::DataType::BOOL); + NDArray x7('c', {2}, {1, 2}, sd::DataType::INT64); + NDArray x8('c', {2,2}, sd::DataType::BOOL); + + NDArray x13('c', {0}, std::vector{3}, sd::DataType::INT64); + NDArray x14('c', {0}, std::vector{1.5}, sd::DataType::DOUBLE); + NDArray x15(sd::DataType::DOUBLE); + NDArray x16('c', {2,2}, sd::DataType::DOUBLE); + + NDArray exp1('c', {2,2}, {11, 22, 31, 42}, sd::DataType::HALF); + NDArray exp2('c', {2,2}, {11, 22, 31, 42}, sd::DataType::INT32); + NDArray exp3('c', {2,2}, {1, 1, 1, 1}, sd::DataType::BOOL); + NDArray exp4('c', {0}, std::vector{4.5}, sd::DataType::DOUBLE); + NDArray exp5('c', {2,2}, {11.5, 21.5, 31.5, 41.5}, sd::DataType::DOUBLE); + + x1.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x2, x3); + ASSERT_EQ(x3, exp1); + + x1.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x4, x5); + ASSERT_EQ(x5, exp2); + + x6.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x7, x8); + ASSERT_EQ(x8, exp3); + + auto x9 = x1.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x2); + ASSERT_EQ(x9, exp1); + + auto x10 = x1.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x4); + ASSERT_EQ(x10, exp2); + + auto x11 = x6.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x7); + ASSERT_EQ(x11, exp3); + + auto x12 = x1.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x2); + ASSERT_EQ(x12, exp1); + + x13.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x14, x15); + ASSERT_EQ(x15, exp4); + + x1.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x14, x16); + ASSERT_EQ(x16, exp5); + + x14.applyTrueBroadcast(sd::BroadcastOpsTuple::Add(), x1, x16); + ASSERT_EQ(x16, exp5); + +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_applyTrueBroadcast_test2) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x1('c', {2,2}, {10, 20, 30, 40}, sd::DataType::HALF); + NDArray x2('c', {2}, {10, 40}, sd::DataType::HALF); + NDArray x3('c', {2,2}, sd::DataType::BOOL); + NDArray x4('c', {0}, std::vector{10}, sd::DataType::HALF); + NDArray x5('c', {0}, std::vector{20}, sd::DataType::HALF); + NDArray x6(sd::DataType::BOOL); + + NDArray exp1('c', {2,2}, {1, 0, 0, 1}, sd::DataType::BOOL); + NDArray exp2('c', {2,2}, {1, 0, 0, 0}, sd::DataType::BOOL); + NDArray exp3('c', {0}, std::vector{0}, sd::DataType::BOOL); + + x1.applyTrueBroadcast(BroadcastBoolOpsTuple(sd::scalar::EqualTo, sd::pairwise::EqualTo, sd::broadcast::EqualTo), x2, x3); + ASSERT_EQ(x3, exp1); + + x1.applyTrueBroadcast(BroadcastBoolOpsTuple(sd::scalar::EqualTo, sd::pairwise::EqualTo, sd::broadcast::EqualTo), x4, x3); + ASSERT_EQ(x3, exp2); + + x4.applyTrueBroadcast(BroadcastBoolOpsTuple(sd::scalar::EqualTo, sd::pairwise::EqualTo, sd::broadcast::EqualTo), x1, x3); + ASSERT_EQ(x3, exp2); + + x5.applyTrueBroadcast(BroadcastBoolOpsTuple(sd::scalar::EqualTo, sd::pairwise::EqualTo, sd::broadcast::EqualTo), x4, x6); + ASSERT_EQ(x6, exp3); +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_applyScalar_test1) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + NDArray x3('c', {2,2}, sd::DataType::DOUBLE); + NDArray x4('c', {2,2}, {0, 1, 0, 1}, sd::DataType::BOOL); + + NDArray exp1('c', {2,2}, {1, 2, 3, 4}, sd::DataType::INT64); + NDArray exp2('c', {2,2}, {1.5, 2.5, 3.5, 4.5}, sd::DataType::DOUBLE); + NDArray exp3('c', {2,2}, {0.1, 1.6, 2.6, 3.6}, sd::DataType::FLOAT32); + NDArray exp4('c', {2,2}, {1.1, 2.1, 1.1, 2.1}, sd::DataType::DOUBLE); + NDArray exp5('c', {2,2}, {1, 1, 1, 1}, sd::DataType::BOOL); + + x1.applyScalar(sd::scalar::Add, 1, x1); + ASSERT_EQ(x1, exp1); + + x1.applyScalar(sd::scalar::Add, 0.5, x3); + ASSERT_EQ(x3, exp2); + + x2.applyScalar(sd::scalar::Add, 0.1, x2); + ASSERT_EQ(x2, exp3); + + x4.applyScalar(sd::scalar::Add, 1.1, x3); + ASSERT_EQ(x3, exp4); + + x4.applyScalar(sd::scalar::Add, 1, x4); + ASSERT_EQ(x4, exp5); +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_applyScalar_test2) { + + NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x2('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + NDArray x3('c', {2,2}, {0, 1, 1, 0}, sd::DataType::BOOL); + NDArray x4('c', {2,2}, sd::DataType::BOOL); + + + NDArray exp1('c', {2,2}, {0, 1, 0, 0}, sd::DataType::BOOL); + NDArray exp2('c', {2,2}, {0, 1, 1, 0}, sd::DataType::BOOL); + + x1.applyScalar(sd::scalar::EqualTo, 1, x4); + ASSERT_EQ(x4, exp1); + + x2.applyScalar(sd::scalar::EqualTo, 1.5, x4); + ASSERT_EQ(x4, exp1); + + x3.applyScalar(sd::scalar::EqualTo, true, x4); + ASSERT_EQ(x4, exp2); + +} + +#ifndef __CUDABLAS__ +////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_applyLambda_test1) { + + NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::DOUBLE); + NDArray x2('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x3('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + NDArray x4('c', {2,2}, sd::DataType::DOUBLE); + NDArray x5('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + NDArray x6('c', {2,2}, {0, -1, -1, 0.1}, sd::DataType::BOOL); + NDArray x7('c', {2,2}, sd::DataType::BOOL); + + const float item1 = 0.1; + const double item2 = 0.1; + auto func1 = [=](float elem) { return elem + item1; }; + auto func2 = [=](int elem) { return elem + item1; }; + auto func3 = [=](int elem) { return elem + item2; }; + auto func4 = [=](double elem) { return elem + item1; }; + auto func5 = [=](float elem) { return elem - (int)1; }; + + NDArray exp1('c', {2,2}, {0.1, 1.1, 2.1, 3.1}, sd::DataType::DOUBLE); + NDArray exp2('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray exp3('c', {2,2}, {0.1, 1.1, 2.1, 3.1}, sd::DataType::FLOAT32); + NDArray exp4('c', {2,2}, {0.1, 1.6, 2.6, 3.6}, sd::DataType::FLOAT32); + NDArray exp5('c', {2,2}, {1, 0, 0, 0}, sd::DataType::BOOL); + + x1.applyLambda(func1, x4); + ASSERT_EQ(x4, exp1); + + x2.applyLambda(func1, x2); + ASSERT_EQ(x2, exp2); + + x2.applyLambda(func2, x2); + ASSERT_EQ(x2, exp2); + + x3.applyLambda(func3, x3); + ASSERT_EQ(x3, exp3); + + x5.applyLambda(func4, x5); + // x5.printBuffer(); + ASSERT_EQ(x5, exp4); + + x6.applyLambda(func5, x7); + ASSERT_EQ(x7, exp5); +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_applyIndexedLambda_test1) { + + NDArray x1('c', {2,2}, {0, 1, 2, 3}, sd::DataType::DOUBLE); + NDArray x2('c', {2,2}, {0, 1, 2, 3}, sd::DataType::INT64); + NDArray x3('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + NDArray x4('c', {2,2}, sd::DataType::DOUBLE); + NDArray x5('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + NDArray x6('c', {2,2}, {1, -1, -1, 0.1}, sd::DataType::BOOL); + NDArray x7('c', {2,2}, sd::DataType::BOOL); + + const float item1 = 0.1; + const double item2 = 0.1; + auto func1 = [=](Nd4jLong idx, float elem) { return idx + elem + item1; }; + auto func2 = [=](Nd4jLong idx, int elem) { return idx + elem + item1; }; + auto func3 = [=](Nd4jLong idx, int elem) { return idx + elem + item2; }; + auto func4 = [=](Nd4jLong idx, double elem) { return idx + elem + item1; }; + auto func5 = [=](Nd4jLong idx, float elem) { return idx + elem - (int)1; }; + + NDArray exp1('c', {2,2}, {0.1, 2.1, 4.1, 6.1}, sd::DataType::DOUBLE); + NDArray exp2('c', {2,2}, {0, 2, 4, 6}, sd::DataType::INT64); + NDArray exp3('c', {2,2}, {0.1, 2.1, 4.1, 6.1}, sd::DataType::FLOAT32); + NDArray exp4('c', {2,2}, {0.1, 2.6, 4.6, 6.6}, sd::DataType::FLOAT32); + NDArray exp5('c', {2,2}, {0, 1, 1, 1}, sd::DataType::BOOL); + NDArray exp6('c', {2,2}, {0, 3, 6, 9}, sd::DataType::INT64); + + x1.applyIndexedLambda(func1, x4); + ASSERT_EQ(x4, exp1); + + x2.applyIndexedLambda(func1, x2); + ASSERT_EQ(x2, exp2); + + x2.applyIndexedLambda(func2, x2); + ASSERT_EQ(x2, exp6); + + x3.applyIndexedLambda(func3, x3); + ASSERT_EQ(x3, exp3); + + x5.applyIndexedLambda(func4, x5); + ASSERT_EQ(x5, exp4); + + x6.applyIndexedLambda(func5, x7); + ASSERT_EQ(x7, exp5); +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_applyPairwiseLambda_test1) { + + NDArray x1('c', {2,2}, {0., 1, 2, 3}, sd::DataType::DOUBLE); + NDArray x2('c', {2,2}, {0., 1, 2, 3}, sd::DataType::INT64); + NDArray x3('c', {2,2}, {0., 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + NDArray x4('c', {2,2}, sd::DataType::DOUBLE); + NDArray x5('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + NDArray x6('c', {2,2}, {0.1, -1, -1, 0.1}, sd::DataType::BOOL); + NDArray x7('c', {2,2}, sd::DataType::BOOL); + NDArray other1('c', {2,2}, {0.1, 0.1, 0.1, 0.1}, sd::DataType::FLOAT32); + NDArray other2('c', {2,2}, {0.1, 0.1, 0.1, 0.1}, sd::DataType::DOUBLE); + NDArray other3('c', {2,2}, {0., -1, -2, -3}, sd::DataType::INT64); + NDArray other4('c', {2,2}, {1, 0, 0.1, 0}, sd::DataType::BOOL); + + auto func1 = [](float elem1, float elem2) { return elem1 + elem2; }; + auto func2 = [](int elem1, float elem2) { return elem1 + elem2; }; + auto func3 = [](int elem1, double elem2) { return elem1 + elem2; }; + auto func4 = [](double elem1, float elem2) { return elem1 + elem2; }; + auto func5 = [](float elem1, int elem2) { return elem1 - elem2; }; + + NDArray exp1('c', {2,2}, {0.1, 1.1, 2.1, 3.1}, sd::DataType::DOUBLE); + NDArray exp2('c', {2,2}, {0., 0, 0, 0}, sd::DataType::INT64); + NDArray exp3('c', {2,2}, {0.1, 1.1, 2.1, 3.1}, sd::DataType::FLOAT32); + NDArray exp4('c', {2,2}, {0.1, 1.6, 2.6, 3.6}, sd::DataType::FLOAT32); + NDArray exp5('c', {2,2}, {0., 1, 0, 1}, sd::DataType::BOOL); + + x1.applyPairwiseLambda(other2, func1, x4); + ASSERT_EQ(x4, exp1); + + x2.applyPairwiseLambda(other3, func1, x2); + ASSERT_EQ(x2, exp2); + + x2.applyPairwiseLambda(other3, func2, x2); + ASSERT_EQ(x2, other3); + + x3.applyPairwiseLambda(other1, func3, x3); + ASSERT_EQ(x3, exp3); + + x5.applyPairwiseLambda(other1, func4, x5); + ASSERT_EQ(x5, exp4); + + x6.applyPairwiseLambda(other4, func5, x7); + ASSERT_EQ(x7, exp5); +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_applyIndexedPairwiseLambda_test1) { + + NDArray x1('c', {2,2}, {0., 1, 2, 3}, sd::DataType::DOUBLE); + NDArray x2('c', {2,2}, {0., 1, 2, 3}, sd::DataType::INT64); + NDArray x3('c', {2,2}, {0., 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + NDArray x4('c', {2,2}, sd::DataType::DOUBLE); + NDArray x5('c', {2,2}, {0, 1.5, 2.5, 3.5}, sd::DataType::FLOAT32); + NDArray x6('c', {2,2}, {0.1, -1, -1, 0.1}, sd::DataType::BOOL); + NDArray x7('c', {2,2}, sd::DataType::BOOL); + NDArray other1('c', {2,2}, {0.1, 0.1, 0.1, 0.1}, sd::DataType::FLOAT32); + NDArray other2('c', {2,2}, {0.1, 0.1, 0.1, 0.1}, sd::DataType::DOUBLE); + NDArray other3('c', {2,2}, {0., -1, -2, -3}, sd::DataType::INT64); + NDArray other4('c', {2,2}, {1, 0, 0.1, 0}, sd::DataType::BOOL); + + auto func1 = [](Nd4jLong idx, float elem1, float elem2) { return elem1 + elem2 + idx; }; + auto func2 = [](Nd4jLong idx, int elem1, float elem2) { return elem1 + elem2 + idx; }; + auto func3 = [](Nd4jLong idx, int elem1, double elem2) { return elem1 + elem2 + idx; }; + auto func4 = [](Nd4jLong idx, double elem1, float elem2) { return elem1 + elem2 + idx; }; + auto func5 = [](Nd4jLong idx, float elem1, int elem2) { return elem1 - elem2 + idx; }; + + NDArray exp1('c', {2,2}, {0.1, 2.1, 4.1, 6.1}, sd::DataType::DOUBLE); + NDArray exp2('c', {2,2}, {0., 1, 2, 3}, sd::DataType::INT64); + NDArray exp3('c', {2,2}, {0.1, 2.1, 4.1, 6.1}, sd::DataType::FLOAT32); + NDArray exp4('c', {2,2}, {0.1, 2.6, 4.6, 6.6}, sd::DataType::FLOAT32); + NDArray exp5('c', {2,2}, {0., 1, 1, 1}, sd::DataType::BOOL); + + x1.applyIndexedPairwiseLambda(other2, func1, x4); + ASSERT_EQ(x4, exp1); + + x2.applyIndexedPairwiseLambda(other3, func1, x2); + ASSERT_EQ(x2, exp2); + + x2.applyIndexedPairwiseLambda(other3, func2, x2); + ASSERT_EQ(x2, exp2); + + x3.applyIndexedPairwiseLambda(other1, func3, x3); + ASSERT_EQ(x3, exp3); + + x5.applyIndexedPairwiseLambda(other1, func4, x5); + ASSERT_EQ(x5, exp4); + + x6.applyIndexedPairwiseLambda(other4, func5, x7); + ASSERT_EQ(x7, exp5); +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_applyTriplewiseLambda_test1) { + + NDArray x1('c', {2,2}, {0., 1, 2, 3}, sd::DataType::DOUBLE); + NDArray x2('c', {2,2}, {0., -1, -2, -3}, sd::DataType::DOUBLE); + NDArray x3('c', {2,2}, {0, -1.5, -2.5, -3.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2,2}, sd::DataType::DOUBLE); + + NDArray x5('c', {2,2}, {0., 1, 2, 3}, sd::DataType::INT32); + NDArray x6('c', {2,2}, {0., -1, -2, -3}, sd::DataType::INT32); + NDArray x7('c', {2,2}, {0., 10, 20, 30}, sd::DataType::INT32); + + NDArray x8('c', {2,2}, {0., 1, 0, 1}, sd::DataType::BOOL); + NDArray x9('c', {2,2}, {1., 1, 0, 1}, sd::DataType::BOOL); + NDArray x10('c', {2,2}, {0., 0, 0, 0}, sd::DataType::BOOL); + + auto func1 = [](double elem1, float elem2, int elem3) { return elem1 + elem2 + elem3; }; + auto func2 = [](float elem1, float elem2, float elem3) { return elem1 + elem2 + elem3; }; + auto func3 = [](int elem1, int elem2, int elem3) { return elem1 + elem2 + elem3; }; + auto func4 = [](bool elem1, bool elem2, bool elem3) { return elem1 + elem2 + elem3; }; + + NDArray exp('c', {2,2}, {1., 1, 0, 1}, sd::DataType::BOOL); + + x1.applyTriplewiseLambda(x2, x3, func1, x4); + ASSERT_EQ(x4, x2); + + x1.applyTriplewiseLambda(x2, x3, func2, x1); + ASSERT_EQ(x1, x3); + + x5.applyTriplewiseLambda(x6, x7, func3, x5); + ASSERT_EQ(x5, x7); + + x8.applyTriplewiseLambda(x9, x10, func4, x8); + ASSERT_EQ(x8, exp); +} + +#endif + +////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_applyIndexReduce_test1) { + + NDArray x1('c', {2,3}, {0, 1, 2, 3, 4, 5}, sd::DataType::DOUBLE); + NDArray exp1('c', {}, std::vector{5}, sd::DataType::INT64); + NDArray exp2('c', {2}, {2,2}, sd::DataType::INT64); + NDArray exp3('c', {3}, {1,1,1}, sd::DataType::INT64); + + NDArray scalar = x1.applyIndexReduce(sd::indexreduce::IndexMax, {0,1}); + ASSERT_EQ(scalar, exp1); + + NDArray vec1 = x1.applyIndexReduce(sd::indexreduce::IndexMax, {1}); + ASSERT_EQ(vec1, exp2); + + NDArray vec2 = x1.applyIndexReduce(sd::indexreduce::IndexMax, {0}); + ASSERT_EQ(vec2, exp3); +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, ndarray_applyIndexReduce_test2) { + + NDArray x1('c', {2,3}, {0, 1, 2, 3, 4, 5}, sd::DataType::DOUBLE); + NDArray scalar('c', {}, std::vector{5}, sd::DataType::INT64); + NDArray vec1('c', {2}, {2,2}, sd::DataType::INT64); + NDArray vec2('c', {3}, {1,1,1}, sd::DataType::INT64); + NDArray exp1('c', {}, std::vector{5}, sd::DataType::INT64); + NDArray exp2('c', {2}, {2,2}, sd::DataType::INT64); + NDArray exp3('c', {3}, {1,1,1}, sd::DataType::INT64); + + x1.applyIndexReduce(sd::indexreduce::IndexMax, scalar, {0,1}); + ASSERT_EQ(scalar, exp1); + + x1.applyIndexReduce(sd::indexreduce::IndexMax, vec1, {1}); + ASSERT_EQ(vec1, exp2); + + x1.applyIndexReduce(sd::indexreduce::IndexMax, vec2, {0}); + ASSERT_EQ(vec2, exp3); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, applyReduce3_test1) { + + NDArray x1('c', {2,2}, {1,2,3,4}, sd::DataType::INT32); + NDArray x2('c', {2,2}, {-1,-2,-3,-4}, sd::DataType::INT32); + NDArray x3('c', {2,2}, {1.5,1.5,1.5,1.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2,2}, {1,2,3,4}, sd::DataType::DOUBLE); + NDArray exp1('c', {}, std::vector{-30}, sd::DataType::FLOAT32); + NDArray exp2('c', {}, std::vector{15}, sd::DataType::DOUBLE); + + auto result = x1.applyReduce3(reduce3::Dot, x2); + ASSERT_EQ(result, exp1); + + result = x3.applyReduce3(reduce3::Dot, x4); + ASSERT_EQ(result, exp2); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, applyReduce3_test2) { + + NDArray x1('c', {2,2}, {1,2,3,4}, sd::DataType::INT32); + NDArray x2('c', {2,2}, {-1,-2,-3,-4}, sd::DataType::INT32); + NDArray x3('c', {2,2}, {1.5,1.5,1.5,1.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2,2}, {1,2,3,4}, sd::DataType::DOUBLE); + NDArray x5('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::INT32); + NDArray x6('c', {2,3}, {-6,-5,-4,-3,-2,-1}, sd::DataType::INT32); + NDArray x7('c', {2,3}, {1.5,1.5,1.5,1.5,1.5,1.5}, sd::DataType::DOUBLE); + NDArray x8('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::DOUBLE); + + NDArray exp1('c', {}, std::vector{-30}, sd::DataType::FLOAT32); + NDArray exp2('c', {}, std::vector{15}, sd::DataType::DOUBLE); + NDArray exp3('c', {3}, {-18,-20,-18}, sd::DataType::FLOAT32); + NDArray exp4('c', {2}, {-28,-28}, sd::DataType::FLOAT32); + NDArray exp5('c', {3}, {7.5,10.5,13.5}, sd::DataType::DOUBLE); + NDArray exp6('c', {2}, {9,22.5}, sd::DataType::DOUBLE); + + auto result = x1.applyReduce3(reduce3::Dot, x2, {0,1}); + ASSERT_EQ(result, exp1); + + result = x3.applyReduce3(reduce3::Dot, x4, {0,1}); + ASSERT_EQ(result, exp2); + + result = x5.applyReduce3(reduce3::Dot, x6, std::vector({0})); + ASSERT_EQ(result, exp3); + + result = x5.applyReduce3(reduce3::Dot, x6, std::vector({1})); + ASSERT_EQ(result, exp4); + + result = x8.applyReduce3(reduce3::Dot, x7, std::vector({0})); + ASSERT_EQ(result, exp5); + + result = x8.applyReduce3(reduce3::Dot, x7, std::vector({1})); + ASSERT_EQ(result, exp6); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, applyAllReduce3_test1) { + + NDArray x1('c', {2,2}, {1,2,3,4}, sd::DataType::INT32); + NDArray x2('c', {2,3}, {-1,1,-1,1,-1,1}, sd::DataType::INT32); + NDArray x3('c', {2,3}, {1.5,1.5,1.5,1.5,1.5,1.5}, sd::DataType::DOUBLE); + NDArray x4('c', {2,2}, {1,2,3,4}, sd::DataType::DOUBLE); + NDArray exp1('c', {2,3}, {2,-2,2,2,-2,2}, sd::DataType::FLOAT32); + NDArray exp2('c', {2,3}, {6,6,6,9,9,9}, sd::DataType::DOUBLE); + + auto result = x1.applyAllReduce3(reduce3::Dot, x2, {0}); + ASSERT_EQ(result, exp1); + + result = x4.applyAllReduce3(reduce3::Dot, x3, {0}); + ASSERT_EQ(result, exp2); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, RowCol_test1) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x1('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::INT32); + NDArray x2('c', {2}, {0.5,0.6}, sd::DataType::FLOAT32); + NDArray x3('c', {3}, {1.5,1.6,1.7}, sd::DataType::FLOAT32); + NDArray x4('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::DOUBLE); + NDArray x5('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::INT32); + + NDArray exp1('c', {2,3}, {2,3,4,5,6,7}, sd::DataType::INT32); + NDArray exp2('c', {2,3}, {0,1,2,3,4,5}, sd::DataType::INT32); + NDArray exp3('c', {2,3}, {1.5,2.5,3.5,4.6,5.6,6.6}, sd::DataType::DOUBLE); + NDArray exp4('c', {2,3}, {0,1,1,2,3,3}, sd::DataType::INT32); + + x1.addiRowVector(x3); + ASSERT_EQ(x1, exp1); + + x1.addiColumnVector(x2); + ASSERT_EQ(x1, exp1); + + x4.addiColumnVector(x2); + ASSERT_EQ(x4, exp3); + + x5.muliColumnVector(x2); + ASSERT_EQ(x5, exp4); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, RowCol_test2) { + if (!Environment::getInstance().isExperimentalBuild()) + return; + + NDArray x1('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::INT32); + NDArray x2('c', {2}, {0.5,0.6}, sd::DataType::FLOAT32); + NDArray x3('c', {3}, {1.5,1.6,1.7}, sd::DataType::FLOAT32); + NDArray x4('c', {2,3}, sd::DataType::FLOAT32); + NDArray x5('c', {3}, {1,2,3}, sd::DataType::INT64); + NDArray x6('c', {2,3}, sd::DataType::INT32); + NDArray x7('c', {3}, {1.5,1.6,1.7}, sd::DataType::DOUBLE); + NDArray x8('c', {2,3}, {1,2,3,4,5,6}, sd::DataType::FLOAT32); + NDArray x9('c', {3}, {1,2,3}, sd::DataType::DOUBLE); + NDArray x10('c', {2,3}, sd::DataType::DOUBLE); + + NDArray exp1('c', {2,3}, {2.5,3.6,4.7,5.5,6.6,7.7}, sd::DataType::FLOAT32); + NDArray exp2('c', {2,3}, {2, 4, 6, 5, 7, 9}, sd::DataType::INT32); + NDArray exp3('c', {2,3}, {-0.5,0.4,1.3,2.5,3.4,4.3}, sd::DataType::FLOAT32); + NDArray exp4('c', {2,3}, {1,4,9,4,10,18}, sd::DataType::DOUBLE); + NDArray exp5('c', {2,3}, {1,1,1,4,2.5,2}, sd::DataType::DOUBLE); + NDArray exp6('c', {2,3}, {1.5,2.5,3.5,4.6,5.6,6.6}, sd::DataType::FLOAT32); + + x1.addRowVector(x3, x4); + ASSERT_EQ(x4, exp1); + + x1.addRowVector(x5, x6); + ASSERT_EQ(x6, exp2); + + x8.subRowVector(x7, x4); + ASSERT_EQ(x4, exp3); + + x1.mulRowVector(x9, x10); + ASSERT_EQ(x10, exp4); + + x1.divRowVector(x9, x10); + ASSERT_EQ(x10, exp5); + + x1.addColumnVector(x2, x4); + ASSERT_EQ(x4, exp6); +} + +////////////////////////////////////////////////////////////////////// +/* +TEST_F(MultiDataTypeTests, tile_test1) { + + NDArray x1('c', {2,1}, {0,1}, sd::DataType::INT32); + NDArray x2('c', {2,1}, {0.5,1.5}, sd::DataType::DOUBLE); + NDArray x3('c', {2,2}, sd::DataType::INT32); + NDArray x4('c', {2,2}, sd::DataType::DOUBLE); + NDArray x5('c', {1,2}, {0.5,1.5}, sd::DataType::DOUBLE);; + NDArray x6('c', {2,2}, sd::DataType::FLOAT32); + NDArray x7('c', {2,2}, sd::DataType::BOOL); + + NDArray exp1('c', {2,2}, {0,0,1,1}, sd::DataType::DOUBLE); + NDArray exp2('c', {2,2}, {0.5,1.5,0.5,1.5}, sd::DataType::FLOAT32); + NDArray exp3('c', {2,2}, {0,0,1,1}, sd::DataType::INT32); + NDArray exp4('c', {2,2}, {0,0,1,1}, sd::DataType::BOOL); + + x1.tile({1,2}, x4); + ASSERT_EQ(x4, exp1); + + x2.tile({1,2}, x3); + ASSERT_EQ(x3, exp3); + + x1.tile({1,2}, x7); + ASSERT_EQ(x7, exp4); + + x1.tile(x4); + ASSERT_EQ(x4, exp1); + + x2.tile(x3); + ASSERT_EQ(x3, exp3); + + x1.tile(x7); + ASSERT_EQ(x7, exp4); +} +*/ + +////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, asT_test1) { + + NDArray x1('c', {2}, {1.5, 2.5}, sd::DataType::FLOAT32); + + NDArray exp1('c', {2}, {1, 2}, sd::DataType::INT32); + NDArray exp2('c', {2}, {1.5, 2.5}, sd::DataType::DOUBLE); + + auto result = new NDArray(x1.asT()); + ASSERT_EQ(*result, exp1); + delete result; + + result = new NDArray(x1.asT()); + ASSERT_EQ(*result, exp2); + delete result; + + result = new NDArray(x1.asT(sd::DataType::INT32)); + ASSERT_EQ(*result, exp1); + delete result; + + result = new NDArray(x1.asT(sd::DataType::DOUBLE)); + ASSERT_EQ(*result, exp2); + delete result; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, assign_test2) { + + NDArray x1('c', {2,3}, {1.5,2.5,3.5,4.5,5.5,6.5}, sd::DataType::FLOAT32); + NDArray x2('c', {3,2}, sd::DataType::INT32); + NDArray x3('c', {3,2}, sd::DataType::DOUBLE); + NDArray x4('c', {3,2}, sd::DataType::BOOL); + NDArray x5('c', {2,3}, {1.5,2.5,0,4.5,5.5,6.5}, sd::DataType::FLOAT32); + + NDArray exp1('c', {3,2}, {1, 2,3,4,5,6}, sd::DataType::INT32); + NDArray exp2('c', {3,2}, {1.5,2.5,3.5,4.5,5.5,6.5}, sd::DataType::DOUBLE); + NDArray exp3('c', {3,2}, {1,1,0,1,1,1}, sd::DataType::BOOL); + + x2.assign(x1); + ASSERT_EQ(x2, exp1); + + x3.assign(x1); + ASSERT_EQ(x3, exp2); + + x4.assign(x5); + ASSERT_EQ(x4, exp3); +} + +TEST_F(MultiDataTypeTests, Test_Cast_1) { + auto first = NDArrayFactory::create('c', {10}); + auto asBool = NDArrayFactory::create('c', {10}); + auto _not = NDArrayFactory::create('c', {10}); + auto asFloat = NDArrayFactory::create('c', {10}); + auto exp = NDArrayFactory::create('c', {10}); + exp.assign(0.0f); + + asBool.assign(first); + + // asBool.printIndexedBuffer("asBool"); + asBool.applyScalar(scalar::Not, false, _not); + + // _not.printIndexedBuffer("_not"); + + asFloat.assign(_not); + + // asFloat.printIndexedBuffer("asFloat"); + ASSERT_EQ(exp, asFloat); +} + +TEST_F(MultiDataTypeTests, Test_Cast_2) { + auto first = NDArrayFactory::create('c', {10}); + auto asBool = NDArrayFactory::create('c', {10}); + auto _not = NDArrayFactory::create('c', {10}); + auto asFloat = NDArrayFactory::create('c', {10}); + auto exp = NDArrayFactory::create('c', {10}); + exp.assign(1.0f); + + asBool.assign(first); + + // asBool.printIndexedBuffer("asBool"); + asBool.applyTransform(transform::Not, _not); + + // _not.printIndexedBuffer("_not"); + + asFloat.assign(_not); + + // asFloat.printIndexedBuffer("asFloat"); + ASSERT_EQ(exp, asFloat); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, divide_bool_test1) { + + NDArray x1('c', {2,3}, {1.5,0,3.5,0,5.5,6.5}, sd::DataType::FLOAT32); + NDArray x2('c', {3,2}, {1,1,0,1,0,1}, sd::DataType::BOOL); + NDArray x3('c', {2,3}, sd::DataType::FLOAT32); + NDArray x4('c', {2}, sd::DataType::BOOL); + + try { + NDArray x3 = x1 / x2; + } + catch (std::exception& message) { + // printf("%s\n", message.what()); + ASSERT_TRUE(1); + } + + try { + x1 /= x2; + } + catch (std::exception& message) { + // printf("%s\n", message.what()); + ASSERT_TRUE(1); + } + + try { + NDArray x3 = 150. / x2; + } + catch (std::exception& message) { + // printf("%s\n", message.what()); + ASSERT_TRUE(1); + } + + try { + x1.divRowVector(x4, x3); + } + catch (std::exception& message) { + // printf("%s\n", message.what()); + ASSERT_TRUE(1); + } + + try { + x1.applyBroadcast(sd::broadcast::FloorDiv, {1}, x4, x3); + } + catch (std::exception& message) { + // printf("%s\n", message.what()); + ASSERT_TRUE(1); + } + + try { + x1.applyTrueBroadcast(BROADCAST(FloorMod), x2, x3); + } + catch (std::exception& message) { + // printf("%s\n", message.what()); + ASSERT_TRUE(1); + } +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, aaa) { + + NDArray z('c', {2,5}, {1,2,3,4,5,6,7,8,9,10}, sd::DataType::DOUBLE); + z.permutei({1,0}); + + sd::graph::RandomGenerator gen(119,5); + ExtraArguments extras({1.5, 2.5}); + + NativeOpExecutioner::execRandom(LaunchContext::defaultContext(), sd::random::UniformDistribution, + &gen, + z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + extras.argumentsAsT()); + // z.printIndexedBuffer(); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(MultiDataTypeTests, assign_2) +{ + NDArray x('c', {4}, {1.5,2.5,3.5,4.5}, sd::DataType::FLOAT32); + NDArray y('c', {4}, sd::DataType::INT32); + NDArray expected('c', {4}, {1,2,3,4}, sd::DataType::INT32); + + y.assign(x); + // y.printBuffer(); + + ASSERT_TRUE(expected.equalsTo(&y)); +} diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/MultiDeviceTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/MultiDeviceTests.cpp new file mode 100644 index 000000000..3ea90eb27 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/MultiDeviceTests.cpp @@ -0,0 +1,72 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include +#include + + +using namespace sd; + +class MultiDeviceTests : public testing::Test { +public: + +}; + +void createArrays(int limit, std::vector &arrays) { + auto deviceId = AffinityManager::currentDeviceId(); + auto numDevices = AffinityManager::numberOfDevices(); + + for (int e = 0; e < limit; e++) { + auto value = deviceId * limit + e; + arrays[value] = NDArrayFactory::create_('c', {10}); + arrays[value]->assign(value); + //nd4j_printf("device_%i; value: [%i]; mean: [%f]\n", deviceId, value, arrays[value]->meanNumber().e(0)); + } +} + +TEST_F(MultiDeviceTests, test_multi_device_migration_1) { + auto deviceId = AffinityManager::currentDeviceId(); + auto numDevices = AffinityManager::numberOfDevices(); + auto numArrays = 10; + std::vector arrays(numDevices * numArrays); + + // filling list of arrays on multiple threads + for (int e = 0; e < numDevices; e++) { + std::thread t1(createArrays, numArrays, std::ref(arrays)); + + t1.join(); + } + + // at this moment all arrays are build, so we can test migration + for (int e = 0; e < arrays.size(); e++) { + ASSERT_NEAR((float) e, arrays[e]->meanNumber().e(0), 1e-5f); + delete arrays[e]; + } +} diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/NDArrayConstructorsTests.cu b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/NDArrayConstructorsTests.cu new file mode 100644 index 000000000..1a3e07a9a --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/NDArrayConstructorsTests.cu @@ -0,0 +1,208 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +using namespace sd; +using namespace sd::graph; + +class NDArrayConstructorsTests : public testing::Test { +public: + +}; + +TEST_F(NDArrayConstructorsTests, test_constructor_1) { + auto x = NDArrayFactory::empty_(); + + ASSERT_TRUE(x->buffer() == nullptr); + ASSERT_TRUE(x->specialBuffer() == nullptr); + + ASSERT_FALSE(x->shapeInfo() == nullptr); + ASSERT_FALSE(x->specialShapeInfo() == nullptr); + + ASSERT_TRUE(x->isActualOnDeviceSide()); + ASSERT_TRUE(x->isActualOnHostSide()); + + delete x; +} + +TEST_F(NDArrayConstructorsTests, test_constructor_2) { + auto x = NDArrayFactory::vector(5, 1.0f); + + + ASSERT_FALSE(x->buffer() == nullptr); + ASSERT_FALSE(x->specialBuffer() == nullptr); + + ASSERT_FALSE(x->shapeInfo() == nullptr); + ASSERT_FALSE(x->specialShapeInfo() == nullptr); + + ASSERT_TRUE(x->isActualOnDeviceSide()); + ASSERT_FALSE(x->isActualOnHostSide()); + + delete x; +} + +TEST_F(NDArrayConstructorsTests, test_constructor_3) { + auto x = NDArrayFactory::create('c',{5, 5}); + + ASSERT_TRUE(x.buffer() == nullptr); + ASSERT_FALSE(x.specialBuffer() == nullptr); + + ASSERT_FALSE(x.shapeInfo() == nullptr); + ASSERT_FALSE(x.specialShapeInfo() == nullptr); + + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); +} + +TEST_F(NDArrayConstructorsTests, test_constructor_4) { + auto x = NDArrayFactory::create(sd::DataType::FLOAT32, 1.0f); + + ASSERT_FALSE(x.buffer() == nullptr); + ASSERT_FALSE(x.specialBuffer() == nullptr); + + ASSERT_FALSE(x.shapeInfo() == nullptr); + ASSERT_FALSE(x.specialShapeInfo() == nullptr); + + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_TRUE(x.isActualOnHostSide()); +} + +TEST_F(NDArrayConstructorsTests, test_constructor_5) { + auto x = NDArrayFactory::create('c',{2, 2}, {1, 2, 3, 4}); + + ASSERT_TRUE(x.buffer() == nullptr); + ASSERT_FALSE(x.specialBuffer() == nullptr); + + ASSERT_FALSE(x.shapeInfo() == nullptr); + ASSERT_FALSE(x.specialShapeInfo() == nullptr); + + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); +} + +TEST_F(NDArrayConstructorsTests, test_constructor_6) { + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + NDArray y(x); + + ASSERT_TRUE(y.buffer() == nullptr); + ASSERT_FALSE(y.specialBuffer() == nullptr); + + ASSERT_FALSE(y.shapeInfo() == nullptr); + ASSERT_FALSE(y.specialShapeInfo() == nullptr); + + ASSERT_TRUE(y.isActualOnDeviceSide()); + ASSERT_FALSE(y.isActualOnHostSide()); +} + +TEST_F(NDArrayConstructorsTests, test_constructor_7) { + auto x = NDArrayFactory::create(1.0f); + + ASSERT_FALSE(x.buffer() == nullptr); + ASSERT_FALSE(x.specialBuffer() == nullptr); + + ASSERT_FALSE(x.shapeInfo() == nullptr); + ASSERT_FALSE(x.specialShapeInfo() == nullptr); + + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_TRUE(x.isActualOnHostSide()); +} + +TEST_F(NDArrayConstructorsTests, test_constructor_8) { + auto x = NDArrayFactory::create_('c',{2, 2}, {1, 2, 3, 4}); + + ASSERT_TRUE(x->buffer() == nullptr); + ASSERT_FALSE(x->specialBuffer() == nullptr); + + ASSERT_FALSE(x->shapeInfo() == nullptr); + ASSERT_FALSE(x->specialShapeInfo() == nullptr); + + ASSERT_TRUE(x->isActualOnDeviceSide()); + ASSERT_FALSE(x->isActualOnHostSide()); + + delete x; +} + +TEST_F(NDArrayConstructorsTests, test_constructor_9) { + auto x = NDArrayFactory::create_('c',{2, 2}); + + ASSERT_TRUE(x->buffer() == nullptr); + ASSERT_FALSE(x->specialBuffer() == nullptr); + + ASSERT_FALSE(x->shapeInfo() == nullptr); + ASSERT_FALSE(x->specialShapeInfo() == nullptr); + + ASSERT_TRUE(x->isActualOnDeviceSide()); + ASSERT_FALSE(x->isActualOnHostSide()); + + delete x; +} + +TEST_F(NDArrayConstructorsTests, test_linspace_1) { + auto x = NDArrayFactory::linspace(1.0f, 10.0f, 20); + + ASSERT_FALSE(x->buffer() == nullptr); + ASSERT_FALSE(x->specialBuffer() == nullptr); + + ASSERT_FALSE(x->shapeInfo() == nullptr); + ASSERT_FALSE(x->specialShapeInfo() == nullptr); + + ASSERT_TRUE(x->isActualOnDeviceSide()); + ASSERT_TRUE(x->isActualOnHostSide()); + + delete x; +} + +TEST_F(NDArrayConstructorsTests, test_constructor_10) { + + NDArray scalar1(sd::DataType::DOUBLE); // scalar1 = 0 + NDArray scalar2('c', {}, std::vector{0}); + + ASSERT_TRUE(scalar1.isActualOnDeviceSide()); + ASSERT_TRUE(!scalar1.isActualOnHostSide()); + ASSERT_TRUE(scalar2.isActualOnDeviceSide()); + ASSERT_TRUE(scalar2.isActualOnHostSide()); + + ASSERT_TRUE(scalar2.equalsTo(scalar1)); + + ASSERT_TRUE(scalar1.isActualOnDeviceSide()); + ASSERT_TRUE(!scalar1.isActualOnHostSide()); + ASSERT_TRUE(scalar2.isActualOnDeviceSide()); + ASSERT_TRUE(scalar2.isActualOnHostSide()); + + ASSERT_TRUE(scalar1.buffer() == nullptr); + ASSERT_TRUE(scalar1.specialBuffer() != nullptr); + ASSERT_TRUE(scalar1.shapeInfo() != nullptr); + ASSERT_TRUE(scalar1.specialShapeInfo() != nullptr); + ASSERT_TRUE(scalar1.lengthOf() == 1); +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu new file mode 100644 index 000000000..935294fd6 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/NDArrayCudaBasicsTests.cu @@ -0,0 +1,2200 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // + // @author raver119@gmail.com + // + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +using namespace sd; +using namespace sd::graph; + +class NDArrayCudaBasicsTests : public testing::Test { +public: + +}; + +////////////////////////////////////////////////////////////////////////// +static cudaError_t allocateDeviceMem(LaunchContext& lc, std::vector& devicePtrs, const std::vector>& hostData) { + + if(devicePtrs.size() != hostData.size()) + throw std::invalid_argument("prepareDataForCuda: two input sts::vectors should same sizes !"); + + cudaError_t cudaResult; + + void* reductionPointer; + cudaResult = cudaMalloc(reinterpret_cast(&reductionPointer), 1024*1024); if(cudaResult != 0) return cudaResult; + int* allocationPointer; + cudaResult = cudaMalloc(reinterpret_cast(&allocationPointer), 1024*1024); if(cudaResult != 0) return cudaResult; + + lc.setReductionPointer(reductionPointer); + lc.setAllocationPointer(allocationPointer); + cudaStream_t stream = *lc.getCudaStream(); + + for(int i = 0; i < devicePtrs.size(); ++i) { + + cudaResult = cudaMalloc(reinterpret_cast(&devicePtrs[i]), hostData[i].second); if(cudaResult != 0) return cudaResult; + cudaMemcpyAsync(devicePtrs[i], hostData[i].first, hostData[i].second, cudaMemcpyHostToDevice, stream); + } + return cudaResult; +} + +TEST_F(NDArrayCudaBasicsTests, Test_Registration_1) { + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}, {5, 4, 3, 2, 1}); + + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); +} + +TEST_F(NDArrayCudaBasicsTests, Test_Registration_2) { + auto x = NDArrayFactory::create('c', {5}); + auto y = NDArrayFactory::create('c', {5}); + + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); +} + +TEST_F(NDArrayCudaBasicsTests, Test_Registration_3) { + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}, {5, 4, 3, 2, 1}); + + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); + + NDArray::registerSpecialUse({&x}, {&y}); + + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); + + ASSERT_TRUE(y.isActualOnDeviceSide()); + ASSERT_FALSE(y.isActualOnHostSide()); +} + +TEST_F(NDArrayCudaBasicsTests, Test_Registration_01) { + auto x = NDArrayFactory::create_('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create_('c', {5}, {5, 4, 3, 2, 1}); + + ASSERT_TRUE(x->isActualOnDeviceSide()); + ASSERT_FALSE(x->isActualOnHostSide()); + delete x; + delete y; +} + +TEST_F(NDArrayCudaBasicsTests, Test_Registration_02) { + auto x = NDArrayFactory::create_('c', {5}); + auto y = NDArrayFactory::create_('c', {5}); + + ASSERT_TRUE(x->isActualOnDeviceSide()); + ASSERT_FALSE(x->isActualOnHostSide()); + delete x; + delete y; +} + +TEST_F(NDArrayCudaBasicsTests, Test_Registration_03) { + auto x = NDArrayFactory::create_('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create_('c', {5}, {5, 4, 3, 2, 1}); + + ASSERT_TRUE(x->isActualOnDeviceSide()); + ASSERT_FALSE(x->isActualOnHostSide()); + + NDArray::registerSpecialUse({y}, {x}); + x->applyTransform(transform::Neg, *y); + //ASSERT_TRUE(x->isActualOnDeviceSide()); + //ASSERT_FALSE(x->isActualOnHostSide()); + + //ASSERT_TRUE(y->isActualOnDeviceSide()); + //ASSERT_TRUE(y->isActualOnHostSide()); + //y->syncToHost(); + // y->printBuffer("Negatives"); + delete x; + delete y; +} + +TEST_F(NDArrayCudaBasicsTests, Test_Cosine_1) { + auto x = NDArrayFactory::create_('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create_('c', {5}, {5, 4, 3, 2, 1}); + + ASSERT_TRUE(x->isActualOnDeviceSide()); + ASSERT_FALSE(x->isActualOnHostSide()); + + NDArray::registerSpecialUse({y}, {x}); + x->applyTransform(transform::Cosine, *y); + //ASSERT_TRUE(x->isActualOnDeviceSide()); + //ASSERT_FALSE(x->isActualOnHostSide()); + + //ASSERT_TRUE(y->isActualOnDeviceSide()); + //ASSERT_TRUE(y->isActualOnHostSide()); + //y->syncToHost(); + delete x; + delete y; +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, TestAdd_1) { + // allocating host-side arrays + auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); + auto z = NDArrayFactory::create('c', { 5 }, {10, 10, 10, 10, 10}); + + auto exp = NDArrayFactory::create('c', { 5 }, { 2, 4, 6, 8, 10 }); + + // making raw buffers + //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); + //ASSERT_EQ(0, res); + //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); + //ASSERT_EQ(0, res); + //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + //ASSERT_EQ(0, res); + + Nd4jPointer nativeStream = (Nd4jPointer)malloc(sizeof(cudaStream_t)); + CHECK_ALLOC(nativeStream, "Failed to allocate memory for new CUDA stream", sizeof(cudaStream_t)); + cudaError_t dZ = cudaStreamCreate(reinterpret_cast(&nativeStream)); + auto stream = reinterpret_cast(&nativeStream); + + //cudaMemcpyAsync(devBufferPtrX, x.buffer(), x.lengthOf() * x.sizeOfT(), cudaMemcpyHostToDevice, *stream); + //cudaMemcpyAsync(devShapePtrX, x.shapeInfo(), shape::shapeInfoByteLength(x.shapeInfo()), cudaMemcpyHostToDevice, *stream); + + LaunchContext lc(stream, nullptr, nullptr); + NativeOpExecutioner::execPairwiseTransform(&lc, pairwise::Add, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr); + z.tickWriteDevice(); + auto res = cudaStreamSynchronize(*stream); + ASSERT_EQ(0, res); + + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, TestAdd_2) { + // allocating host-side arrays + NDArray x('c', { 5 }, { 1, 2, 3, 4, 5}); + NDArray y('c', { 5 }, { 1, 2, 3, 4, 5}); + NDArray z('c', { 5 }, sd::DataType::DOUBLE); + + NDArray exp('c', { 5 }, { 2, 4, 6, 8, 10 }); + + Nd4jPointer nativeStream = (Nd4jPointer)malloc(sizeof(cudaStream_t)); + CHECK_ALLOC(nativeStream, "Failed to allocate memory for new CUDA stream", sizeof(cudaStream_t)); + cudaError_t dZ = cudaStreamCreate(reinterpret_cast(&nativeStream)); + auto stream = reinterpret_cast(&nativeStream); + + LaunchContext lc(stream, *stream, nullptr, nullptr); + NativeOpExecutioner::execPairwiseTransform(&lc, pairwise::Add, nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr); + auto res = cudaStreamSynchronize(*stream); + ASSERT_EQ(0, res); + + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, TestAdd_3) { + // allocating host-side arrays + auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); + auto z = NDArrayFactory::create('c', { 5 }, {10, 10, 10, 10, 10}); + + auto exp = NDArrayFactory::create('c', { 5 }, { 2, 4, 6, 8, 10 }); + + // making raw buffers + //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); + //ASSERT_EQ(0, res); + //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); + //ASSERT_EQ(0, res); + //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + //ASSERT_EQ(0, res); + + Nd4jPointer nativeStream = (Nd4jPointer)malloc(sizeof(cudaStream_t)); + CHECK_ALLOC(nativeStream, "Failed to allocate memory for new CUDA stream", sizeof(cudaStream_t)); + cudaError_t dZ = cudaStreamCreate(reinterpret_cast(&nativeStream)); + auto stream = reinterpret_cast(&nativeStream); + + //cudaMemcpyAsync(devBufferPtrX, x.buffer(), x.lengthOf() * x.sizeOfT(), cudaMemcpyHostToDevice, *stream); + //cudaMemcpyAsync(devShapePtrX, x.shapeInfo(), shape::shapeInfoByteLength(x.shapeInfo()), cudaMemcpyHostToDevice, *stream); + + LaunchContext lc(stream, *stream, nullptr, nullptr); + NativeOpExecutioner::execPairwiseTransform(&lc, pairwise::Add, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr); + z.tickWriteDevice(); + auto res = cudaStreamSynchronize(*stream); + ASSERT_EQ(0, res); + //double* localBuffer = ; + z.syncToHost(); + cudaMemcpy(z.buffer(), z.specialBuffer(), z.lengthOf() * z.sizeOfT(), cudaMemcpyDeviceToHost); + res = cudaStreamSynchronize(*stream); + z.tickWriteHost(); + ASSERT_EQ(0, res); + + // + // cudaFree(devBufferPtrX); + //cudaFree(devBufferPtrZ); + //cudaFree(devShapePtrX); + + for (int e = 0; e < z.lengthOf(); e++) { + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + } +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, TestAdd_4) { + // allocating host-side arrays + auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); + auto z = NDArrayFactory::create('c', { 5 }); + + auto exp = NDArrayFactory::create('c', { 5 }, { 2, 4, 6, 8, 10 }); + + // making raw buffers + //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); + //ASSERT_EQ(0, res); + //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); + //ASSERT_EQ(0, res); + //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + //ASSERT_EQ(0, res); + x.applyPairwiseTransform(pairwise::Add, y, z); + + // + // cudaFree(devBufferPtrX); + //cudaFree(devBufferPtrZ); + //cudaFree(devShapePtrX); + + for (int e = 0; e < z.lengthOf(); e++) { + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + } +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, TestAdd_5) { + // allocating host-side arrays + auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); + //auto z = NDArrayFactory::create('c', { 5 }); + + auto exp = NDArrayFactory::create('c', { 5 }, { 2, 4, 6, 8, 10 }); + + // making raw buffers + //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); + //ASSERT_EQ(0, res); + //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); + //ASSERT_EQ(0, res); + //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + //ASSERT_EQ(0, res); + x += y; + //x.applyPairwiseTransform(pairwise::Add, &y, &z, nullptr); + x.syncToHost(); + //y.printBuffer("3Y = "); + //z.printBuffer("3Result out"); + + // + // cudaFree(devBufferPtrX); + //cudaFree(devBufferPtrZ); + //cudaFree(devShapePtrX); + + for (int e = 0; e < x.lengthOf(); e++) { + ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); + } +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, TestAdd_6) { + // allocating host-side arrays + auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create(2); //.'c', { 5 }, { 1, 2, 3, 4, 5}); + //auto z = NDArrayFactory::create('c', { 5 }); + + auto exp = NDArrayFactory::create('c', { 5 }, { 3, 4, 5, 6, 7 }); + + // making raw buffers + //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); + //ASSERT_EQ(0, res); + //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); + //ASSERT_EQ(0, res); + //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + //ASSERT_EQ(0, res); + x += y; + //x.applyPairwiseTransform(pairwise::Add, &y, &z, nullptr); + x.syncToHost(); + + // cudaFree(devBufferPtrX); + //cudaFree(devBufferPtrZ); + //cudaFree(devShapePtrX); + + for (int e = 0; e < x.lengthOf(); e++) { + ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); + } +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, TestAdd_7) { + // allocating host-side arrays + auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); + //auto y = NDArrayFactory::create(2); //.'c', { 5 }, { 1, 2, 3, 4, 5}); + //auto z = NDArrayFactory::create('c', { 5 }); + + auto exp = NDArrayFactory::create('c', { 5 }, { 3, 4, 5, 6, 7 }); + + // making raw buffers + //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); + //ASSERT_EQ(0, res); + //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); + //ASSERT_EQ(0, res); + //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + //ASSERT_EQ(0, res); + x += 2.; + //x.applyPairwiseTransform(pairwise::Add, &y, &z, nullptr); + x.syncToHost(); + + // cudaFree(devBufferPtrX); + //cudaFree(devBufferPtrZ); + //cudaFree(devShapePtrX); + + for (int e = 0; e < x.lengthOf(); e++) { + ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); + } +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, TestMultiply_1) { + // allocating host-side arrays + auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); + auto z = NDArrayFactory::create('c', { 5 }); + + auto exp = NDArrayFactory::create('c', { 5 }, { 1, 4, 9, 16, 25 }); + + // making raw buffers + //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); + //ASSERT_EQ(0, res); + //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); + //ASSERT_EQ(0, res); + //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + //ASSERT_EQ(0, res); + x.applyPairwiseTransform(pairwise::Multiply, y, z); + // x.printBuffer("3X = "); + // y.printBuffer("3Y = "); + // z.printBuffer("3Result out"); + + // + // cudaFree(devBufferPtrX); + //cudaFree(devBufferPtrZ); + //cudaFree(devShapePtrX); + + for (int e = 0; e < z.lengthOf(); e++) { + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + } +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, TestMultiply_2) { + // allocating host-side arrays + auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); + NDArray z('c', { 5 }, sd::DataType::DOUBLE); + + auto exp = NDArrayFactory::create('c', { 5 }, { 1, 4, 9, 16, 25 }); + + // making raw buffers + //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); + //ASSERT_EQ(0, res); + //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); + //ASSERT_EQ(0, res); + //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + //ASSERT_EQ(0, res); + x.applyPairwiseTransform(pairwise::Multiply, y, z); + + // + // cudaFree(devBufferPtrX); + //cudaFree(devBufferPtrZ); + //cudaFree(devShapePtrX); + + for (int e = 0; e < z.lengthOf(); e++) { + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + } +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, TestMultiply_3) { + // allocating host-side arrays + NDArray x('c', { 5 }, { 1, 2, 3, 4, 5}, sd::DataType::DOUBLE); + NDArray y('c', { 5 }, { 1., 2., 3., 4., 5.}, sd::DataType::DOUBLE); + auto z = NDArrayFactory::create('c', { 5 }); + + auto exp = NDArrayFactory::create('c', { 5 }, { 1, 4, 9, 16, 25 }); + + // making raw buffers + //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); + //ASSERT_EQ(0, res); + //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); + //ASSERT_EQ(0, res); + //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + //ASSERT_EQ(0, res); + x.applyPairwiseTransform(pairwise::Multiply, y, z); + //x.printBuffer("23X = "); + //y.printBuffer("23Y = "); + // z.printBuffer("23Result out"); + + // + // cudaFree(devBufferPtrX); + //cudaFree(devBufferPtrZ); + //cudaFree(devShapePtrX); + + for (int e = 0; e < z.lengthOf(); e++) { + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + } +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, TestMultiply_4) { + // allocating host-side arrays + NDArray x('c', { 5 }, { 1, 2, 3, 4, 5}, sd::DataType::DOUBLE); + NDArray y('c', { 5 }, { 1., 2., 3., 4., 5.}, sd::DataType::DOUBLE); + //auto z = NDArrayFactory::create('c', { 5 }); + + auto exp = NDArrayFactory::create('c', { 5 }, { 1, 4, 9, 16, 25 }); + + // making raw buffers + //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); + //ASSERT_EQ(0, res); + //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); + //ASSERT_EQ(0, res); + //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + //ASSERT_EQ(0, res); + //x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); + //x.printBuffer("23X = "); + //y.printBuffer("23Y = "); + x *= y; + //x.tickWriteDevice(); + // x.printBuffer("33Result out"); + + // + // cudaFree(devBufferPtrX); + //cudaFree(devBufferPtrZ); + //cudaFree(devShapePtrX); + + for (int e = 0; e < x.lengthOf(); e++) { + ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); + } +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, TestPrimitiveNeg_01) { + // allocating host-side arrays + auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); + auto exp = NDArrayFactory::create('c', { 5 }, { -1, -2, -3, -4, -5 }); + + auto stream = x.getContext()->getCudaStream();//reinterpret_cast(&nativeStream); + + NativeOpExecutioner::execTransformSame(x.getContext(), transform::Neg, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), nullptr, nullptr, nullptr); + auto res = cudaStreamSynchronize(*stream); + ASSERT_EQ(0, res); + y.tickWriteDevice(); + + // x.printBuffer("X = "); + // y.printBuffer("Y = "); + + for (int e = 0; e < y.lengthOf(); e++) { + ASSERT_NEAR(exp.e(e), y.e(e), 1e-5); + } +} + +TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveNeg_2) { + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}); + + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); + + x.applyTransform(transform::Neg, y); + //ASSERT_TRUE(x->isActualOnDeviceSide()); + //ASSERT_FALSE(x->isActualOnHostSide()); + + //ASSERT_TRUE(y->isActualOnDeviceSide()); + //ASSERT_TRUE(y->isActualOnHostSide()); + //auto res = cudaStreamSynchronize(*y.getContext()->getCudaStream()); + //ASSERT_EQ(0, res); + // y.printBuffer("Negatives2"); + //delete x; + //delete y; +} + +TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveSqrt_1) { // strict + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}); + auto exp = NDArrayFactory::create({1.000000, 1.414214, 1.732051, 2.000000, 2.236068}); + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); + + x.applyTransform(transform::Sqrt, y); + //ASSERT_TRUE(x->isActualOnDeviceSide()); + //ASSERT_FALSE(x->isActualOnHostSide()); + + //ASSERT_TRUE(y->isActualOnDeviceSide()); + //ASSERT_TRUE(y->isActualOnHostSide()); + //auto res = cudaStreamSynchronize(*y.getContext()->getCudaStream()); + //ASSERT_EQ(0, res); + ASSERT_TRUE(y.equalsTo(exp)); + //y.printBuffer("SQRT output"); + //delete x; + //delete y; +} + +TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveAssign_1) { // strict + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}); + //auto exp = NDArrayFactory::create({1.000000, 1.414214, 1.732051, 2.000000, 2.236068}); + //ASSERT_TRUE(x.isActualOnDeviceSide()); + //ASSERT_TRUE(x.isActualOnHostSide()); + + x.applyTransform(transform::Assign, y); + //ASSERT_TRUE(x->isActualOnDeviceSide()); + //ASSERT_FALSE(x->isActualOnHostSide()); + + //ASSERT_TRUE(y->isActualOnDeviceSide()); + //ASSERT_TRUE(y->isActualOnHostSide()); + //auto res = cudaStreamSynchronize(*y.getContext()->getCudaStream()); + //ASSERT_EQ(0, res); + + // printf("Assigned to another array\n"); + // y.printBuffer("OUput"); + ASSERT_TRUE(y.equalsTo(x)); + //y.syncToHost(); + //y.printBuffer("IsMax output"); + //delete x; + //delete y; +} + +TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveCosine_1) { // strict + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}); + auto exp = NDArrayFactory::create('c', {5}, {0.540302, -0.416147, -0.989992, -0.653644, 0.283662}); + + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); + + x.applyTransform(transform::Cosine, y); + //ASSERT_TRUE(x->isActualOnDeviceSide()); + //ASSERT_FALSE(x->isActualOnHostSide()); + + //ASSERT_TRUE(y->isActualOnDeviceSide()); + //ASSERT_TRUE(y->isActualOnHostSide()); + //auto res = cudaStreamSynchronize(*y.getContext()->getCudaStream()); + //ASSERT_EQ(0, res); + ASSERT_TRUE(exp.isSameShape(y)); + ASSERT_TRUE(exp.dataType() == y.dataType()); + //y.printBuffer("Cosine2"); + //delete x; + //delete y; +} + +TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveCosine_2) { + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}); + auto exp = NDArrayFactory::create('c', {5}, {0.540302, -0.416147, -0.989992, -0.653644, 0.283662}); + + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); + x.applyTransform(transform::Cosine, y); + //ASSERT_TRUE(x->isActualOnDeviceSide()); + //ASSERT_FALSE(x->isActualOnHostSide()); + + //ASSERT_TRUE(y->isActualOnDeviceSide()); + //ASSERT_TRUE(y->isActualOnHostSide()); + //auto res = cudaStreamSynchronize(*y.getContext()->getCudaStream()); + //ASSERT_EQ(0, res); + //exp.syncToHost(); + //y.printBuffer("PrimitiveCosine2"); + //exp.printBuffer("Primitive Cosine exp"); + ASSERT_TRUE(exp.isSameShape(y)); + ASSERT_TRUE(exp.dataType() == y.dataType()); + //for (int e = 0; e < y.lengthOf(); e++) { + // ASSERT_NEAR(exp.e(e), y.e(e), 1e-5); + //} + + ASSERT_TRUE(exp.equalsTo(y)); + //delete x; + //delete y; +} + +TEST_F(NDArrayCudaBasicsTests, Test_PrimitiveCosine_3) { + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {5}); + auto exp = NDArrayFactory::create({0.540302, -0.416147, -0.989992, -0.653644, 0.283662}); + + ASSERT_TRUE(x.isActualOnDeviceSide()); + ASSERT_FALSE(x.isActualOnHostSide()); + x.applyTransform(transform::Cosine, y); + //ASSERT_TRUE(x->isActualOnDeviceSide()); + //ASSERT_FALSE(x->isActualOnHostSide()); + + //ASSERT_TRUE(y->isActualOnDeviceSide()); + //ASSERT_TRUE(y->isActualOnHostSide()); + //auto res = cudaStreamSynchronize(*y.getContext()->getCudaStream()); + //ASSERT_EQ(0, res); + //exp.syncToHost(); +// y.printBuffer("PrimitiveCosine3"); +// exp.printBuffer("Primitive Cosine3 exp"); +// y.printShapeInfo("Y shape"); +// exp.printShapeInfo("Exp Shape"); + ASSERT_TRUE(exp.isSameShape(y)); +// +// for (int e = 0; e < y.lengthOf(); e++) { +// printf("%lf == %lf\n", exp.e(e), y.e(e)); +//// ASSERT_NEAR(exp.e(e), y.e(e), 1e-5); +// } + + ASSERT_TRUE(exp.equalsTo(y)); + //delete x; + //delete y; +} + +TEST_F(NDArrayCudaBasicsTests, TestRawBroadcast_2) { + + //if (!Environment::getInstance().isExperimentalBuild()) + // return; + + NDArray x = NDArrayFactory::create('c', {2,3,4}); + NDArray y('c', {2,4}, {10,20,30,40,50,60,70,80}, sd::DataType::DOUBLE); + NDArray z('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::DOUBLE); +// NDArray exp('c', {2,3,4}, {10., 21., 32., 43., 14., 25., 36., 47., 18., 29., 40., 51., 62., 73., 84., 95., 66., 77., 88., 99., 70., 81., 92., 103}, sd::DataType::DOUBLE); + NDArray exp('c', {2,3,4}, {10., 40., 90., 160., 50., 120., 210., 320., 90., 200., 330., 480., 650., 840., 1050., 1280., 850., 1080., 1330., 1600., 1050., 1320., 1610., 1920.}, sd::DataType::DOUBLE); + x.linspace(1); x.syncToDevice(); + + std::vector dimensions = {0,2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t stream; + cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext lc(&stream); + + // allocate required amount of global device memory and copy host data to it + cudaResult = allocateDeviceMem(lc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); + + // call cuda kernel which calculates result + NativeOpExecutioner::execBroadcast(&lc, sd::broadcast::Multiply, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + (int*)devicePtrs[0], dimensions.size(), + (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], + nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + z.tickWriteDevice(); + + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for(int i = 0; i < devicePtrs.size(); ++i) + cudaFree(devicePtrs[i]); + + // delete cuda stream + cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +TEST_F(NDArrayCudaBasicsTests, TestRawBroadcast_3) { + + //if (!Environment::getInstance().isExperimentalBuild()) + // return; + + NDArray x('c', {2,3,4}, sd::DataType::DOUBLE); + NDArray y('c', {2,4}, {10,20,30,40,50,60,70,80}, sd::DataType::DOUBLE); + NDArray z('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::DOUBLE); +// NDArray exp('c', {2,3,4}, {10., 21., 32., 43., 14., 25., 36., 47., 18., 29., 40., 51., 62., 73., 84., 95., 66., 77., 88., 99., 70., 81., 92., 103}, sd::DataType::DOUBLE); + NDArray exp('c', {2,3,4}, {10., 40., 90., 160., 50., 120., 210., 320., 90., 200., 330., 480., 650., 840., 1050., 1280., 850., 1080., 1330., 1600., 1050., 1320., 1610., 1920.}, sd::DataType::DOUBLE); + x.linspace(1); x.syncToDevice(); + + std::vector dimensions = {0,2}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(int)); // 0 -- dimensions + hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + //cudaStream_t stream; + //cudaResult = cudaStreamCreate(&stream); ASSERT_EQ(0, cudaResult); + LaunchContext* pLc = x.getContext();//(&stream); + cudaStream_t* stream = pLc->getCudaStream(); + // allocate required amount of global device memory and copy host data to it +// cudaResult = allocateDeviceMem(*pLc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); + for(int i = 0; i < devicePtrs.size(); ++i) { + + cudaResult = cudaMalloc(reinterpret_cast(&devicePtrs[i]), hostData[i].second); ASSERT_EQ(0, cudaResult); + cudaMemcpyAsync(devicePtrs[i], hostData[i].first, hostData[i].second, cudaMemcpyHostToDevice, *stream); + } + + NDArray::registerSpecialUse({&z}, {&x, &y}); + // call cuda kernel which calculates result + NativeOpExecutioner::execBroadcast(pLc, sd::broadcast::Multiply, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + (int*)devicePtrs[0], dimensions.size(), + (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], + nullptr, nullptr); + + //cudaResult = cudaStreamSynchronize(stream); ASSERT_EQ(0, cudaResult); + //z.syncToHost(); + // verify results + for (int e = 0; e < z.lengthOf(); e++) + ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for(int i = 0; i < devicePtrs.size(); ++i) + cudaFree(devicePtrs[i]); + ASSERT_TRUE(exp.equalsTo(z)); + // delete cuda stream + //cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + + +TEST_F(NDArrayCudaBasicsTests, TestBroadcastMultiply_1) { + // allocating host-side arrays + NDArray x('c', { 2, 3 }, { 1, 2, 3, 4, 5, 6}, sd::DataType::DOUBLE); + NDArray y = NDArrayFactory::create(3.); //'c', { 3 }, { 2., 3., 4.}, sd::DataType::DOUBLE); + //auto z = NDArrayFactory::create('c', { 5 }); + + auto exp = NDArrayFactory::create('c', { 2, 3 }, { 3, 6, 9, 12, 15, 18 }); + + // making raw buffers + //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); + //ASSERT_EQ(0, res); + //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); + //ASSERT_EQ(0, res); + //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + //ASSERT_EQ(0, res); + //x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); + x *= y; + //x.syncToHost(); + + // + // cudaFree(devBufferPtrX); + //cudaFree(devBufferPtrZ); + //cudaFree(devShapePtrX); + ASSERT_TRUE(exp.equalsTo(x)); +// for (int e = 0; e < x.lengthOf(); e++) { +// ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); +// } +} + +TEST_F(NDArrayCudaBasicsTests, TestBroadcastMultiply_01) { + // allocating host-side arrays + NDArray x('c', { 2, 3 }, { 1, 2, 3, 4, 5, 6}, sd::DataType::DOUBLE); + NDArray y = NDArrayFactory::create(3.); //'c', { 3 }, { 2., 3., 4.}, sd::DataType::DOUBLE); + auto z = NDArrayFactory::create('c', { 2, 3 }); + + auto exp = NDArrayFactory::create('c', { 2, 3 }, { 3, 6, 9, 12, 15, 18 }); + + // making raw buffers + //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); + //ASSERT_EQ(0, res); + //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); + //ASSERT_EQ(0, res); + //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + //ASSERT_EQ(0, res); + //x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); + //x.printBuffer("23X = "); + //y.printBuffer("23Y = "); + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z);// *= y; + // z.printBuffer("53Result out"); + + // + // cudaFree(devBufferPtrX); + //cudaFree(devBufferPtrZ); + //cudaFree(devShapePtrX); + ASSERT_TRUE(exp.equalsTo(z)); + +// for (int e = 0; e < x.lengthOf(); e++) { +// ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); +// } +} + +TEST_F(NDArrayCudaBasicsTests, TestBroadcastMultiply_02) { + // allocating host-side arrays + auto x = NDArrayFactory::create('c', { 2, 3 }, { 1, 2, 3, 4, 5, 6}); //, sd::DataType::DOUBLE); + auto y = NDArrayFactory::create('c', {2,3}, {3, 3, 3, 3, 3, 3}); //'c', { 3 }, { 2., 3., 4.}, sd::DataType::DOUBLE); + auto z = NDArrayFactory::create('c', { 2, 3 }); + + auto exp = NDArrayFactory::create('c', { 2, 3 }, { 3, 6, 9, 12, 15, 18 }); + //if (x.isActualOnHostSide() && !x.isActualOnDeviceSide()) + // making raw buffers + //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); + //ASSERT_EQ(0, res); + //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); + //ASSERT_EQ(0, res); + //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + //ASSERT_EQ(0, res); + //x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); + //x.printBuffer("23X = "); + //y.printBuffer("23Y = "); + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, z);// *= y; + + // z.printBuffer("52Result out"); + + // + // cudaFree(devBufferPtrX); + //cudaFree(devBufferPtrZ); + //cudaFree(devShapePtrX); + ASSERT_TRUE(exp.equalsTo(z)); + +// for (int e = 0; e < x.lengthOf(); e++) { +// ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); +// } +} + +TEST_F(NDArrayCudaBasicsTests, TestBroadcastMultiply_002) { + // allocating host-side arrays + auto x = NDArrayFactory::create('c', { 2, 3 }, { 1, 2, 3, 4, 5, 6}); //, sd::DataType::DOUBLE); + auto y = NDArrayFactory::create('c', {2, 3}, {2., 3., 3., 3., 3., 3.}); //'c', { 3 }, { 2., 3., 4.}, sd::DataType::DOUBLE); + auto z = NDArrayFactory::create('c', { 2, 3 }); + + auto exp = NDArrayFactory::create('c', { 2, 3 }, { 2, 6, 9, 12, 15, 18 }); + //if (x.isActualOnHostSide() && !x.isActualOnDeviceSide()) + // making raw buffers + //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); + //ASSERT_EQ(0, res); + //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); + //ASSERT_EQ(0, res); + //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + //ASSERT_EQ(0, res); + //x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); + //x.printBuffer("23X = "); + //y.printBuffer("23Y = "); + x.applyPairwiseTransform(pairwise::Multiply, y, z);// *= y; + + // z.printBuffer("51Result out"); + + // + // cudaFree(devBufferPtrX); + //cudaFree(devBufferPtrZ); + //cudaFree(devShapePtrX); + ASSERT_TRUE(exp.equalsTo(z)); + +// for (int e = 0; e < x.lengthOf(); e++) { +// ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); +// } +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, TestBroadcastRaw_1) { + + //if (!Environment::getInstance().isExperimentalBuild()) + // return; + + NDArray x('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::INT32); + NDArray y('c', {3}, {10, 20, 30}, sd::DataType::INT64); + NDArray z('c', {2,3,4}, {100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100,100}, sd::DataType::INT32); + NDArray exp('c', {2,3,4}, {10, 11, 12, 13,24, 25, 26, 27,38, 39, 40, 41,22, 23, 24, 25,36, 37, 38, 39,50, 51, 52, 53}, sd::DataType::INT32); + //real output [10, 11, 12, 13, 4, 5, 6, 7, 28, 29, 30, 31, 22, 23, 24, 25, 16, 17, 18, 19, 40, 41, 42, 43] + x.linspace(0); x.syncToDevice(); + + std::vector dimensions = {1}; + + // evaluate xTad data + shape::TAD xTad; + xTad.init(x.shapeInfo(), dimensions.data(), dimensions.size()); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); + + // prepare input arrays for prepareDataForCuda function + std::vector> hostData; + hostData.emplace_back(dimensions.data(), dimensions.size() * sizeof(Nd4jLong)); // 0 -- dimensions + hostData.emplace_back(xTad.tadOnlyShapeInfo, shape::shapeInfoByteLength(xTad.tadOnlyShapeInfo)); // 1 -- xTadShapeInfo + hostData.emplace_back(xTad.tadOffsets, xTad.numTads * sizeof(Nd4jLong)); // 2 -- xTadOffsets + std::vector devicePtrs(hostData.size(), nullptr); + + // create cuda stream and LaunchContext + cudaError_t cudaResult; + cudaStream_t* stream = x.getContext()->getCudaStream(); + LaunchContext* pLc = x.getContext(); + + // allocate required amount of global device memory and copy host data to it + //cudaResult = allocateDeviceMem(*pLc, devicePtrs, hostData); ASSERT_EQ(0, cudaResult); + for(size_t i = 0; i < devicePtrs.size(); ++i) { + cudaResult = cudaMalloc(&devicePtrs[i], hostData[i].second); //if(cudaResult != 0) return cudaResult; + ASSERT_EQ(cudaResult, 0); + cudaMemcpy(devicePtrs[i], hostData[i].first, hostData[i].second, cudaMemcpyHostToDevice); + } + + // call cuda kernel which calculates result + NativeOpExecutioner::execBroadcast(pLc, sd::broadcast::Add, + nullptr, x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), + nullptr, y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo(), + nullptr, z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), + (int*)devicePtrs[0], dimensions.size(), + (Nd4jLong*)devicePtrs[1], (Nd4jLong*)devicePtrs[2], + nullptr, nullptr); + + cudaResult = cudaStreamSynchronize(*stream); ASSERT_EQ(0, cudaResult); + + // x.printIndexedBuffer(" X"); + // y.printIndexedBuffer("+Y"); + // z.printBuffer("ADD broadcasted output"); + // verify results + // for (int e = 0; e < z.lengthOf(); e++) + // ASSERT_NEAR(exp.e(e), z.e(e), 1e-5); + + // free allocated global device memory + for(int i = 0; i < devicePtrs.size(); ++i) + cudaFree(devicePtrs[i]); + + // delete cuda stream + //cudaResult = cudaStreamDestroy(stream); ASSERT_EQ(0, cudaResult); +} + +TEST_F(NDArrayCudaBasicsTests, TestBroadcastMultiply) { + // allocating host-side arrays + NDArray x('c', { 2, 3 }, { 1, 2, 3, 4, 5, 6}, sd::DataType::DOUBLE); + NDArray y('c', { 3 }, { 2., 3., 4.}, sd::DataType::DOUBLE); + //auto z = NDArrayFactory::create('c', { 5 }); + + auto exp = NDArrayFactory::create('c', { 2, 3 }, { 2, 6, 12, 8, 15, 24 }); + + // making raw buffers + //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); + //ASSERT_EQ(0, res); + //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); + //ASSERT_EQ(0, res); + //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + //ASSERT_EQ(0, res); + //x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); + //x.printBuffer("23X = "); + //y.printBuffer("23Y = "); + x *= y; + + // + // cudaFree(devBufferPtrX); + //cudaFree(devBufferPtrZ); + //cudaFree(devShapePtrX); + + //for (int e = 0; e < x.lengthOf(); e++) { + // ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); + //} +} + + +TEST_F(NDArrayCudaBasicsTests, TestBroadcastMultiply_2) { + // allocating host-side arrays + NDArray x('c', { 2, 3 }, { 1, 2, 3, 4, 5, 6}, sd::DataType::DOUBLE); + NDArray y('c', { 3 }, { 2., 3., 4.}, sd::DataType::DOUBLE); + //auto z = NDArrayFactory::create('c', { 5 }); + + auto exp = NDArrayFactory::create('c', { 2, 3 }, { 11,12, 13,14, 15, 16 }); + auto expZ = NDArrayFactory::create('c', { 2, 3 }, { 2, 6, 12, 8, 15, 24 }); + + // making raw buffers + //Nd4jPointer devBufferPtrX, devBufferPtrZ, devShapePtrX; + //cudaError_t res = cudaMalloc(reinterpret_cast(&devBufferPtrX), x.lengthOf() * x.sizeOfT()); + //ASSERT_EQ(0, res); + //res = cudaMalloc(reinterpret_cast(&devBufferPtrZ), x.lengthOf() * x.sizeOfT()); + //ASSERT_EQ(0, res); + //res = cudaMalloc(reinterpret_cast(&devShapePtrX), shape::shapeInfoByteLength(x.shapeInfo())); + //ASSERT_EQ(0, res); + //x.applyPairwiseTransform(pairwise::Multiply, &y, &z, nullptr); + //x.printBuffer("23X = "); + //y.printBuffer("23Y = "); + //void NDArray::applyTrueBroadcast(sd::BroadcastOpsTuple op, const NDArray* other, NDArray* target, const bool checkTargetShape, ExtraArguments *extraArgs) + x.applyTrueBroadcast(BroadcastOpsTuple::Multiply(), y, exp); + + // + // cudaFree(devBufferPtrX); + //cudaFree(devBufferPtrZ); + //cudaFree(devShapePtrX); + + //for (int e = 0; e < x.lengthOf(); e++) { + // ASSERT_NEAR(exp.e(e), x.e(e), 1e-5); + //} + ASSERT_TRUE(exp.equalsTo(expZ)); + +} + + +////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, TestReduceSum_1) { + // allocating host-side arrays + auto x = NDArrayFactory::create('c', { 5 }, { 1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create(15); + auto exp = NDArrayFactory::create(15); + + auto stream = x.getContext()->getCudaStream();//reinterpret_cast(&nativeStream); + + NativeOpExecutioner::execReduceSameScalar(x.getContext(), reduce::Sum, x.buffer(), x.shapeInfo(), x.specialBuffer(), x.specialShapeInfo(), nullptr, y.buffer(), y.shapeInfo(), y.specialBuffer(), y.specialShapeInfo()); + auto res = cudaStreamSynchronize(*stream); + ASSERT_EQ(0, res); + y.syncToHost(); + + ASSERT_NEAR(y.e(0), 15, 1e-5); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, TestDup1) { + + NDArray array('c', {2,3}, {1,2,3,4,5,6}); + auto arrC = array.dup('c'); + auto arrF = array.dup('f'); + // arrC->printBuffer("arrC"); + + // arrF->printBuffer("arrF"); + //arrC->printShapeInfo("C shape"); + //arrF->printShapeInfo("F shape"); + + ASSERT_TRUE(array.equalsTo(arrF)); + ASSERT_TRUE(array.equalsTo(arrC)); + + ASSERT_TRUE(arrF.equalsTo(arrC)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, equalsTo_1) { + + NDArray x('c', {2,5}, {1,2,3,4,5,6,7,8,9,10}, sd::DataType::DOUBLE); + NDArray y('c', {2,5}, {1,2,3,4,5,6,7,8,9,10}, sd::DataType::DOUBLE); + + ASSERT_TRUE(x.equalsTo(y)); + + x.permutei({1,0}); + y.permutei({1,0}); + + ASSERT_TRUE(x.equalsTo(y)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, equalsTo_2) { + + NDArray x('c', {2,5}, {1,2,3,4,5,6,7,8,10,10}, sd::DataType::DOUBLE); + NDArray y('c', {2,5}, {1,2,5,4,5,6,7,8,9,10}, sd::DataType::DOUBLE); + + ASSERT_FALSE(x.equalsTo(y)); + + x.permutei({1,0}); + y.permutei({1,0}); + + ASSERT_FALSE(x.equalsTo(y)); +} + +////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, equalsTo_3) { + + NDArray x('c', {2,5}, {1,2,3,4,5,6,7,8,9,10}, sd::DataType::DOUBLE); + NDArray y('c', {2,5}, {1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f}, sd::DataType::FLOAT32); + + ASSERT_FALSE(x.equalsTo(y)); + + x.permutei({1,0}); + y.permutei({1,0}); + + ASSERT_FALSE(x.equalsTo(y)); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, applyReduce3_1) { + + NDArray x('c', {2,3,4}, {-10,-9,-8,-7,-6,-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13}, sd::DataType::INT32); + NDArray x2('c', {2,3,4}, {-10,-9,-8,-7,-6,-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13}, sd::DataType::INT32); + NDArray y('c', {2,3,4}, {-2,3,-4,5,-2,3,-4,5,-2,3,-4,5,-2,3,-4,5,-2,3,-4,5,-2,3,-4,5}, sd::DataType::INT32); + NDArray k('c', {2,3}, {-2,3,-4,5,-2,3}, sd::DataType::INT32); + NDArray k2('c', {3,2}, {-2,3,-4,5,-2,3}, sd::DataType::INT32); + + NDArray exp1('c', {3}, {4.f, 20.f, 36.f}, sd::DataType::FLOAT32); + NDArray exp2('c', {2,3}, {-10.f, -2.f, 6.f,14.f, 22.f, 30.f}, sd::DataType::FLOAT32); + NDArray exp3('c', {4}, {38.f, 41.f, 44.f, 47.f}, sd::DataType::FLOAT32); + NDArray exp4('c', {4}, {114.f, 117.f, 120.f, 123.f}, sd::DataType::FLOAT32); + + + NDArray z = x.applyReduce3(sd::reduce3::Dot, y, {0,2}); + ASSERT_TRUE(z.equalsTo(&exp1)); + + z = x.applyReduce3(sd::reduce3::Dot, k, {0,1}); + ASSERT_TRUE(z.equalsTo(&exp3)); + + x.permutei({0,2,1}); + y.permutei({0,2,1}); + + z = y.applyReduce3(sd::reduce3::Dot, x, {1}); + ASSERT_TRUE(z.equalsTo(&exp2)); + + x2.permutei({1,0,2}); + + z = x2.applyReduce3(sd::reduce3::Dot, k2, {0,1}); + ASSERT_TRUE(z.equalsTo(&exp4)); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, applyReduce3_2) { + + NDArray x('c', {2,3,4}, {-10,-9,-8.5,-7,-6,-5,-4,-3,-2,-1,0,1,2,3,4,5,6,7,8,9,10,11,12,13}, sd::DataType::DOUBLE); + NDArray x2('c', {2,3,4}, {-10,-9,-8,-7,-6,-5,-4,-3,-2,-1,0.5,1,2,3,4,5,6,7,8,9,10,11,12,13}, sd::DataType::DOUBLE); + NDArray y('c', {2,3,4}, {-2,3,-4,5,-2,3,-4,5,-2,3,-4,5,-2.5,3,-4,5,-2,3,-4,5,-2,3,-4,5}, sd::DataType::DOUBLE); + NDArray k('c', {2,3}, {-2,3,-4,5.5,-2,3}, sd::DataType::DOUBLE); + NDArray k2('c', {3,2}, {-2,3,-4,5,-2,3.5}, sd::DataType::DOUBLE); + + NDArray exp1('c', {3}, {5., 20., 36.}, sd::DataType::DOUBLE); + NDArray exp2('c', {2,3}, {-8., -2., 6., 13., 22., 30.}, sd::DataType::DOUBLE); + NDArray exp3('c', {4}, {39., 42.5, 47., 49.5}, sd::DataType::DOUBLE); + NDArray exp4('c', {4}, {119., 122.5, 125., 129.5}, sd::DataType::DOUBLE); + + NDArray z = x.applyReduce3(sd::reduce3::Dot, y, {0,2}); + ASSERT_TRUE(z.equalsTo(&exp1)); + + z = x.applyReduce3(sd::reduce3::Dot, k, {0,1}); + ASSERT_TRUE(z.equalsTo(&exp3)); + + x.permutei({0,2,1}); + y.permutei({0,2,1}); + + z = y.applyReduce3(sd::reduce3::Dot, x, {1}); + ASSERT_TRUE(z.equalsTo(&exp2)); + + x2.permutei({1,0,2}); + + z = x2.applyReduce3(sd::reduce3::Dot, k2, {0,1}); + ASSERT_TRUE(z.equalsTo(&exp4)); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, applyReduce3_3) { + + NDArray x1('c', {2,2,2}, {1,2,3,4,5,6,7,8}, sd::DataType::INT32); + NDArray x2('c', {2,2,2}, {-1,-2,-3,-4,-5,-6,-7,-8}, sd::DataType::INT32); + NDArray x3('c', {3,2}, {1.5,1.5,1.5,1.5,1.5,1.5}, sd::DataType::DOUBLE); + NDArray x4('c', {3,2}, {1,2,3,4,5,6}, sd::DataType::DOUBLE); + + NDArray exp1('c', {}, std::vector{-204}, sd::DataType::FLOAT32); + NDArray exp2('c', {}, std::vector{31.5}, sd::DataType::DOUBLE); + + + auto z = x1.applyReduce3(reduce3::Dot, x2); + ASSERT_TRUE(z.equalsTo(&exp1)); + + z = x3.applyReduce3(reduce3::Dot, x4); + ASSERT_TRUE(z.equalsTo(&exp2)); + + x1.permutei({2,1,0}); + x2.permutei({2,1,0}); + x3.permutei({1,0}); + x4.permutei({1,0}); + + z = x1.applyReduce3(reduce3::Dot, x2); + ASSERT_TRUE(z.equalsTo(&exp1)); + + z = x3.applyReduce3(reduce3::Dot, x4); + ASSERT_TRUE(z.equalsTo(&exp2)); +} + +//////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, applyAllReduce3_1) { + + NDArray x1('c', {2,3,2}, {1,2,3,4,5,6,7,8,-1,-2,-3,-4,}, sd::DataType::INT32); + NDArray x2('c', {2,2,2}, {-1,-2,-3,-4,-5,-6,-7,-8}, sd::DataType::INT32); + NDArray x3('c', {3,2}, {1.5,1.5,1.5,1.5,1.5,1.5}, sd::DataType::DOUBLE); + NDArray x4('c', {3,2}, {1,2,3,4,5,6}, sd::DataType::DOUBLE); + + NDArray exp1('c', {3,2}, {-88.f, -124.f, 6.f, -2.f, 22.f, 14.f}, sd::DataType::FLOAT32); + NDArray exp2('c', {6,4}, {-36.f, -44.f, -52.f, -60.f,-42.f, -52.f, -62.f, -72.f, 2.f, 0.f, -2.f, + -4.f, 6.f, 4.f, 2.f, 0.f, 10.f, 8.f, 6.f, 4.f, 14.f, 12.f, 10.f, 8.f}, + sd::DataType::FLOAT32); + NDArray exp3('c', {1,1}, std::vector{31.5}, sd::DataType::DOUBLE); + NDArray exp4('c', {3,3}, {4.5, 10.5, 16.5,4.5, 10.5, 16.5,4.5, 10.5, 16.5}, sd::DataType::DOUBLE); + + auto z = x1.applyAllReduce3(reduce3::Dot, x2, {0,2}); + ASSERT_TRUE(z.equalsTo(&exp1)); + + z = x1.applyAllReduce3(reduce3::Dot, x2, {0}); + ASSERT_TRUE(z.equalsTo(&exp2)); + + z = x3.applyAllReduce3(reduce3::Dot, x4, {0,1}); + ASSERT_TRUE(z.equalsTo(&exp3)); + + z = x3.applyAllReduce3(reduce3::Dot, x4, {1}); + ASSERT_TRUE(z.equalsTo(&exp4)); + + x1.permutei({2,1,0}); + x2.permutei({2,1,0}); + x3.permutei({1,0}); + x4.permutei({1,0}); + + z = x1.applyAllReduce3(reduce3::Dot, x2, {0,2}); + ASSERT_TRUE(z.equalsTo(&exp1)); + + z = x3.applyAllReduce3(reduce3::Dot, x4, {0}); + ASSERT_TRUE(z.equalsTo(&exp4)); +} + +////////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, applyIndexReduce_test1) { + + NDArray x('c', {2,3}, {0, 10, 1, 2, 2.5,-4}, sd::DataType::DOUBLE); + + NDArray scalar('c', {}, std::vector{100}, sd::DataType::INT64); + NDArray vec1('c', {2}, {100,100}, sd::DataType::INT64); + NDArray vec2('c', {3}, {100,100,100}, sd::DataType::INT64); + + NDArray exp1('c', {}, std::vector{1}, sd::DataType::INT64); + NDArray exp2('c', {2}, {1,1}, sd::DataType::INT64); + NDArray exp3('c', {3}, {1,0,0}, sd::DataType::INT64); + + NDArray exp4('c', {}, std::vector{2}, sd::DataType::INT64); + NDArray exp5('c', {2}, {1,1}, sd::DataType::INT64); + NDArray exp6('c', {3}, {1,0,0}, sd::DataType::INT64); + + x.applyIndexReduce(sd::indexreduce::IndexMax, scalar, {0,1}); + ASSERT_TRUE(scalar.equalsTo(&exp1)); + + x.applyIndexReduce(sd::indexreduce::IndexMax, vec1, {1}); + ASSERT_TRUE(vec1.equalsTo(&exp2)); + + x.applyIndexReduce(sd::indexreduce::IndexMax, vec2, {0}); + ASSERT_TRUE(vec2.equalsTo(&exp3)); + + x.permutei({1,0}); + + x.applyIndexReduce(sd::indexreduce::IndexMax, scalar, {0,1}); + ASSERT_TRUE(scalar.equalsTo(&exp4)); + + x.applyIndexReduce(sd::indexreduce::IndexMax, vec1, {0}); + ASSERT_TRUE(vec1.equalsTo(&exp5)); + + x.applyIndexReduce(sd::indexreduce::IndexMax, vec2, {1}); + ASSERT_TRUE(vec2.equalsTo(&exp6)); +} + + +////////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, applyIndexReduce_test2) { + + NDArray x('c', {2,3}, {0, 10, 1, 2, 2.5,-4}, sd::DataType::DOUBLE); + + NDArray exp1('c', {}, std::vector{1}, sd::DataType::INT64); + NDArray exp2('c', {2}, {1,1}, sd::DataType::INT64); + NDArray exp3('c', {3}, {1,0,0}, sd::DataType::INT64); + + NDArray exp4('c', {}, std::vector{2}, sd::DataType::INT64); + NDArray exp5('c', {2}, {1,1}, sd::DataType::INT64); + NDArray exp6('c', {3}, {1,0,0}, sd::DataType::INT64); + + auto z = x.applyIndexReduce(sd::indexreduce::IndexMax, {0,1}); + ASSERT_TRUE(z.equalsTo(&exp1)); + + z = x.applyIndexReduce(sd::indexreduce::IndexMax, {1}); + ASSERT_TRUE(z.equalsTo(&exp2)); + + z = x.applyIndexReduce(sd::indexreduce::IndexMax, {0}); + ASSERT_TRUE(z.equalsTo(&exp3)); + + x.permutei({1,0}); + + z = x.applyIndexReduce(sd::indexreduce::IndexMax, {0,1}); + ASSERT_TRUE(z.equalsTo(&exp4)); + + z = x.applyIndexReduce(sd::indexreduce::IndexMax, {0}); + ASSERT_TRUE(z.equalsTo(&exp5)); + + z = x.applyIndexReduce(sd::indexreduce::IndexMax, {1}); + ASSERT_TRUE(z.equalsTo(&exp6)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_float_test1) { + + NDArray x('c', {2,3,2}, {1,2,3,4,5,6,7,8,-1,-2,-3,-4,}, sd::DataType::INT32); + + NDArray z1('c', {}, std::vector{100}, sd::DataType::DOUBLE); + NDArray z2('c', {2,2}, {100,100,100,100}, sd::DataType::FLOAT32); + NDArray z3('c', {3}, {100,100,100}, sd::DataType::DOUBLE); + NDArray z4('c', {3,2}, {100,100,100,100,100,100}, sd::DataType::FLOAT32); + NDArray z5('c', {2}, {100,100}, sd::DataType::FLOAT32); + + NDArray exp1('c', {}, std::vector{2.166667}, sd::DataType::DOUBLE); + NDArray exp2('c', {2,2}, {3.f,4.f,1.f,0.666667f}, sd::DataType::FLOAT32); + NDArray exp3('c', {3}, {4.5,1,1}, sd::DataType::DOUBLE); + NDArray exp4('c', {3,2}, {4,5,1,1,1,1}, sd::DataType::FLOAT32); + NDArray exp5('c', {2}, {3.5f,0.833333f}, sd::DataType::FLOAT32); + + x.reduceAlongDimension(sd::reduce::Mean, z1, {0,1,2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); + + x.reduceAlongDimension(sd::reduce::Mean, z2, {1}); + ASSERT_TRUE(z2.equalsTo(&exp2)); + + x.reduceAlongDimension(sd::reduce::Mean, z3, {0,2}); + ASSERT_TRUE(z3.equalsTo(&exp3)); + + x.permutei({1,0,2}); // 3x2x2 + + x.reduceAlongDimension(sd::reduce::Mean, z1, {0,1,2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); + + x.reduceAlongDimension(sd::reduce::Mean, z4, {1}); + ASSERT_TRUE(z4.equalsTo(&exp4)); + + x.reduceAlongDimension(sd::reduce::Mean, z5, {0,2}); + ASSERT_TRUE(z5.equalsTo(&exp5)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_float_test2) { + + NDArray x('c', {2,3,2}, {1,2,3,4,5,6,7,8,-1,-2,-3,-4,}, sd::DataType::DOUBLE); + + NDArray exp1('c', {}, std::vector{2.166667}, sd::DataType::DOUBLE); + NDArray exp2('c', {2,2}, {3,4,1,0.666667}, sd::DataType::DOUBLE); + NDArray exp3('c', {3}, {4.5,1,1}, sd::DataType::DOUBLE); + NDArray exp4('c', {3,2}, {4,5,1,1,1,1}, sd::DataType::DOUBLE); + NDArray exp5('c', {2}, {3.5,0.833333}, sd::DataType::DOUBLE); + + NDArray z1 = x.reduceAlongDimension(sd::reduce::Mean, {0,1,2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); + + NDArray z2 = x.reduceAlongDimension(sd::reduce::Mean, {1}); + ASSERT_TRUE(z2.equalsTo(&exp2)); + + NDArray z3 = x.reduceAlongDimension(sd::reduce::Mean, {0,2}); + ASSERT_TRUE(z3.equalsTo(&exp3)); + + x.permutei({1,0,2}); // 3x2x2 + + NDArray z4 = x.reduceAlongDimension(sd::reduce::Mean, {0,1,2}); + ASSERT_TRUE(z4.equalsTo(&exp1)); + + NDArray z5 = x.reduceAlongDimension(sd::reduce::Mean, {1}); + ASSERT_TRUE(z5.equalsTo(&exp4)); + + NDArray z6 = x.reduceAlongDimension(sd::reduce::Mean, {0,2}); + ASSERT_TRUE(z6.equalsTo(&exp5)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, EqualityTest1) { + auto arrayA = NDArrayFactory::create_('f', {3, 5}); + auto arrayB = NDArrayFactory::create_('f', {3, 5}); + auto arrayC = NDArrayFactory::create_('f', {3, 5}); + + auto arrayD = NDArrayFactory::create_('f', {2, 4}); + auto arrayE = NDArrayFactory::create_('f', {1, 15}); + + for (int i = 0; i < arrayA->rows(); i++) { + for (int k = 0; k < arrayA->columns(); k++) { + arrayA->p(i, k, (float) i); + } + } + + for (int i = 0; i < arrayB->rows(); i++) { + for (int k = 0; k < arrayB->columns(); k++) { + arrayB->p(i, k, (float) i); + } + } + + for (int i = 0; i < arrayC->rows(); i++) { + for (int k = 0; k < arrayC->columns(); k++) { + arrayC->p(i, k, (float) i+1); + } + } + + ASSERT_TRUE(arrayA->equalsTo(arrayB, 1e-5)); + + ASSERT_FALSE(arrayC->equalsTo(arrayB, 1e-5)); + + ASSERT_FALSE(arrayD->equalsTo(arrayB, 1e-5)); + + ASSERT_FALSE(arrayE->equalsTo(arrayB, 1e-5)); + + delete arrayA; + delete arrayB; + delete arrayC; + delete arrayD; + delete arrayE; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_same_test1) { + + NDArray x('c', {2,3,2}, {1.5f,2.f,3.f,4.f,5.f,6.f,7.5f,8.f,-1.f,-2.f,-3.5f,-4.f}, sd::DataType::FLOAT32); + + NDArray z1('c', {}, std::vector{100}, sd::DataType::FLOAT32); + NDArray z2('c', {2,2}, {100,100,100,100}, sd::DataType::FLOAT32); + NDArray z3('c', {3}, {100,100,100}, sd::DataType::FLOAT32); + NDArray z4('c', {3,2}, {100,100,100,100,100,100}, sd::DataType::FLOAT32); + NDArray z5('c', {2}, {100,100}, sd::DataType::FLOAT32); + + NDArray exp1('c', {}, std::vector{26.5f}, sd::DataType::FLOAT32); + NDArray exp2('c', {2,2}, {9.5f,12.f,3.f,2.f}, sd::DataType::FLOAT32); + NDArray exp3('c', {3}, {19.f,4.f,3.5f}, sd::DataType::FLOAT32); + NDArray exp4('c', {3,2}, {9.f,10.f,2.f,2.f,1.5f,2.f}, sd::DataType::FLOAT32); + NDArray exp5('c', {2}, {21.5f,5.f}, sd::DataType::FLOAT32); + + x.reduceAlongDimension(sd::reduce::Sum, z1, {0,1,2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); + + x.reduceAlongDimension(sd::reduce::Sum, z2, {1}); + ASSERT_TRUE(z2.equalsTo(&exp2)); + + x.reduceAlongDimension(sd::reduce::Sum, z3, {0,2}); + ASSERT_TRUE(z3.equalsTo(&exp3)); + + x.permutei({1,0,2}); // 3x2x2 + + x.reduceAlongDimension(sd::reduce::Sum, z1, {0,1,2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); + + x.reduceAlongDimension(sd::reduce::Sum, z4, {1}); + ASSERT_TRUE(z4.equalsTo(&exp4)); + + x.reduceAlongDimension(sd::reduce::Sum, z5, {0,2}); + ASSERT_TRUE(z5.equalsTo(&exp5)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_same_test2) { + + NDArray x('c', {2,3,2}, {1.5,2,3,4,5,6,7.5,8,-1,-2,-3.5,-4,}, sd::DataType::INT64); + + NDArray exp1('c', {}, std::vector{26}, sd::DataType::INT64); + NDArray exp2('c', {2,2}, {9,12,3,2}, sd::DataType::INT64); + NDArray exp3('c', {3}, {18,4,4}, sd::DataType::INT64); + NDArray exp4('c', {3,2}, {8,10,2,2,2,2}, sd::DataType::INT64); + NDArray exp5('c', {2}, {21,5}, sd::DataType::INT64); + + NDArray z1 = x.reduceAlongDimension(sd::reduce::Sum, {0,1,2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); + + NDArray z2 = x.reduceAlongDimension(sd::reduce::Sum, {1}); + ASSERT_TRUE(z2.equalsTo(&exp2)); + + NDArray z3 = x.reduceAlongDimension(sd::reduce::Sum, {0,2}); + ASSERT_TRUE(z3.equalsTo(&exp3)); + + x.permutei({1,0,2}); // 3x2x2 + + NDArray z4 = x.reduceAlongDimension(sd::reduce::Sum, {0,1,2}); + ASSERT_TRUE(z4.equalsTo(&exp1)); + + NDArray z5 = x.reduceAlongDimension(sd::reduce::Sum, {1}); + ASSERT_TRUE(z5.equalsTo(&exp4)); + + NDArray z6 = x.reduceAlongDimension(sd::reduce::Sum, {0,2}); + ASSERT_TRUE(z6.equalsTo(&exp5)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_bool_test1) { + + NDArray x('c', {2,3,2}, {0.5,2,3,-4,5,6,-7.5,8,-1,-0.5,-3.5,4}, sd::DataType::DOUBLE); + + NDArray z1('c', {}, std::vector{true}, sd::DataType::BOOL); + NDArray z2('c', {2,2}, {true,true,true,true}, sd::DataType::BOOL); + NDArray z3('c', {3}, {true,true,true}, sd::DataType::BOOL); + NDArray z4('c', {3,2}, {true,true,true,true,true,true}, sd::DataType::BOOL); + NDArray z5('c', {2}, {true,true}, sd::DataType::BOOL); + + NDArray exp1('c', {}, std::vector{true}, sd::DataType::BOOL); + NDArray exp2('c', {2,2}, {true,true,false,true}, sd::DataType::BOOL); + NDArray exp3('c', {3}, {true,true,true}, sd::DataType::BOOL); + NDArray exp4('c', {3,2}, {true,true,true,false,true,true}, sd::DataType::BOOL); + NDArray exp5('c', {2}, {true,true}, sd::DataType::BOOL); + + x.reduceAlongDimension(sd::reduce::IsPositive, z1, {0,1,2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); + + x.reduceAlongDimension(sd::reduce::IsPositive, z2, {1}); + ASSERT_TRUE(z2.equalsTo(&exp2)); + + x.reduceAlongDimension(sd::reduce::IsPositive, z3, {0,2}); + ASSERT_TRUE(z3.equalsTo(&exp3)); + + x.permutei({1,0,2}); // 3x2x2 + + x.reduceAlongDimension(sd::reduce::IsPositive, z1, {0,1,2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); + + x.reduceAlongDimension(sd::reduce::IsPositive, z4, {1}); + ASSERT_TRUE(z4.equalsTo(&exp4)); + + x.reduceAlongDimension(sd::reduce::IsPositive, z5, {0,2}); + ASSERT_TRUE(z5.equalsTo(&exp5)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_bool_test2) { + + NDArray x('c', {2,3,2}, {0.5,2,3,-4,5,6,-7.5,8,-1,-0.5,-3.5,4}, sd::DataType::INT32); + + NDArray exp1('c', {}, std::vector{1}, sd::DataType::BOOL); + NDArray exp2('c', {2,2}, {1,1,0,1}, sd::DataType::BOOL); + NDArray exp3('c', {3}, {1,1,1}, sd::DataType::BOOL); + NDArray exp4('c', {3,2}, {0,1,1,0,1,1}, sd::DataType::BOOL); + NDArray exp5('c', {2}, {1,1}, sd::DataType::BOOL); + + NDArray z1 = x.reduceAlongDimension(sd::reduce::IsPositive, {0,1,2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); + + NDArray z2 = x.reduceAlongDimension(sd::reduce::IsPositive, {1}); + ASSERT_TRUE(z2.equalsTo(&exp2)); + + NDArray z3 = x.reduceAlongDimension(sd::reduce::IsPositive, {0,2}); + ASSERT_TRUE(z3.equalsTo(&exp3)); + + x.permutei({1,0,2}); // 3x2x2 + + NDArray z4 = x.reduceAlongDimension(sd::reduce::IsPositive, {0,1,2}); + ASSERT_TRUE(z4.equalsTo(&exp1)); + + NDArray z5 = x.reduceAlongDimension(sd::reduce::IsPositive, {1}); + ASSERT_TRUE(z5.equalsTo(&exp4)); + + NDArray z6 = x.reduceAlongDimension(sd::reduce::IsPositive, {0,2}); + ASSERT_TRUE(z6.equalsTo(&exp5)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_long_test1) { + + NDArray x('c', {2,3,2}, {0.5f,2.f,3.f,-0.f,5.f,6.f,-7.5f,0.f,-1.f,-0.5f,-3.5f,4.f}, sd::DataType::FLOAT32); + + NDArray z1('c', {}, std::vector{100}, sd::DataType::INT64); + NDArray z2('c', {2,2}, {100,100,100,100}, sd::DataType::INT64); + NDArray z3('c', {3}, {100,100,100}, sd::DataType::INT64); + NDArray z4('c', {3,2}, {100,100,100,100,100,100}, sd::DataType::INT64); + NDArray z5('c', {2}, {100,100}, sd::DataType::INT64); + + NDArray exp1('c', {}, std::vector{2}, sd::DataType::INT64); + NDArray exp2('c', {2,2}, {0,1,0,1}, sd::DataType::INT64); + NDArray exp3('c', {3}, {1,1,0}, sd::DataType::INT64); + NDArray exp4('c', {3,2}, {0,1,0,1,0,0}, sd::DataType::INT64); + NDArray exp5('c', {2}, {1,1}, sd::DataType::INT64); + + x.reduceAlongDimension(sd::reduce::CountZero, z1, {0,1,2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); + + x.reduceAlongDimension(sd::reduce::CountZero, z2, {1}); + ASSERT_TRUE(z2.equalsTo(&exp2)); + + x.reduceAlongDimension(sd::reduce::CountZero, z3, {0,2}); + ASSERT_TRUE(z3.equalsTo(&exp3)); + + x.permutei({1,0,2}); // 3x2x2 + + x.reduceAlongDimension(sd::reduce::CountZero, z1, {0,1,2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); + + x.reduceAlongDimension(sd::reduce::CountZero, z4, {1}); + ASSERT_TRUE(z4.equalsTo(&exp4)); + + x.reduceAlongDimension(sd::reduce::CountZero, z5, {0,2}); + ASSERT_TRUE(z5.equalsTo(&exp5)); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, reduceAlongDimension_long_test2) { + + NDArray x('c', {2,3,2}, {0.5,2,3,-0,5,6,-7.5,0,-1,-0.5,-3.5,4}, sd::DataType::INT32); + + NDArray exp1('c', {}, std::vector{4}, sd::DataType::INT64); + NDArray exp2('c', {2,2}, {1,1,0,2}, sd::DataType::INT64); + NDArray exp3('c', {3}, {2,2,0}, sd::DataType::INT64); + NDArray exp4('c', {3,2}, {1,1,0,2,0,0}, sd::DataType::INT64); + NDArray exp5('c', {2}, {2,2}, sd::DataType::INT64); + + NDArray z1 = x.reduceAlongDimension(sd::reduce::CountZero, {0,1,2}); + ASSERT_TRUE(z1.equalsTo(&exp1)); + + NDArray z2 = x.reduceAlongDimension(sd::reduce::CountZero, {1}); + ASSERT_TRUE(z2.equalsTo(&exp2)); + + NDArray z3 = x.reduceAlongDimension(sd::reduce::CountZero, {0,2}); + ASSERT_TRUE(z3.equalsTo(&exp3)); + + x.permutei({1,0,2}); // 3x2x2 + + NDArray z4 = x.reduceAlongDimension(sd::reduce::CountZero, {0,1,2}); + ASSERT_TRUE(z4.equalsTo(&exp1)); + + NDArray z5 = x.reduceAlongDimension(sd::reduce::CountZero, {1}); + ASSERT_TRUE(z5.equalsTo(&exp4)); + + NDArray z6 = x.reduceAlongDimension(sd::reduce::CountZero, {0,2}); + ASSERT_TRUE(z6.equalsTo(&exp5)); +} + +TEST_F(NDArrayCudaBasicsTests, BroadcastOpsTest1) { + + auto x = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); + auto row = NDArrayFactory::linspace(1.0f, 5.0f, 5); + NDArray expRow('c', {1, 5,}, {1,2,3,4,5}, sd::DataType::FLOAT32); + NDArray exp('c', {5,5}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5}, sd::DataType::FLOAT32); + + ASSERT_TRUE(row->equalsTo(&expRow)); + + x.applyBroadcast(broadcast::Add, {1}, *row, z); + x += *row; + + ASSERT_TRUE(x.equalsTo(z)); + //ASSERT_TRUE(z.equalsTo(&exp)); + + delete row; +} + +TEST_F(NDArrayCudaBasicsTests, BroadcastOpsTest2) { + + auto x = NDArrayFactory::create('c', {5, 5}); + //auto z = NDArrayFactory::create('c', {5, 5}); + auto row = NDArrayFactory::linspace(1.0f, 5.0f, 5); + NDArray expRow('c', {1, 5,}, {1,2,3,4,5}, sd::DataType::FLOAT32); + NDArray exp('c', {5,5}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5}, sd::DataType::FLOAT32); + + ASSERT_TRUE(row->equalsTo(&expRow)); + x.applyBroadcast(broadcast::Add, {1}, *row, x); + ASSERT_TRUE(x.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, TestBroadcast_1) { + + NDArray exp('c', {2, 3, 2, 2}, {1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3., 1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3.}, sd::DataType::DOUBLE); + + auto input = NDArrayFactory::create('c',{ 2, 3, 2, 2}); + auto bias = NDArrayFactory::create('c', {1, 3}); + + bias.linspace(1); + input.applyBroadcast(broadcast::Add, {1}, bias, input); + ASSERT_TRUE(exp.equalsTo(&input)); +} + +TEST_F(NDArrayCudaBasicsTests, TestFloat16_1) { + auto x = NDArrayFactory::create({1,2,3,4,5,7,8,9}); + auto y = NDArrayFactory::create({1,2,3,4,5,7,8,9}); + ASSERT_TRUE(x.equalsTo(&y)); +} + +TEST_F(NDArrayCudaBasicsTests, TestFloat16_2) { + auto x = NDArrayFactory::create('c', {9}, {1,2,3,4,5,6,7,8,9}); + auto y = NDArrayFactory::create('c', {9}, {1,2,3,4,5,6,7,8,9}); + ASSERT_TRUE(x.equalsTo(y)); + //for (int e = 0; e < x.lengthOf(); e++) + // ASSERT_NEAR(x.e(e), y.e(e), 1.e-5f); +} + +TEST_F(NDArrayCudaBasicsTests, TestFloat16_3) { + auto x = NDArrayFactory::create({1,2,3,4,5,7,8,9}); + auto y = NDArrayFactory::create({1,2,3,4,5,7,8,9}); + ASSERT_TRUE(x.equalsTo(&y)); +} + +TEST_F(NDArrayCudaBasicsTests, TestFloat_4) { + auto x = NDArrayFactory::create({1,2,3,4,5,7,8,9}); + auto y = NDArrayFactory::create({2,4,5,5,6,7,8,9}); + ASSERT_FALSE(x.equalsTo(&y)); +} + +TEST_F(NDArrayCudaBasicsTests, TestFloat_5) { + auto x = NDArrayFactory::create('c', {3,3}, {1,2,3,4,5,6,7,8,9}); + auto y = NDArrayFactory::create('c', {3,3}, {2,4,5,5,6,7,8,9, 10}); + ASSERT_FALSE(x.equalsTo(&y)); +} + +TEST_F(NDArrayCudaBasicsTests, TestFloat_6) { + auto x = NDArrayFactory::create('f', {3,3}, {1,2,3,4,5,6,7,8,9}); + auto y = NDArrayFactory::create('f', {3,3}, {2,4,5,5,6,7,8,9,10}); + ASSERT_FALSE(x.equalsTo(&y)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_05) +{ + auto x = NDArrayFactory::create('c', {8, 8, 8}); + auto y = NDArrayFactory::create('c', {1, 8, 8}); + auto expected = NDArrayFactory::create('c', {8, 8, 8}); + NDArray res2 = NDArrayFactory::create(expected.ordering(), expected.getShapeAsVector()); + x = 1.; + y = 2.; + expected = 3.; + res2 = 0.f; + + x.applyTrueBroadcast(BroadcastOpsTuple::Add(), y, res2);// *= y; + + ASSERT_TRUE(expected.isSameShape(&res2)); + ASSERT_TRUE(expected.equalsTo(&res2)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_5) +{ + auto x = NDArrayFactory::create('c', {8, 8, 8}); + auto y = NDArrayFactory::create('c', {8, 1, 8}); + auto expected = NDArrayFactory::create('c', {8, 8, 8}); + NDArray res2(expected); + x = 1.; + y = 2.; + expected = 3.; + //x.printBuffer("X="); + //y.printBuffer("Y="); + //expected.printBuffer("EXPECTED"); + auto result = x + y; + //result.printBuffer("1 + 2 ="); + //res2.assign(x + y); + + //x.applyTrueBroadcast(BroadcastOpsTuple::Add(), &y, &res2); + //res2.printBuffer("Z="); + //x.applyTrueBroadcast(BroadcastOpsTuple::Add(), &y, &res2);// *= y; +// x += y; + //x.printBuffer("OutputX"); + //res2.syncToHost(); + //res2.printBuffer("OUputZ"); + //x.printIndexedBuffer("OUtputX"); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_51) +{ + auto x = NDArrayFactory::create('c', {8, 8, 8}); + auto y = NDArrayFactory::create('c', {8, 8}); + auto expected = NDArrayFactory::create('c', {8, 8, 8}); + NDArray res2(expected); + x = 1.; + y = 2.; + expected = 3.; + //x.printBuffer("X="); + //y.printBuffer("Y="); + //expected.printBuffer("EXPECTED"); + auto result = x + y; + //result.printBuffer("1 + 2 ="); + //res2.assign(x + y); + + //x.applyTrueBroadcast(BroadcastOpsTuple::Add(), &y, &res2); + //res2.printBuffer("Z="); + //x.applyTrueBroadcast(BroadcastOpsTuple::Add(), &y, &res2);// *= y; +// x += y; + //x.printBuffer("OutputX"); + //res2.syncToHost(); + //res2.printBuffer("OUputZ"); + //x.printIndexedBuffer("OUtputX"); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); +} + +TEST_F(NDArrayCudaBasicsTests, Tile_Test_2_1) +{ + auto x = NDArrayFactory::create('c', {2, 1, 2}); + x = 10.; + auto y = x.tile({1,2,1}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}); + exp = 10.; + + // y.printShapeInfo("Output SHAPE"); + // y.printBuffer("Output TILE"); + // exp.printBuffer("Expect TILE"); + ASSERT_TRUE(exp.equalsTo(y)); +} + +TEST_F(NDArrayCudaBasicsTests, Tile_Test_2_2) +{ + auto x = NDArrayFactory::create('f', {2, 1, 2}); + x = 10.; + auto y = x.tile({1,2,1}); + auto exp = NDArrayFactory::create('f', {2, 2, 2}); + exp = 10.; + ASSERT_TRUE(exp.equalsTo(y)); +} + +TEST_F(NDArrayCudaBasicsTests, Tile_Test_2_3) +{ + auto x = NDArrayFactory::create('f', {2, 1, 2}); + x = 10.; + x.p(1,0,1, 20); + x.syncToDevice(); + auto y = x.tile({1,2,1}); + auto exp = NDArrayFactory::create('f', {2, 2, 2}); + exp = 10.; + exp.p(1,0,1, 20.); + exp.p(1, 1, 1, 20.); + exp.syncToDevice(); + ASSERT_TRUE(exp.equalsTo(y)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, Operator_Plus_Test_2) +{ + double expBuff[] = {2., 3, 3., 4., 4., 5, 5., 6., 6., 7, 7., 8.}; + NDArray a('c', {4,4}, {1,2,3,4,5,6,7,8,9,2,3,2,1,0,4,7}, sd::DataType::FLOAT32); + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); + + x.linspace(1); + y.linspace(1); + auto result = x + y; + + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, assign_2) +{ + NDArray x('c', {4}, {1.5f,2.5f,3.5f,4.5f}, sd::DataType::FLOAT32); + NDArray y('c', {4}, sd::DataType::INT32); + NDArray expected('c', {4}, {1,2,3,4}, sd::DataType::INT32); + + y.assign(x); + // y.printBuffer("ASSIGN VECTOR"); + + ASSERT_TRUE(expected.equalsTo(&y)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, subarray_1) +{ + NDArray x('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}, sd::DataType::FLOAT32); + NDArray y('f', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}, sd::DataType::FLOAT32); + + Nd4jLong shapeExpX0[] = {1, 2, 12, 8192, 1, 99}; + float buffExpX0[] = {1.f, 13.f}; + Nd4jLong shapeExpX1[] = {1, 2, 12, 8192, 1, 99}; + float buffExpX1[] = {2.f, 14.f}; + Nd4jLong shapeExpX2[] = {3, 2, 1, 1, 12, 4, 1, 8192, 1, 99}; + float buffExpX2[] = {1.f, 13.f}; + Nd4jLong shapeExpX3[] = {2, 2, 4, 12, 1, 8192, 1, 99}; + float buffExpX3[] = {9.f, 10.f, 11.f, 12.f, 21.f, 22.f, 23.f, 24.f}; + Nd4jLong shapeExpX4[] = {3, 2, 1, 4, 12, 4, 1, 8192, 1, 99}; + float buffExpX4[] = {9.f, 10.f, 11.f, 12.f, 21.f, 22.f, 23.f, 24.f}; + Nd4jLong shapeExpX5[] = {2, 2, 3, 12, 4, 8192, 1, 99}; + float buffExpX5[] = {4.f, 8.f, 12.f, 16.f, 20.f, 24.f}; + + Nd4jLong shapeExpY0[] = {1, 2, 1, 8192, 1, 99}; + float buffExpY0[] = {1.f, 2.f}; + Nd4jLong shapeExpY1[] = {1, 2, 1, 8192, 1, 99}; + float buffExpY1[] = {7.f, 8.f}; + Nd4jLong shapeExpY2[] = {3, 2, 1, 1, 1, 2, 6, 8192, 1, 102}; + float buffExpY2[] = {1.f, 2.f}; + Nd4jLong shapeExpY3[] = {2, 2, 4, 1, 6, 8192, 1, 99}; + float buffExpY3[] = {5.f, 11.f, 17.f, 23.f, 6.f, 12.f, 18.f, 24.f}; + Nd4jLong shapeExpY4[] = {3, 2, 1, 4, 1, 2, 6, 8192, 1, 102}; + float buffExpY4[] = {5.f, 11.f, 17.f, 23.f, 6.f, 12.f, 18.f, 24.f}; + Nd4jLong shapeExpY5[] = {2, 2, 3, 1, 2, 8192, 1, 99}; + float buffExpY5[] = {19.f, 21.f, 23.f, 20.f, 22.f, 24.f}; + + + NDArray x0 = x(0, {1,2}); + NDArray xExp(buffExpX0, shapeExpX0); + + ASSERT_TRUE(xExp.isSameShape(x0)); + ASSERT_TRUE(xExp.equalsTo(x0)); +// for(int i = 0; i < shape::shapeInfoLength(x0.rankOf()); ++i) +// ASSERT_TRUE(x0.shapeInfo()[i] == shapeExpX0[i]); +// for(int i = 0; i < x0.lengthOf(); ++i) +// ASSERT_TRUE(x0.e(i) == buffExpX0[i]); + + NDArray x1 = x(1, {1,2}); + NDArray x1Exp(buffExpX1, shapeExpX1); + ASSERT_TRUE(x1Exp.isSameShape(x1)); + ASSERT_TRUE(x1Exp.equalsTo(x1)); + +// for(int i = 0; i < shape::shapeInfoLength(x1.rankOf()); ++i) +// ASSERT_TRUE(x1.shapeInfo()[i] == shapeExpX1[i]); +// for(int i = 0; i < x1.lengthOf(); ++i) +// ASSERT_TRUE(x1.e(i) == buffExpX1[i]); + + NDArray x2 = x(0, {1,2}, true); + NDArray x2Exp(buffExpX2, shapeExpX2); + ASSERT_TRUE(x2Exp.isSameShape(x2)); +// x2.printBuffer("X2"); +// x2Exp.printBuffer("X2 EXPECT"); + ASSERT_TRUE(x2Exp.equalsTo(x2)); +// for(int i = 0; i < shape::shapeInfoLength(x2.rankOf()); ++i) +// ASSERT_TRUE(x2.shapeInfo()[i] == shapeExpX2[i]); +// for(int i = 0; i < x2.lengthOf(); ++i) +// ASSERT_TRUE(x2.e(i) == buffExpX2[i]); + + NDArray x3 = x(2, {1}); + NDArray x3Exp(buffExpX3, shapeExpX3); + ASSERT_TRUE(x3Exp.isSameShape(x3)); + ASSERT_TRUE(x3Exp.equalsTo(x3)); +// for(int i = 0; i < shape::shapeInfoLength(x3.rankOf()); ++i) +// ASSERT_TRUE(x3.shapeInfo()[i] == shapeExpX3[i]); +// for(int i = 0; i < x3.lengthOf(); ++i) +// ASSERT_TRUE(x3.e(i) == buffExpX3[i]); + + NDArray x4 = x(2, {1}, true); + NDArray x4Exp(buffExpX4, shapeExpX4); + ASSERT_TRUE(x4Exp.isSameShape(x4)); + ASSERT_TRUE(x4Exp.equalsTo(x4)); +// for(int i = 0; i < shape::shapeInfoLength(x4.rankOf()); ++i) +// ASSERT_TRUE(x4.shapeInfo()[i] == shapeExpX4[i]); +// for(int i = 0; i < x4.lengthOf(); ++i) +// ASSERT_TRUE(x4.e(i) == buffExpX4[i]); + + NDArray x5 = x(3, {2}); + NDArray x5Exp(buffExpX5, shapeExpX5); + ASSERT_TRUE(x5Exp.isSameShape(x5)); + ASSERT_TRUE(x5Exp.equalsTo(x5)); + +// for(int i = 0; i < shape::shapeInfoLength(x5.rankOf()); ++i) +// ASSERT_TRUE(x5.shapeInfo()[i] == shapeExpX5[i]); +// for(int i = 0; i < x5.lengthOf(); ++i) +// ASSERT_TRUE(x5.e(i) == buffExpX5[i]); + + // ******************* // + NDArray y0 = y(0, {1,2}); + NDArray y0Exp(buffExpY0, shapeExpY0); + ASSERT_TRUE(y0Exp.isSameShape(y0)); + ASSERT_TRUE(y0Exp.equalsTo(y0)); +// for(int i = 0; i < shape::shapeInfoLength(y0.rankOf()); ++i) +// ASSERT_TRUE(y0.shapeInfo()[i] == shapeExpY0[i]); +// for(int i = 0; i < y0.lengthOf(); ++i) +// ASSERT_TRUE(y0.e(i) == buffExpY0[i]); + + NDArray y1 = y(1, {1,2}); + NDArray y1Exp(buffExpY1, shapeExpY1); + ASSERT_TRUE(y1Exp.isSameShape(y1)); + ASSERT_TRUE(y1Exp.equalsTo(y1)); +// for(int i = 0; i < shape::shapeInfoLength(y1.rankOf()); ++i) +// ASSERT_TRUE(y1.shapeInfo()[i] == shapeExpY1[i]); +// for(int i = 0; i < y1.lengthOf(); ++i) +// ASSERT_TRUE(y1.e(i) == buffExpY1[i]); + + NDArray y2 = y(0, {1,2}, true); + NDArray y2Exp(buffExpY2, shapeExpY2); + ASSERT_TRUE(y2Exp.isSameShape(y2)); + ASSERT_TRUE(y2Exp.equalsTo(y2)); +// for(int i = 0; i < shape::shapeInfoLength(y2.rankOf()); ++i) +// ASSERT_TRUE(y2.shapeInfo()[i] == shapeExpY2[i]); +// for(int i = 0; i < y2.lengthOf(); ++i) +// ASSERT_TRUE(y2.e(i) == buffExpY2[i]); + + NDArray y3 = y(2, {1}); + NDArray y3Exp(buffExpY3, shapeExpY3); + ASSERT_TRUE(y3Exp.isSameShape(y3)); + ASSERT_TRUE(y3Exp.equalsTo(y3)); +// for(int i = 0; i < shape::shapeInfoLength(y3.rankOf()); ++i) +// ASSERT_TRUE(y3.shapeInfo()[i] == shapeExpY3[i]); +// for(int i = 0; i < y3.lengthOf(); ++i) +// ASSERT_TRUE(y3.e(i) == buffExpY3[i]); + + NDArray y4 = y(2, {1}, true); + NDArray y4Exp = NDArrayFactory::create('f', {2,1,4}, {5, 6, 11, 12, 17, 18, 23, 24}); + ASSERT_TRUE(y4Exp.isSameShape(y4)); + ASSERT_TRUE(y4Exp.equalsTo(y4)); +// for(int i = 0; i < shape::shapeInfoLength(y4.rankOf()); ++i) +// ASSERT_TRUE(y4.shapeInfo()[i] == shapeExpY4[i]); +// for(int i = 0; i < y4.lengthOf(); ++i) +// ASSERT_TRUE(y4.e(i) == buffExpY4[i]); + + NDArray y5 = y(3, {2}); + NDArray y5Exp(buffExpY5, shapeExpY5); + ASSERT_TRUE(y5Exp.isSameShape(y5)); + ASSERT_TRUE(y5Exp.equalsTo(y5)); +// for(int i = 0; i < shape::shapeInfoLength(y5.rankOf()); ++i) +// ASSERT_TRUE(y5.shapeInfo()[i] == shapeExpY5[i]); +// for(int i = 0; i < y5.lengthOf(); ++i) +// ASSERT_TRUE(y5.e(i) == buffExpY5[i]); + +} +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayCudaBasicsTests, Test_diagonal_1) { + + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + auto exp = NDArrayFactory::create('c', {2, 1}, {1, 5}); + + auto diag = x.diagonal('c'); + //diag.syncToDevice(); + for (Nd4jLong e = 0; e < exp.lengthOf(); ++e) { + printf("VAL[%ld] = %f\n", e, diag.e(e)); //, exp.e(e), 1.e-5); + } + + for (Nd4jLong e = 0; e < exp.lengthOf(); ++e) { + ASSERT_NEAR(diag.e(e), exp.e(e), 1.e-5); + } + double eps(1.e-5); + NDArray tmp(sd::DataType::FLOAT32, x.getContext()); // scalar = 0 + + ExtraArguments extras({eps}); + NativeOpExecutioner::execReduce3Scalar(diag.getContext(), reduce3::EqualsWithEps, diag.buffer(), + diag.shapeInfo(), diag.specialBuffer(), diag.specialShapeInfo(), extras.argumentsAsT(sd::DataType::FLOAT32), + exp.buffer(), exp.shapeInfo(), exp.specialBuffer(), exp.specialShapeInfo(), + tmp.buffer(), tmp.shapeInfo(), tmp.specialBuffer(), tmp.specialShapeInfo()); + cudaStream_t* stream = x.getContext()->getCudaStream(); + auto res = cudaStreamSynchronize(*stream); + // tmp.printBuffer("Compare result is (expected 0)"); + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); +} + +TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_02) { + auto x = NDArrayFactory::linspace(1.f, 60.f, 60); //('c', {1, 60}); + //x.linspace(1); + auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); + x->reshapei('c', {3, 4, 5}); + + x->permutei({0, 1, 2}); + x->streamline(); + +// x.printShapeInfo("{0, 1, 2} shape"); +// x.printBuffer("{0, 1, 2} data"); + + ASSERT_TRUE(exp.isSameShape(x)); + ASSERT_TRUE(exp.equalsTo(x)); + delete x; +} + +TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_0) { + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); + x.reshapei('c', {3, 4, 5}); + + x.permutei({0, 1, 2}); + x.streamline(); + +// x.printShapeInfo("{0, 1, 2} shape"); +// x.printBuffer("{0, 1, 2} data"); + + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); +} +TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_1) { + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); + x.reshapei('c', {3, 4, 5}); + + x.permutei({0, 1, 2}); + x.streamline(); + +// x.printShapeInfo("{0, 1, 2} shape"); +// x.printBuffer("{0, 1, 2} data"); + + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); +} +TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_2) { + //auto x = NDArrayFactory::create('c', {1, 60}); + auto xx = NDArrayFactory::linspace(1.f, 60.f, 60); //('c', {1, 60}); +// auto x = *xx; + //x.linspace(1); +// auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); +// x.reshapei('c', {3, 4, 5}); + +// x.permutei({0, 1, 2}); +// x.streamline(); + +// x.printShapeInfo("{0, 1, 2} shape"); +// x.printBuffer("{0, 1, 2} data"); + +// ASSERT_TRUE(exp.isSameShape(&x)); +// ASSERT_TRUE(exp.equalsTo(&x)); + delete xx; +} +TEST_F(NDArrayCudaBasicsTests, Test_PermuteEquality_3) { + auto x = NDArrayFactory::create('c', {1, 60}); + //x.linspace(1); + for (int l = 0; l < x.lengthOf(); l++) + x.p(l, float(l + 1.f)); + auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f, 33.0f, 34.0f, 35.0f, 36.0f, 37.0f, 38.0f, 39.0f, 40.0f, 41.0f, 42.0f, 43.0f, 44.0f, 45.0f, 46.0f, 47.0f, 48.0f, 49.0f, 50.0f, 51.0f, 52.0f, 53.0f, 54.0f, 55.0f, 56.0f, 57.0f, 58.0f, 59.0f, 60.0}); + x.reshapei('c', {3, 4, 5}); + + x.permutei({0, 1, 2}); + x.streamline(); + +// x.printShapeInfo("{0, 1, 2} shape"); +// x.printBuffer("{0, 1, 2} data"); + + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); +} + +TEST_F(NDArrayCudaBasicsTests, Test_Empty_1) { + auto x = NDArrayFactory::empty(); + ASSERT_TRUE(x.isActualOnHostSide()); + ASSERT_TRUE(x.isEmpty()); +} + +TEST_F(NDArrayCudaBasicsTests, Test_Empty_2) { + auto x = NDArrayFactory::empty_(); + + ASSERT_TRUE(x->isEmpty()); + delete x; +} + +TEST_F(NDArrayCudaBasicsTests, Test_Empty_3) { + auto x = NDArrayFactory::empty(sd::DataType::FLOAT32); + + ASSERT_TRUE(x.isEmpty()); +} + +TEST_F(NDArrayCudaBasicsTests, Test_Empty_4) { + auto x = NDArrayFactory::empty_(sd::DataType::FLOAT32); + + ASSERT_TRUE(x->isEmpty()); + delete x; +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/NDArrayListTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/NDArrayListTests.cpp new file mode 100644 index 000000000..14ffedb71 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/NDArrayListTests.cpp @@ -0,0 +1,75 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include "testlayers.h" + +using namespace sd; + +class NDArrayListTests : public testing::Test { +public: + +}; + + +TEST_F(NDArrayListTests, BasicTests_1) { + NDArrayList list(false); + + auto x = NDArrayFactory::create('c', {1, 10}); + auto y = NDArrayFactory::create('c', {1, 10}); + + ASSERT_EQ(ND4J_STATUS_OK, list.write(1, new NDArray(x.dup()))); + + //ASSERT_EQ(ND4J_STATUS_DOUBLE_WRITE, list.write(1, &y)); +} + +TEST_F(NDArrayListTests, BasicTests_2) { + NDArrayList list(false); + + auto x = NDArrayFactory::create('c', {1, 10}); + auto y = NDArrayFactory::create('c', {1, 7}); + + ASSERT_EQ(ND4J_STATUS_OK, list.write(1, new NDArray(x.dup()))); + + ASSERT_EQ(ND4J_STATUS_BAD_INPUT, list.write(0, &y)); +} + + +TEST_F(NDArrayListTests, Test_Stack_UnStack_1) { + auto input = NDArrayFactory::create('c', {10, 10}); + input.linspace(1); + + NDArrayList list(false); + + list.unstack(&input, 0); + + ASSERT_EQ(10, list.elements()); + + auto array = list.stack(); + + ASSERT_TRUE(input.isSameShape(array)); + + ASSERT_TRUE(input.equalsTo(array)); + + delete array; +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/NDArrayTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/NDArrayTests.cpp new file mode 100644 index 000000000..46d8e2311 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/NDArrayTests.cpp @@ -0,0 +1,2682 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 04.08.17. +// + +#include "testlayers.h" +#include +#include +#include + +using namespace sd; + +////////////////////////////////////////////////////////////////////// +class NDArrayTest : public testing::Test { +public: + int alpha = 0; + + Nd4jLong *cShape = new Nd4jLong[8]{2, 2, 2, 2, 1, 8192, 1, 99}; + Nd4jLong *fShape = new Nd4jLong[8]{2, 2, 2, 1, 2, 8192, 1, 102}; + + float arr1[6] = {1,2,3,4,5,6}; + Nd4jLong shape1[8] = {2,2,3,3,1,8192,1,99}; + float arr2[48] = {1,2,3,1,2,3,4,5,6,4,5,6,1,2,3,1,2,3,4,5,6,4,5,6,1,2,3,1,2,3,4,5,6,4,5,6,1,2,3,1,2,3,4,5,6,4,5,6}; + Nd4jLong shape2[10] = {3,2,4,6,24,6,1,8192,1,99}; + const std::vector tileShape1 = {2,2,2}; + + + ~NDArrayTest() { + delete[] cShape; + delete[] fShape; + } +}; + + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestDup1) { + + NDArray array(arr1, shape1); + + auto arrC = new NDArray(array.dup('c')); + auto arrF = new NDArray(array.dup('f')); + + ASSERT_TRUE(array.equalsTo(arrF)); + ASSERT_TRUE(array.equalsTo(arrC)); + + ASSERT_TRUE(arrF->equalsTo(arrC)); + + delete arrC; + delete arrF; +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, AssignScalar1) { + auto array = NDArrayFactory::create_('c', {1, 10}); + + array->assign(2.0f); + + for (int i = 0; i < array->lengthOf(); i++) { + ASSERT_EQ(2.0f, array->e(i)); + } + + delete array; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, NDArrayOrder1) { + // original part + auto c = new float[4] {1, 2, 3, 4}; + + // expected part + auto f = new float[4] {1, 3, 2, 4}; + + auto arrayC = new NDArray(c, cShape); + auto arrayF = new NDArray(arrayC->dup('f')); + auto arrayC2 = new NDArray(arrayF->dup('c')); + + arrayF->syncToHost(); + arrayC2->syncToHost(); + + ASSERT_EQ('c', arrayC->ordering()); + ASSERT_EQ('f', arrayF->ordering()); + ASSERT_EQ('c', arrayC2->ordering()); + + for (int i = 0; i < 4; i++) { + ASSERT_NEAR(f[i], arrayF->bufferAsT()[i], 1e-5f); + } + + for (int i = 0; i < 8; i++) { + ASSERT_EQ(fShape[i], arrayF->shapeInfo()[i]); + } + + for (int i = 0; i < 4; i++) { + ASSERT_NEAR(c[i], arrayC2->bufferAsT()[i], 1e-5f); + } + + for (int i = 0; i < 8; i++) { + ASSERT_EQ(cShape[i], arrayC2->shapeInfo()[i]); + } + + + delete[] c; + delete[] f; + delete arrayC; + delete arrayF; + delete arrayC2; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestGetScalar1) { + auto c = new float[4] {1, 2, 3, 4}; + auto cShape = new Nd4jLong[8]{2, 2, 2, 2, 1, 8192, 1, 99}; + + auto arrayC = new NDArray(c, cShape); + + ASSERT_NEAR(3.0f, arrayC->e(1, 0), 1e-5f); + ASSERT_NEAR(4.0f, arrayC->e(1, 1), 1e-5f); + + auto arrayF = new NDArray(arrayC->dup('f')); + + ASSERT_NEAR(3.0f, arrayF->e(1, 0), 1e-5f); + ASSERT_NEAR(4.0f, arrayF->e(1, 1), 1e-5f); + + + arrayF->p(1, 0, 7.0f); + ASSERT_NEAR(7.0f, arrayF->e(1, 0), 1e-5f); + + + arrayC->p(1, 1, 9.0f); + ASSERT_NEAR(9.0f, arrayC->e(1, 1), 1e-5f); + + delete[] c; + delete[] cShape; + + delete arrayC; + delete arrayF; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, EqualityTest1) { + auto arrayA = NDArrayFactory::create_('f', {3, 5}); + auto arrayB = NDArrayFactory::create_('f', {3, 5}); + auto arrayC = NDArrayFactory::create_('f', {3, 5}); + + auto arrayD = NDArrayFactory::create_('f', {2, 4}); + auto arrayE = NDArrayFactory::create_('f', {1, 15}); + + for (int i = 0; i < arrayA->rows(); i++) { + for (int k = 0; k < arrayA->columns(); k++) { + arrayA->p(i, k, (float) i); + } + } + + for (int i = 0; i < arrayB->rows(); i++) { + for (int k = 0; k < arrayB->columns(); k++) { + arrayB->p(i, k, (float) i); + } + } + + for (int i = 0; i < arrayC->rows(); i++) { + for (int k = 0; k < arrayC->columns(); k++) { + arrayC->p(i, k, (float) i+1); + } + } + + //nd4j_printf("A B\n",""); + ASSERT_TRUE(arrayA->equalsTo(arrayB, 1e-5)); + + //nd4j_printf("C B\n",""); + ASSERT_FALSE(arrayC->equalsTo(arrayB, 1e-5)); + + //nd4j_printf("D B\n",""); + ASSERT_FALSE(arrayD->equalsTo(arrayB, 1e-5)); + + //nd4j_printf("E B\n",""); + ASSERT_FALSE(arrayE->equalsTo(arrayB, 1e-5)); + + delete arrayA; + delete arrayB; + delete arrayC; + delete arrayD; + delete arrayE; +} + +TEST_F(NDArrayTest, TestTad1) { + auto array = NDArrayFactory::create_('c', {3, 3}); + + auto row2 = (*array)(1, {0}); + + ASSERT_TRUE(row2.isView()); + ASSERT_EQ(3, row2.lengthOf()); + + row2.assign(1.0); + + ASSERT_NEAR(3.0f, array->sumNumber().e(0), 1e-5); + delete array; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestTad2) { + auto array = NDArrayFactory::create_('c', {3, 3}); + + ASSERT_EQ(3, array->tensorsAlongDimension({1})); + + delete array; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestTad3) { + auto array = NDArrayFactory::create_('c', {4, 3}); + + auto row2 = (*array)(1, {0}); + + ASSERT_TRUE(row2.isView()); + ASSERT_EQ(3, row2.lengthOf()); + delete array; +} + + +TEST_F(NDArrayTest, TestPermuteReshape1) { + + NDArray array('c', {2, 2, 5, 5}, sd::DataType::FLOAT32); + int pShape[] = {4, 2, 5, 5, 2, 25, 5, 1, 50, 8192, 0, 99}; + int rShape[] = {3, 2, 25, 2, 25, 1, 50, 8192, 0, 99}; + + array.permutei({1, 2, 3, 0}); + + for (int e = 0; e < shape::shapeInfoLength(array.shapeInfo()); e++) + ASSERT_EQ(pShape[e], array.shapeInfo()[e]); + + array.reshapei('c', {2, 25, 2}); + + for (int e = 0; e < shape::shapeInfoLength(array.shapeInfo()); e++) + ASSERT_EQ(rShape[e], array.shapeInfo()[e]); +} + + +TEST_F(NDArrayTest, TestPermuteReshape2) { + auto array = NDArrayFactory::create('c', {2, 2, 5, 5, 6, 6}); + int pShape[] = {6, 2, 2, 6, 6, 5, 5, 900, 1800, 6, 1, 180, 36, 8192, 0, 99}; + int rShape[] = {3, 2, 72, 25, 1800, 25, 1, 8192, 1, 99}; + + + // array.printShapeInfo("before"); + + array.permutei({1, 0, 4, 5, 2, 3}); + + // array.printShapeInfo("after "); + + auto aShape = array.shapeInfo(); + + for (int e = 0; e < shape::shapeInfoLength(array.shapeInfo()); e++) + ASSERT_EQ(pShape[e], aShape[e]); + + array.reshapei('c', {2, 72, 25}); + + for (int e = 0; e < shape::shapeInfoLength(array.shapeInfo()); e++) + ASSERT_EQ(rShape[e], array.shapeInfo()[e]); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestRepeat1) { + + auto eBuffer = new float[8] {1.0,2.0,1.0,2.0,3.0,4.0,3.0,4.0}; + auto eShape = new Nd4jLong[8]{2, 4, 2, 2, 1, 8192, 1, 99}; + NDArray array('c', {2, 2}, sd::DataType::FLOAT32); + auto exp = new NDArray(eBuffer, eShape); + for (int e = 0; e < array.lengthOf(); e++) + array.p(e, e + 1); + + // array.printBuffer(); + + auto rep = array.repeat(0, {2}); + + ASSERT_EQ(4, rep.sizeAt(0)); + ASSERT_EQ(2, rep.sizeAt(1)); + + ASSERT_TRUE(exp->equalsTo(rep)); + + delete[] eBuffer; + delete[] eShape; + delete exp; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestRepeat2) { + auto eBuffer = new float[8] {1.0,2.0,1.0,2.0,3.0,4.0,3.0,4.0}; + auto eShape = new Nd4jLong[8]{2, 4, 2, 2, 1, 8192, 1, 99}; + auto array = NDArrayFactory::create_('c', {2, 2}); + auto exp = new NDArray(eBuffer, eShape); + for (int e = 0; e < array->lengthOf(); e++) + array->p(e, e + 1); + + //array->printBuffer(); + + auto rep = new NDArray(exp->dup()); + rep->assign(0.); + array->repeat(0, {2}, *rep); + //rep->printIndexedBuffer("Repeated"); + + ASSERT_EQ(4, rep->sizeAt(0)); + ASSERT_EQ(2, rep->sizeAt(1)); + + //rep->printBuffer(); + + ASSERT_TRUE(exp->equalsTo(rep)); + + delete[] eBuffer; + delete[] eShape; + delete array; + delete exp; + delete rep; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestIndexedPut1) { + auto array = NDArrayFactory::create_('f', {3, 3}); + + array->p(4, 1.0f); + ASSERT_EQ(1.0f, array->e(4)); + //array->printBuffer(); + + delete array; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestSum1) { + // Nd4jLong *cShape = new Nd4jLong[8]{2, 2, 2, 2, 1, 8192, 1, 99}; + float *c = new float[4] {1, 2, 3, 4}; + + auto array = new NDArray(c, cShape); + + ASSERT_EQ(10.0f, array->sumNumber().e(0)); + ASSERT_EQ(2.5f, array->meanNumber().e(0)); + + delete[] c; + delete array; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestAddiRowVector) { + float *c = new float[4] {1, 2, 3, 4}; + float *e = new float[4] {2, 3, 4, 5}; + + auto array = new NDArray(c, cShape); + auto row = NDArrayFactory::create_('c', {1, 2}); + auto exp = new NDArray(e, cShape); + row->assign(1.0f); + + array->addiRowVector(*row); + + ASSERT_TRUE(exp->equalsTo(array)); + + delete[] c; + delete[] e; + + delete array; + delete row; + delete exp; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestAddiColumnVector) { + float arr1[] = {1, 2, 3, 4}; + float arr2[] = {5, 6}; + float arr3[] = {6, 7, 9, 10}; + Nd4jLong shape1[] = {2,2,2,2,1,8192,1,99}; + Nd4jLong shape2[] = {2,2,1,1,1,8192,1,99}; + NDArray matrix(arr1, shape1); + NDArray column(arr2, shape2); + NDArray exp(arr3, shape1); + + matrix.addiColumnVector(column); + ASSERT_TRUE(exp.isSameShapeStrict(matrix)); + ASSERT_TRUE(exp.equalsTo(&matrix)); +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestMuliColumnVector) { + float arr1[] = {1, 2, 3, 4}; + float arr2[] = {5, 6}; + float arr3[] = {5, 10, 18, 24}; + Nd4jLong shape1[] = {2,2,2,2,1,8192,1,99}; + Nd4jLong shape2[] = {2,2,1,1,1,8192,1,99}; + NDArray matrix(arr1, shape1); + NDArray column(arr2, shape2); + NDArray exp(arr3, shape1); + + matrix.muliColumnVector(column); + + ASSERT_TRUE(exp.isSameShapeStrict(matrix)); + ASSERT_TRUE(exp.equalsTo(&matrix)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Test3D_1) { + auto arrayC = NDArrayFactory::create_('c', {2, 5, 10}); + auto arrayF = NDArrayFactory::create_('f', {2, 5, 10}); + + ASSERT_EQ(100, arrayC->lengthOf()); + ASSERT_EQ(100, arrayF->lengthOf()); + + ASSERT_EQ('c', arrayC->ordering()); + ASSERT_EQ('f', arrayF->ordering()); + + delete arrayC; + delete arrayF; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestTranspose1) { + auto arrayC = NDArrayFactory::create_('c', {2, 5, 10}); + + auto expC = new Nd4jLong[10] {3, 2, 5, 10, 50, 10, 1, 16384, 1, 99}; + auto expT = new Nd4jLong[10] {3, 10, 5, 2, 1, 10, 50, 16384, 1, 102}; + + auto arrayT = arrayC->transpose(); + + for (int e = 0; e < arrayC->rankOf(); e++) { + ASSERT_EQ(shape::shapeOf(expC)[e], arrayC->sizeAt(e)); + ASSERT_EQ(shape::shapeOf(expT)[e], arrayT.sizeAt(e)); + } + + delete arrayC; + delete[] expC; + delete[] expT; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestTranspose2) { + auto arrayC = NDArrayFactory::create_('c', {2, 5, 10}); + + auto expC = new Nd4jLong[10] {3, 2, 5, 10, 50, 10, 1, 16384, 1, 99}; + auto expT = new Nd4jLong[10] {3, 10, 5, 2, 1, 10, 50, 16384, 1, 102}; + + arrayC->transposei(); + + + for (int e = 0; e < arrayC->rankOf(); e++) { + ASSERT_EQ(shape::shapeOf(expT)[e], arrayC->sizeAt(e)); + } + + delete arrayC; + delete[] expC; + delete[] expT; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestReduceAlongDimension1) { + + NDArray array('c', {2,2}, {1,2,3,4}, sd::DataType::FLOAT32); + + auto res = array.reduceAlongDimension(reduce::Sum, {0}); + + ASSERT_EQ(2, res.lengthOf()); + + ASSERT_EQ(4.0f, res.e(0)); + ASSERT_EQ(6.0f, res.e(1)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestReduceAlongDimension2) { + float *c = new float[4] {1, 2, 3, 4}; + auto array = new NDArray(c, cShape); + + auto res = array->reduceAlongDimension(reduce::Sum, {1}); + + ASSERT_EQ(2, res.lengthOf()); + + ASSERT_EQ(3.0f, res.e(0)); + ASSERT_EQ(7.0f, res.e(1)); + + delete[] c; + delete array; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestTransform1) { + float *c = new float[4] {-1, -2, -3, -4}; + auto array = new NDArray(c, cShape); + + float *e = new float[4] {1, 2, 3, 4}; + auto exp = new NDArray(e, cShape); + + array->applyTransform(transform::Abs, *array); + + ASSERT_TRUE(exp->equalsTo(array)); + + delete[] c; + delete array; + delete[] e; + delete exp; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestReduceScalar1) { + float *c = new float[4] {-1, -2, -3, -4}; + auto array = new NDArray(c, cShape); + + ASSERT_EQ(-4, array->reduceNumber(reduce::Min, nullptr).e(0)); + + delete[] c; + delete array; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestReduceScalar2) { + float *c = new float[4] {-1, -2, -3, -4}; + auto array = new NDArray(c, cShape); + + ASSERT_EQ(-10, array->reduceNumber(reduce::Sum, nullptr).e(0)); + + delete[] c; + delete array; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestReduceScalar3) { + auto array = new NDArray(arr1, shape1); + + ASSERT_EQ(21, array->reduceNumber(reduce::Sum, nullptr).e(0)); + + delete array; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestApplyTransform1) { + float *c = new float[4] {-1, -2, -3, -4}; + auto array = new NDArray(c, cShape); + + float *e = new float[4] {1, 2, 3, 4}; + auto exp = new NDArray(e, cShape); + + array->applyTransform(transform::Abs, *array); + + + ASSERT_TRUE(exp->equalsTo(array)); + + delete[] c; + delete array; + + delete[] e; + delete exp; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestVectors1) { + float *c = new float[4]{-1, -2, -3, -4}; + auto array = new NDArray(c, cShape); + + + auto vecShape = array->getShapeInfoAsVector(); + auto vecBuffer = array->getBufferAsVector(); + + ASSERT_EQ(8, vecShape.size()); + ASSERT_EQ(4, vecBuffer.size()); + + for (int e = 0; e < vecBuffer.size(); e++) { + ASSERT_NEAR(c[e], vecBuffer[e], 1e-5); + } + + for (int e = 0; e < vecShape.size(); e++) { + ASSERT_EQ(cShape[e], vecShape[e]); + } + + delete[] c; + delete array; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestChecks1) { + auto array = NDArrayFactory::create('c', {1, 5}); + + ASSERT_FALSE(array.isMatrix()); + ASSERT_FALSE(array.isScalar()); + ASSERT_TRUE(array.isVector()); + ASSERT_FALSE(array.isColumnVector()); + ASSERT_TRUE(array.isRowVector()); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestChecks2) { + auto array = NDArrayFactory::create('c', {5, 5}); + + ASSERT_TRUE(array.isMatrix()); + ASSERT_FALSE(array.isScalar()); + ASSERT_FALSE(array.isVector()); + ASSERT_FALSE(array.isColumnVector()); + ASSERT_FALSE(array.isRowVector()); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestChecks3) { + auto array = NDArrayFactory::create('c', {5, 1}); + + ASSERT_FALSE(array.isMatrix()); + ASSERT_FALSE(array.isScalar()); + ASSERT_TRUE(array.isVector()); + ASSERT_TRUE(array.isColumnVector()); + ASSERT_FALSE(array.isRowVector()); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestChecks4) { + auto array = NDArrayFactory::create('c', {1, 1}); + + ASSERT_FALSE(array.isMatrix()); + ASSERT_FALSE(array.isVector()); + ASSERT_FALSE(array.isColumnVector()); + ASSERT_FALSE(array.isRowVector()); + ASSERT_TRUE(array.isScalar()); +} + +TEST_F(NDArrayTest, TestReductionAny1) { + auto array = NDArrayFactory::create('c', {2, 2}); + array.p(0, 1.0f); + array.p(1, 1.0f); + array.p(2, 0.0f); + array.p(3, 0.0f); + array.syncToDevice(); + auto result0 = array.reduceAlongDimension(reduce::Any, {0}); + + ASSERT_EQ(2, result0.lengthOf()); + + ASSERT_NEAR(1.0f, result0.e(0), 1e-5f); + ASSERT_NEAR(1.0f, result0.e(1), 1e-5f); + + auto result1 = array.reduceAlongDimension(reduce::Any, {1}); + + ASSERT_EQ(2, result1.lengthOf()); + + ASSERT_NEAR(1.0f, result1.e(0), 1e-5f); + ASSERT_NEAR(0.0f, result1.e(1), 1e-5f); +} + +TEST_F(NDArrayTest, TestReductionAll1) { + auto array = NDArrayFactory::create('c', {2, 2}); + array.p(0, 1.0f); + array.p(1, 1.0f); + array.p(2, 0.0f); + array.p(3, 0.0f); + + auto result0 = array.reduceAlongDimension(reduce::All, {0}); + auto result1 = array.reduceAlongDimension(reduce::All, {1}); + + ASSERT_EQ(2, result0.lengthOf()); + ASSERT_EQ(2, result1.lengthOf()); + + ASSERT_FALSE(result0.e(0)); + ASSERT_FALSE(result0.e(1)); + + ASSERT_TRUE(result1.e(0)); + ASSERT_FALSE(result1.e(1)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestChecks5) { + auto array = NDArrayFactory::create('c', {5, 5, 5}); + + ASSERT_FALSE(array.isMatrix()); + ASSERT_FALSE(array.isVector()); + ASSERT_FALSE(array.isColumnVector()); + ASSERT_FALSE(array.isRowVector()); + ASSERT_FALSE(array.isScalar()); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestTile1) { + + // float arr1[6] = {1,2,3,4,5,6}; + // Nd4jLong shape1[8] = {2,2,3,3,1,8192,1,99}; + // float arr2[48] = {1,2,3,1,2,3,4,5,6,4,5,6,1,2,3,1,2,3,4,5,6,4,5,6,1,2,3,1,2,3,4,5,6,4,5,6,1,2,3,1,2,3,4,5,6,4,5,6}; + // Nd4jLong shape2[10] = {3,2,4,6,24,6,1,8192,1,99}; + + NDArray array1(arr1,shape1); // {2,3} + NDArray array2(arr2,shape2); // {2,4,6} + auto expA = new NDArray(array1.dup('c')); + + auto tiled = array1.tile(tileShape1); + + // array2.printShapeInfo("Expct shape"); + // tiled.printShapeInfo("Tiled shape"); + // tiled.printBuffer(); + + ASSERT_TRUE(tiled.isSameShape(&array2)); + ASSERT_TRUE(tiled.equalsTo(&array2)); + + ASSERT_TRUE(expA->isSameShape(&array1)); + ASSERT_TRUE(expA->equalsTo(&array1)); + + // delete tiled; + delete expA; +} + +TEST_F(NDArrayTest, TestTile2) { + + NDArray array1(arr1,shape1); + NDArray array2(arr2,shape2); + + auto tiled = array1.tile(tileShape1); + + ASSERT_TRUE(tiled.isSameShape(&array2)); + ASSERT_TRUE(tiled.equalsTo(&array2)); + // delete tiled; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestTile3) { + + NDArray array1(arr1,shape1); + NDArray array2(arr2,shape2); + + array1.tilei(tileShape1); + + ASSERT_TRUE(array1.isSameShapeStrict(array2)); + ASSERT_TRUE(array1.equalsTo(&array2)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestTile4) { + + float xBuff[] = {1,2,3,4,5,6}; + float expBuff[] = {1.f,2.f, 1.f,2.f, 3.f,4.f, 3.f,4.f, 5.f,6.f, 5.f,6.f}; + + auto x = NDArrayFactory::create(xBuff, 'c', {3,1,2}); + auto exp = NDArrayFactory::create(expBuff, 'c', {3,2,2}); + + auto result = x.tile({2,1}); + + ASSERT_TRUE(result.isSameShapeStrict(exp)); + ASSERT_TRUE(result.equalsTo(&exp)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestTile5) { + + float xBuff[] = {1,2,3,4,5,6,7,8,9,10,11,12}; + float expBuff[] = {1.f, 2.f, 3.f, 4.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 5.f, 6.f, 7.f, 8.f, 9.f,10.f, 11.f,12.f, 9.f,10.f, 11.f,12.f}; + + auto x = NDArrayFactory::create(xBuff, 'c', {3,2,2}); + auto exp = NDArrayFactory::create(expBuff, 'c', {3,4,2}); + + auto result = x.tile({2,1}); + + ASSERT_TRUE(result.isSameShapeStrict(exp)); + ASSERT_TRUE(result.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestTile6) +{ + double expBuff[] = {10.,11., 10.,11., 10.,11., 10.,11., 12.,13., 12.,13., 12.,13., 12.,13., 14.,15., 14.,15., 14.,15., 14.,15.}; + + auto x = NDArrayFactory::create('c', {3, 1, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 2}); + + x.linspace(10); + + auto result = x.tile({1,4,1}); + + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestMmulHelper1) { + auto xBuffer = new float[3]{1.f, 2.f, 3.f}; + auto xShape = new Nd4jLong[8] {2, 1, 3, 1, 1, 8192, 1, 99}; + auto x = new NDArray(xBuffer, xShape); + + auto yBuffer = new float[3]{2.f, 4.f, 6.f}; + auto yShape = new Nd4jLong[8] {2, 1, 3, 1, 1, 8192, 1, 99}; + auto y = new NDArray(yBuffer, yShape); + + auto z = MmulHelper::mmul(x, y); + + ASSERT_EQ(1, z->lengthOf()); + ASSERT_NEAR(28, z->e(0), 1e-5); + + delete z; + delete[] xBuffer; + delete[] xShape; + delete[] yBuffer; + delete[] yShape; + delete y; + delete x; +} + + +TEST_F(NDArrayTest, TestPermuteReshapeMmul1) { + auto x = NDArrayFactory::create('c', {6, 3}); + auto y = NDArrayFactory::create('c', {3, 6}); + + Nd4jLong _expS[] = {2, 3, 3, 1, 3, 8192, 1, 102}; + float _expB[] = {231.0f, 252.0f, 273.0f, 537.0f, 594.0f, 651.0f, 843.0f, 936.0f, 1029.0f}; + NDArray exp(_expB, _expS); + + for (int e = 0; e < x.lengthOf(); e++) + x.p(e, e+1); + + for (int e = 0; e < y.lengthOf(); e++) + y.p(e, e+1); + + x.permutei({1, 0}); + y.permutei({1, 0}); + + auto z = MmulHelper::mmul(&x, &y); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete z; +} + +TEST_F(NDArrayTest, TestPermuteReshapeMmul2) { + auto x = NDArrayFactory::create('c', {6, 3}); + auto y = NDArrayFactory::create('c', {3, 6}); + + Nd4jLong _expS[] = {2, 3, 3, 1, 3, 8192, 1, 102}; + float _expB[] = {231.0f, 252.0f, 273.0f, 537.0f, 594.0f, 651.0f, 843.0f, 936.0f, 1029.0f}; + NDArray exp(_expB, _expS); + + for (int e = 0; e < x.lengthOf(); e++) + x.p(e, e+1); + + for (int e = 0; e < y.lengthOf(); e++) + y.p(e, e+1); + + auto x_ = new NDArray(x.dup('f')); + auto y_ = new NDArray(y.dup('f')); + + x_->permutei({1, 0}); + y_->permutei({1, 0}); + + auto z = MmulHelper::mmul(x_, y_); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete z; + delete x_; + delete y_; +} + + +TEST_F(NDArrayTest, TestPermuteReshapeMmul3) { + auto x = NDArrayFactory::create('c', {2, 2, 2, 3, 2, 2}); + auto y = NDArrayFactory::create('c', {2, 3, 2 ,2}); + + Nd4jLong _expS[] = {2, 8, 2, 1, 8, 8192, 1, 102}; + float _expB[] = {1624.0f, 1858.0f, 2092.0f, 2326.0f, 5368.0f, 5602.0f, 5836.0f, 6070.0f, 4504.0f, 5170.0f, 5836.0f, 6502.0f, 15160.0f, 15826.0f, 16492.0f, 17158.0f}; + NDArray exp(_expB, _expS); + + for (int e = 0; e < x.lengthOf(); e++) + x.p(e, e+1); + + for (int e = 0; e < y.lengthOf(); e++) + y.p(e, e+1); + + x.permutei({0, 3, 4, 5, 1, 2}); + y.permutei({3, 2, 1, 0}); + + x.reshapei('c', {2 * 2 * 2, 3 * 2 * 2}); + y.reshapei('c', {2 * 2 * 3, 2}); + + auto z = MmulHelper::mmul(&x, &y); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete z; +} + +TEST_F(NDArrayTest, TestPermuteReshapeMmul4) { + auto x = NDArrayFactory::create('c', {2, 2, 2, 3, 2, 2}); + auto y = NDArrayFactory::create('c', {2, 3, 2 ,2}); + + Nd4jLong _expS[] = {2, 8, 2, 1, 8, 8192, 1, 102}; + float _expB[] = {1624.0f, 1858.0f, 2092.0f, 2326.0f, 5368.0f, 5602.0f, 5836.0f, 6070.0f, 4504.0f, 5170.0f, 5836.0f, 6502.0f, 15160.0f, 15826.0f, 16492.0f, 17158.0f}; + NDArray exp(_expB, _expS); + + for (int e = 0; e < x.lengthOf(); e++) + x.p(e, e+1); + + for (int e = 0; e < y.lengthOf(); e++) + y.p(e, e+1); + + auto y_ = new NDArray(y.dup('f')); + + x.permutei({0, 3, 4, 5, 1, 2}); + y_->permutei({3, 2, 1, 0}); + + x.reshapei('c', {2 * 2 * 2, 3 * 2 * 2}); + y_->reshapei('c', {2 * 2 * 3, 2}); + + auto z = MmulHelper::mmul(&x, y_); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete z; + delete y_; +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestMmulHelper2) { + auto xBuffer = new float[15]{1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}; + Nd4jLong xShape[8] = {2, 5, 3, 3, 1, 8192, 1, 99}; + auto x = new NDArray(xBuffer, xShape, sd::LaunchContext ::defaultContext(), true); + + + auto yBuffer = new float[3]{2.f, 4.f, 6.f}; + Nd4jLong yShape[8] = {2, 3, 1, 1, 1, 8192, 1, 99}; + auto y = new NDArray(yBuffer, yShape, sd::LaunchContext ::defaultContext(), true); + + auto z = NDArrayFactory::create_('f', {5, 1}); + + auto expBuffer = new float[5]{28.00f, 64.00f, 100.00f, 136.00f, 172.00f}; + auto exp = new NDArray(expBuffer, z->shapeInfo(), sd::LaunchContext ::defaultContext(), true); + + //sd::blas::GEMV::op('f', x->rows(), x->columns(), 1.0f, x->buffer(), y->rows(), y->buffer(), 1, 0.0, z->buffer(), 1); + + MmulHelper::mmul(x, y, z); + + //z->printBuffer(); + + ASSERT_TRUE(z->equalsTo(exp)); + + delete x; + delete y; + delete z; + delete exp; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestMmulHelper3) { + auto xBuffer = new float[15]{1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f}; + auto xShape = new Nd4jLong[8] {2, 5, 3, 1, 5, 8192, 1, 102}; + auto x = new NDArray(xBuffer, xShape); + + auto yBuffer = new float[3]{2.f, 4.f, 6.f}; + auto yShape = new Nd4jLong[8] {2, 3, 1, 1, 1, 8192, 1, 99}; + auto y = new NDArray(yBuffer, yShape); + + auto z = NDArrayFactory::create_('f', {5, 1}); + + auto expBuffer = new float[5]{92.00f, 104.00f, 116.00f, 128.00f, 140.00f}; + auto exp = new NDArray(expBuffer, z->shapeInfo()); + + //sd::blas::GEMV::op('f', x->rows(), x->columns(), 1.0f, x->buffer(), y->rows(), y->buffer(), 1, 0.0, z->buffer(), 1); + + MmulHelper::mmul(x, y, z); + + //z->printBuffer(); + + ASSERT_TRUE(z->equalsTo(exp)); + + delete[] expBuffer; + delete[] xBuffer; + delete[] yBuffer; + delete[] xShape; + delete[] yShape; + + delete x; + delete y; + delete z; + delete exp; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestMmulHelper4) { + auto xBuffer = new float[6]{1, 2, 3, 4, 5, 6}; + auto xShape = new Nd4jLong[8] {2, 3, 2, 2, 1, 8192, 1, 99}; + auto x = new NDArray(xBuffer, xShape); + + auto yBuffer = new float[6]{7, 8, 9, 0, 1, 2}; + auto yShape = new Nd4jLong[8] {2, 2, 3, 3, 1, 8192, 1, 99}; + auto y = new NDArray(yBuffer, yShape); + + auto z = NDArrayFactory::create_('f', {3, 3}); + + auto expBuffer = new float[9]{7.0f, 21.0f, 35.0f, 10.0f, 28.0f, 46.0f, 13.0f, 35.0f, 57.0f}; + auto exp = new NDArray(expBuffer, z->shapeInfo()); + + MmulHelper::mmul(x, y, z); + ASSERT_TRUE(z->equalsTo(exp)); + + delete[] expBuffer; + delete[] xBuffer; + delete[] yBuffer; + delete[] xShape; + delete[] yShape; + + delete x; + delete y; + delete z; + delete exp; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestMmulHelper5) { + auto xBuffer = new float[6]{1, 2, 3, 4, 5, 6}; + auto xShape = new Nd4jLong[8] {2, 3, 2, 1, 3, 8192, 1, 102}; + auto x = new NDArray(xBuffer, xShape); + + auto yBuffer = new float[6]{7, 8, 9, 0, 1, 2}; + auto yShape = new Nd4jLong[8] {2, 2, 3, 3, 1, 8192, 1, 99}; + auto y = new NDArray(yBuffer, yShape); + + auto z = NDArrayFactory::create_('f', {3, 3}); + + auto expBuffer = new float[9]{7.0f, 14.0f, 21.0f, 12.0f, 21.0f, 30.0f, 17.0f, 28.0f, 39.0f}; + auto exp = new NDArray(expBuffer, z->shapeInfo()); + + MmulHelper::mmul(x, y, z); + ASSERT_TRUE(z->equalsTo(exp)); + + delete[] expBuffer; + delete[] xBuffer; + delete[] yBuffer; + delete[] xShape; + delete[] yShape; + + delete x; + delete y; + delete z; + delete exp; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestMmulHelper6) { + auto xBuffer = new float[6]{1, 2, 3, 4, 5, 6}; + auto xShape = new Nd4jLong[8] {2, 3, 2, 1, 3, 8192, 1, 102}; + auto x = new NDArray(xBuffer, xShape); + + auto yBuffer = new float[6]{7, 8, 9, 0, 1, 2}; + auto yShape = new Nd4jLong[8] {2, 2, 3, 1, 2, 8192, 1, 102}; + auto y = new NDArray(yBuffer, yShape); + + auto z = NDArrayFactory::create_('f', {3, 3}); + + auto expBuffer = new float[9]{39.0f, 54.0f, 69.0f, 9.0f, 18.0f, 27.0f, 9.0f, 12.0f, 15.0f}; + auto exp = new NDArray(expBuffer, z->shapeInfo()); + + MmulHelper::mmul(x, y, z); + ASSERT_TRUE(z->equalsTo(exp)); + + + delete[] expBuffer; + delete[] xBuffer; + delete[] yBuffer; + delete[] xShape; + delete[] yShape; + + delete x; + delete y; + delete z; + delete exp; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestMmulHelper7) { + auto xBuffer = new float[15]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + auto xShape = new Nd4jLong[8] {2, 5, 3, 1, 5, 8192, 1, 102}; + auto x = new NDArray(xBuffer, xShape); + + auto yBuffer = new float[5]{2, 4, 6, 8, 10}; + auto yShape = new Nd4jLong[8] {2, 1, 5, 1, 1, 8192, 1, 99}; + auto y = new NDArray(yBuffer, yShape); + + auto z = NDArrayFactory::create_('f', {1, 3}); + + auto expBuffer = new float[9]{110.00f, 260.00f, 410.00f}; + auto exp = new NDArray(expBuffer, z->shapeInfo()); + + MmulHelper::mmul(y, x, z); + + //z->printBuffer(); + ASSERT_TRUE(z->equalsTo(exp)); + + delete[] expBuffer; + delete[] xBuffer; + delete[] yBuffer; + delete[] xShape; + delete[] yShape; + + delete x; + delete y; + delete z; + delete exp; +} + + +TEST_F(NDArrayTest, TestMmulHelper_ND_1) { + Nd4jLong _expS[] = {3, 2, 3, 3, 9, 3, 1, 8192, 1, 99}; + float _expB[] = {70.f, 80.f, 90.f, 158.f, 184.f, 210.f, 246.f, 288.f, 330.f, 1030.f, 1088.f, 1146.f, 1310.f, 1384.f, 1458.f, 1590.f, 1680.f, 1770.f}; + + auto a = NDArrayFactory::create('c', {2, 3, 4}); + for (int e = 0; e < a.lengthOf(); e++) + a.p(e, e+1); + + auto b = NDArrayFactory::create('c', {2, 4, 3}); + for (int e = 0; e < b.lengthOf(); e++) + b.p(e, e+1); + + NDArray exp(_expB, _expS); + auto c = MmulHelper::mmul(&a, &b); + + ASSERT_TRUE(exp.isSameShape(c)); + ASSERT_TRUE(exp.equalsTo(c)); + + delete c; +} + + +TEST_F(NDArrayTest, TestMmulHelper_ND_2) { + Nd4jLong _expS[] = {3, 2, 72, 2, 144, 2, 1, 8192, 1, 99}; + float _expB[] = { + 1.07250000e+04f, 1.10500000e+04f, 2.63500000e+04f, 2.73000000e+04f, 4.19750000e+04f, 4.35500000e+04f, + 5.76000000e+04f, 5.98000000e+04f, 7.32250000e+04f, 7.60500000e+04f, 8.88500000e+04f, 9.23000000e+04f, + 1.04475000e+05f, 1.08550000e+05f, 1.20100000e+05f, 1.24800000e+05f, 1.35725000e+05f, 1.41050000e+05f, + 1.51350000e+05f, 1.57300000e+05f, 1.66975000e+05f, 1.73550000e+05f, 1.82600000e+05f, 1.89800000e+05f, + 1.98225000e+05f, 2.06050000e+05f, 2.13850000e+05f, 2.22300000e+05f, 2.29475000e+05f, 2.38550000e+05f, + 2.45100000e+05f, 2.54800000e+05f, 2.60725000e+05f, 2.71050000e+05f, 2.76350000e+05f, 2.87300000e+05f, + 2.91975000e+05f, 3.03550000e+05f, 3.07600000e+05f, 3.19800000e+05f, 3.23225000e+05f, 3.36050000e+05f, + 3.38850000e+05f, 3.52300000e+05f, 3.54475000e+05f, 3.68550000e+05f, 3.70100000e+05f, 3.84800000e+05f, + 3.85725000e+05f, 4.01050000e+05f, 4.01350000e+05f, 4.17300000e+05f, 4.16975000e+05f, 4.33550000e+05f, + 4.32600000e+05f, 4.49800000e+05f, 4.48225000e+05f, 4.66050000e+05f, 4.63850000e+05f, 4.82300000e+05f, + 4.79475000e+05f, 4.98550000e+05f, 4.95100000e+05f, 5.14800000e+05f, 5.10725000e+05f, 5.31050000e+05f, + 5.26350000e+05f, 5.47300000e+05f, 5.41975000e+05f, 5.63550000e+05f, 5.57600000e+05f, 5.79800000e+05f, + 5.73225000e+05f, 5.96050000e+05f, 5.88850000e+05f, 6.12300000e+05f, 6.04475000e+05f, 6.28550000e+05f, + 6.20100000e+05f, 6.44800000e+05f, 6.35725000e+05f, 6.61050000e+05f, 6.51350000e+05f, 6.77300000e+05f, + 6.66975000e+05f, 6.93550000e+05f, 6.82600000e+05f, 7.09800000e+05f, 6.98225000e+05f, 7.26050000e+05f, + 7.13850000e+05f, 7.42300000e+05f, 7.29475000e+05f, 7.58550000e+05f, 7.45100000e+05f, 7.74800000e+05f, + 7.60725000e+05f, 7.91050000e+05f, 7.76350000e+05f, 8.07300000e+05f, 7.91975000e+05f, 8.23550000e+05f, + 8.07600000e+05f, 8.39800000e+05f, 8.23225000e+05f, 8.56050000e+05f, 8.38850000e+05f, 8.72300000e+05f, + 8.54475000e+05f, 8.88550000e+05f, 8.70100000e+05f, 9.04800000e+05f, 8.85725000e+05f, 9.21050000e+05f, + 9.01350000e+05f, 9.37300000e+05f, 9.16975000e+05f, 9.53550000e+05f, 9.32600000e+05f, 9.69800000e+05f, + 9.48225000e+05f, 9.86050000e+05f, 9.63850000e+05f, 1.00230000e+06f, 9.79475000e+05f, 1.01855000e+06f, + 9.95100000e+05f, 1.03480000e+06f, 1.01072500e+06f, 1.05105000e+06f, 1.02635000e+06f, 1.06730000e+06f, + 1.04197500e+06f, 1.08355000e+06f, 1.05760000e+06f, 1.09980000e+06f, 1.07322500e+06f, 1.11605000e+06f, + 1.08885000e+06f, 1.13230000e+06f, 1.10447500e+06f, 1.14855000e+06f, 1.12010000e+06f, 1.16480000e+06f, + 1.13572500e+06f, 1.18105000e+06f, 1.15135000e+06f, 1.19730000e+06f, 1.16697500e+06f, 1.21355000e+06f, + 3.54260000e+06f, 3.58980000e+06f, 3.58947500e+06f, 3.63730000e+06f, 3.63635000e+06f, 3.68480000e+06f, + 3.68322500e+06f, 3.73230000e+06f, 3.73010000e+06f, 3.77980000e+06f, 3.77697500e+06f, 3.82730000e+06f, + 3.82385000e+06f, 3.87480000e+06f, 3.87072500e+06f, 3.92230000e+06f, 3.91760000e+06f, 3.96980000e+06f, + 3.96447500e+06f, 4.01730000e+06f, 4.01135000e+06f, 4.06480000e+06f, 4.05822500e+06f, 4.11230000e+06f, + 4.10510000e+06f, 4.15980000e+06f, 4.15197500e+06f, 4.20730000e+06f, 4.19885000e+06f, 4.25480000e+06f, + 4.24572500e+06f, 4.30230000e+06f, 4.29260000e+06f, 4.34980000e+06f, 4.33947500e+06f, 4.39730000e+06f, + 4.38635000e+06f, 4.44480000e+06f, 4.43322500e+06f, 4.49230000e+06f, 4.48010000e+06f, 4.53980000e+06f, + 4.52697500e+06f, 4.58730000e+06f, 4.57385000e+06f, 4.63480000e+06f, 4.62072500e+06f, 4.68230000e+06f, + 4.66760000e+06f, 4.72980000e+06f, 4.71447500e+06f, 4.77730000e+06f, 4.76135000e+06f, 4.82480000e+06f, + 4.80822500e+06f, 4.87230000e+06f, 4.85510000e+06f, 4.91980000e+06f, 4.90197500e+06f, 4.96730000e+06f, + 4.94885000e+06f, 5.01480000e+06f, 4.99572500e+06f, 5.06230000e+06f, 5.04260000e+06f, 5.10980000e+06f, + 5.08947500e+06f, 5.15730000e+06f, 5.13635000e+06f, 5.20480000e+06f, 5.18322500e+06f, 5.25230000e+06f, + 5.23010000e+06f, 5.29980000e+06f, 5.27697500e+06f, 5.34730000e+06f, 5.32385000e+06f, 5.39480000e+06f, + 5.37072500e+06f, 5.44230000e+06f, 5.41760000e+06f, 5.48980000e+06f, 5.46447500e+06f, 5.53730000e+06f, + 5.51135000e+06f, 5.58480000e+06f, 5.55822500e+06f, 5.63230000e+06f, 5.60510000e+06f, 5.67980000e+06f, + 5.65197500e+06f, 5.72730000e+06f, 5.69885000e+06f, 5.77480000e+06f, 5.74572500e+06f, 5.82230000e+06f, + 5.79260000e+06f, 5.86980000e+06f, 5.83947500e+06f, 5.91730000e+06f, 5.88635000e+06f, 5.96480000e+06f, + 5.93322500e+06f, 6.01230000e+06f, 5.98010000e+06f, 6.05980000e+06f, 6.02697500e+06f, 6.10730000e+06f, + 6.07385000e+06f, 6.15480000e+06f, 6.12072500e+06f, 6.20230000e+06f, 6.16760000e+06f, 6.24980000e+06f, + 6.21447500e+06f, 6.29730000e+06f, 6.26135000e+06f, 6.34480000e+06f, 6.30822500e+06f, 6.39230000e+06f, + 6.35510000e+06f, 6.43980000e+06f, 6.40197500e+06f, 6.48730000e+06f, 6.44885000e+06f, 6.53480000e+06f, + 6.49572500e+06f, 6.58230000e+06f, 6.54260000e+06f, 6.62980000e+06f, 6.58947500e+06f, 6.67730000e+06f, + 6.63635000e+06f, 6.72480000e+06f, 6.68322500e+06f, 6.77230000e+06f, 6.73010000e+06f, 6.81980000e+06f, + 6.77697500e+06f, 6.86730000e+06f, 6.82385000e+06f, 6.91480000e+06f, 6.87072500e+06f, 6.96230000e+06f, + 6.91760000e+06f, 7.00980000e+06f, 6.96447500e+06f, 7.05730000e+06f, 7.01135000e+06f, 7.10480000e+06f, + 1.17619750e+07f, 1.18560500e+07f, 1.18401000e+07f, 1.19348000e+07f, 1.19182250e+07f, 1.20135500e+07f, + 1.19963500e+07f, 1.20923000e+07f, 1.20744750e+07f, 1.21710500e+07f, 1.21526000e+07f, 1.22498000e+07f, 1.22307250e+07f, 1.23285500e+07f, 1.23088500e+07f, 1.24073000e+07f, 1.23869750e+07f, 1.24860500e+07f, 1.24651000e+07f, 1.25648000e+07f, 1.25432250e+07f, 1.26435500e+07f, 1.26213500e+07f, 1.27223000e+07f, 1.26994750e+07f, 1.28010500e+07f, 1.27776000e+07f, 1.28798000e+07f, 1.28557250e+07f, 1.29585500e+07f, 1.29338500e+07f, 1.30373000e+07f, 1.30119750e+07f, 1.31160500e+07f, 1.30901000e+07f, 1.31948000e+07f, 1.31682250e+07f, 1.32735500e+07f, 1.32463500e+07f, 1.33523000e+07f, 1.33244750e+07f, 1.34310500e+07f, 1.34026000e+07f, 1.35098000e+07f, 1.34807250e+07f, 1.35885500e+07f, 1.35588500e+07f, 1.36673000e+07f, 1.36369750e+07f, 1.37460500e+07f, 1.37151000e+07f, 1.38248000e+07f, 1.37932250e+07f, 1.39035500e+07f, 1.38713500e+07f, 1.39823000e+07f, 1.39494750e+07f, 1.40610500e+07f, 1.40276000e+07f, 1.41398000e+07f, 1.41057250e+07f, 1.42185500e+07f, 1.41838500e+07f, 1.42973000e+07f, 1.42619750e+07f, 1.43760500e+07f, 1.43401000e+07f, 1.44548000e+07f, 1.44182250e+07f, 1.45335500e+07f, 1.44963500e+07f, 1.46123000e+07f, 1.45744750e+07f, 1.46910500e+07f, 1.46526000e+07f, 1.47698000e+07f, 1.47307250e+07f, 1.48485500e+07f, 1.48088500e+07f, 1.49273000e+07f, 1.48869750e+07f, 1.50060500e+07f, 1.49651000e+07f, 1.50848000e+07f, 1.50432250e+07f, 1.51635500e+07f, 1.51213500e+07f, 1.52423000e+07f, 1.51994750e+07f, 1.53210500e+07f, 1.52776000e+07f, 1.53998000e+07f, 1.53557250e+07f, 1.54785500e+07f, 1.54338500e+07f, 1.55573000e+07f, 1.55119750e+07f, 1.56360500e+07f, 1.55901000e+07f, 1.57148000e+07f, 1.56682250e+07f, 1.57935500e+07f, 1.57463500e+07f, 1.58723000e+07f, 1.58244750e+07f, 1.59510500e+07f, 1.59026000e+07f, 1.60298000e+07f, 1.59807250e+07f, 1.61085500e+07f, 1.60588500e+07f, 1.61873000e+07f, 1.61369750e+07f, 1.62660500e+07f, 1.62151000e+07f, 1.63448000e+07f, 1.62932250e+07f, 1.64235500e+07f, 1.63713500e+07f, 1.65023000e+07f, 1.64494750e+07f, 1.65810500e+07f, 1.65276000e+07f, 1.66598000e+07f, 1.66057250e+07f, 1.67385500e+07f, 1.66838500e+07f, 1.68173000e+07f, 1.67619750e+07f, 1.68960500e+07f, 1.68401000e+07f, 1.69748000e+07f, 1.69182250e+07f, 1.70535500e+07f, 1.69963500e+07f, 1.71323000e+07f, 1.70744750e+07f, 1.72110500e+07f, 1.71526000e+07f, 1.72898000e+07f, 1.72307250e+07f, 1.73685500e+07f, 1.73088500e+07f, 1.74473000e+07f, 1.73869750e+07f, 1.75260500e+07f, 1.74651000e+07f, 1.76048000e+07f, 1.75432250e+07f, 1.76835500e+07f, 2.46688500e+07f, 2.48098000e+07f, 2.47782250e+07f, 2.49198000e+07f, 2.48876000e+07f, 2.50298000e+07f, 2.49969750e+07f, 2.51398000e+07f, 2.51063500e+07f, 2.52498000e+07f, 2.52157250e+07f, 2.53598000e+07f, 2.53251000e+07f, 2.54698000e+07f, 2.54344750e+07f, 2.55798000e+07f, 2.55438500e+07f, 2.56898000e+07f, 2.56532250e+07f, 2.57998000e+07f, 2.57626000e+07f, 2.59098000e+07f, 2.58719750e+07f, 2.60198000e+07f, 2.59813500e+07f, 2.61298000e+07f, 2.60907250e+07f, 2.62398000e+07f, 2.62001000e+07f, 2.63498000e+07f, 2.63094750e+07f, 2.64598000e+07f, 2.64188500e+07f, 2.65698000e+07f, 2.65282250e+07f, 2.66798000e+07f, 2.66376000e+07f, 2.67898000e+07f, 2.67469750e+07f, 2.68998000e+07f, 2.68563500e+07f, 2.70098000e+07f, 2.69657250e+07f, 2.71198000e+07f, 2.70751000e+07f, 2.72298000e+07f, 2.71844750e+07f, 2.73398000e+07f, 2.72938500e+07f, 2.74498000e+07f, 2.74032250e+07f, 2.75598000e+07f, 2.75126000e+07f, 2.76698000e+07f, 2.76219750e+07f, 2.77798000e+07f, 2.77313500e+07f, 2.78898000e+07f, 2.78407250e+07f, 2.79998000e+07f, 2.79501000e+07f, 2.81098000e+07f, 2.80594750e+07f, 2.82198000e+07f, 2.81688500e+07f, 2.83298000e+07f, 2.82782250e+07f, 2.84398000e+07f, 2.83876000e+07f, 2.85498000e+07f, 2.84969750e+07f, 2.86598000e+07f, 2.86063500e+07f, 2.87698000e+07f, 2.87157250e+07f, 2.88798000e+07f, 2.88251000e+07f, 2.89898000e+07f, 2.89344750e+07f, 2.90998000e+07f, 2.90438500e+07f, 2.92098000e+07f, 2.91532250e+07f, 2.93198000e+07f, 2.92626000e+07f, 2.94298000e+07f, 2.93719750e+07f, 2.95398000e+07f, 2.94813500e+07f, 2.96498000e+07f, 2.95907250e+07f, 2.97598000e+07f, 2.97001000e+07f, 2.98698000e+07f, 2.98094750e+07f, 2.99798000e+07f, 2.99188500e+07f, 3.00898000e+07f, 3.00282250e+07f, 3.01998000e+07f, 3.01376000e+07f, 3.03098000e+07f, 3.02469750e+07f, 3.04198000e+07f, 3.03563500e+07f, 3.05298000e+07f, 3.04657250e+07f, 3.06398000e+07f, 3.05751000e+07f, 3.07498000e+07f, 3.06844750e+07f, 3.08598000e+07f, 3.07938500e+07f, 3.09698000e+07f, 3.09032250e+07f, 3.10798000e+07f, 3.10126000e+07f, 3.11898000e+07f, 3.11219750e+07f, 3.12998000e+07f, 3.12313500e+07f, 3.14098000e+07f, 3.13407250e+07f, 3.15198000e+07f, 3.14501000e+07f, 3.16298000e+07f, 3.15594750e+07f, 3.17398000e+07f, 3.16688500e+07f, 3.18498000e+07f, 3.17782250e+07f, 3.19598000e+07f, 3.18876000e+07f, 3.20698000e+07f, 3.19969750e+07f, 3.21798000e+07f, 3.21063500e+07f, 3.22898000e+07f, 3.22157250e+07f, 3.23998000e+07f, 3.23251000e+07f, 3.25098000e+07f, 3.24344750e+07f, 3.26198000e+07f, 3.25438500e+07f, 3.27298000e+07f, 3.26532250e+07f, 3.28398000e+07f, 3.27626000e+07f, 3.29498000e+07}; + + auto a = NDArrayFactory::create('c', {2, 72, 25}); + for (int e = 0; e < a.lengthOf(); e++) + a.p(e, e+1); + + auto b = NDArrayFactory::create('c', {2, 25, 2}); + for (int e = 0; e < b.lengthOf(); e++) + b.p(e, e+1); + + NDArray exp(_expB, _expS); + + auto c = MmulHelper::mmul(&a, &b); + + ASSERT_TRUE(exp.isSameShape(c)); + ASSERT_TRUE(exp.equalsTo(c, 1e1)); + + delete c; +} + + +TEST_F(NDArrayTest, TestNegSize1) { + auto array = NDArrayFactory::create('c', {2, 5, 7}); + + ASSERT_EQ(7, array.sizeAt(-1)); + ASSERT_EQ(5, array.sizeAt(-2)); + ASSERT_EQ(2, array.sizeAt(-3)); +} + +////////////////////////////////////////////////////////////////////// +// not-in-place +TEST_F(NDArrayTest, Permute1) { + + Nd4jLong shape1[] = {3, 5, 10, 15, 150, 15, 1, 8192, 1, 99}; + Nd4jLong shape2[] = {3, 15, 5, 10, 1, 150, 15, 8192, 0, 99}; + const std::initializer_list perm = {2, 0, 1}; + + NDArray arr1(shape1,true); + NDArray arr2(shape2,true); + + auto result = arr1.permute(perm); + ASSERT_TRUE(result.isSameShapeStrict(arr2)); +} + +////////////////////////////////////////////////////////////////////// +// in-place +TEST_F(NDArrayTest, Permute2) { + + Nd4jLong shape1[] = {3, 5, 10, 15, 150, 15, 1, 8192, 1, 99}; + Nd4jLong shape2[] = {3, 15, 5, 10, 1, 150, 15, 8192, 0, 99}; + const std::initializer_list perm = {2, 0, 1}; + + NDArray arr1(shape1,true); + NDArray arr2(shape2,true); + + ASSERT_TRUE(arr1.permutei(perm)); + ASSERT_TRUE(arr1.isSameShapeStrict(arr2)); +} + +TEST_F(NDArrayTest, RSubScalarTest1) { + auto array = NDArrayFactory::create('c', {1, 4}); + array.assign(2.0); + + auto result = NDArrayFactory::create('c', {1, 4}); + + array.applyScalar(scalar::ReverseSubtract, 1.0, result); + + ASSERT_NEAR(-1.0, result.meanNumber().e(0), 1e-5); +} + +TEST_F(NDArrayTest, BroadcastOpsTest1) { + + auto x = NDArrayFactory::create('c', {5, 5}); + auto row = NDArrayFactory::linspace(1.0f, 5.0f, 5); + float *brow = new float[5]{1,2,3,4,5}; + auto bshape = new Nd4jLong[8]{2, 1, 5, 1, 1, 8192, 1, 99}; + float *ebuf = new float[25] {1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5}; + auto eshape = new Nd4jLong[8] {2, 5, 5, 5, 1, 8192, 1, 99}; + NDArray expRow(brow, bshape); + NDArray exp(ebuf, eshape); + + ASSERT_TRUE(row->equalsTo(&expRow)); + + + x.applyBroadcast(broadcast::Add, {1}, *row, x); + + //x.printBuffer("Result"); + + ASSERT_TRUE(x.equalsTo(&exp)); + + delete[] brow; + delete[] bshape; + delete[] ebuf; + delete[] eshape; + delete row; +} + +TEST_F(NDArrayTest, TestIndexedPut2) { + auto x = NDArrayFactory::create('f', {2, 2}); + //x.printShapeInfo("x shape"); + x.p(1, 1.0f); + + //x.printBuffer("after"); + ASSERT_NEAR(reinterpret_cast(x.buffer())[2], 1.0, 1e-5); +} + +TEST_F(NDArrayTest, TestIndexedPut3) { + auto x = NDArrayFactory::create('c', {2, 2}); + x.p(1, 1.0f); + + //x.printBuffer("after"); + ASSERT_NEAR(reinterpret_cast(x.buffer())[1], 1.0, 1e-5); +} + +TEST_F(NDArrayTest, TestIndexedPut4) { + auto x = NDArrayFactory::create('f', {2, 2}); + x.p(0, 1, 1.0f); + + //x.printBuffer("after"); + ASSERT_NEAR(reinterpret_cast(x.buffer())[2], 1.0, 1e-5); +} + + +TEST_F(NDArrayTest, TestIndexedPut5) { + auto x = NDArrayFactory::create('c', {2, 2}); + x.p(0, 1, 1.0f); + + //x.printBuffer("after"); + ASSERT_NEAR(x.bufferAsT()[1], 1.0, 1e-5); +} + +TEST_F(NDArrayTest, TestAllTensors1) { + auto matrix = NDArrayFactory::create('c', {3, 5}); + + ResultSet rows = matrix.allTensorsAlongDimension({1}); + + ASSERT_EQ(3, rows.size()); +} + + +TEST_F(NDArrayTest, TestIndexing1) { + auto matrix = NDArrayFactory::create('c', {5, 5}); + for (int e = 0; e < matrix.lengthOf(); e++) + matrix.p(e, (float) e); + + auto sub = matrix({2,4, 0,0}, true); + + ASSERT_EQ(2, sub.rows()); + ASSERT_EQ(5, sub.columns()); + + ASSERT_NEAR(10, sub.e(0), 1e-5); +} + + +TEST_F(NDArrayTest, TestIndexing2) { + auto matrix = NDArrayFactory::create('c', {2, 5, 4, 4}); + matrix.linspace(0); + + auto sub = matrix({0,0, 2,4, 0,0, 0,0}, true); + + ASSERT_EQ(2, sub.sizeAt(0)); + ASSERT_EQ(2, sub.sizeAt(1)); + ASSERT_EQ(4, sub.sizeAt(2)); + ASSERT_EQ(4, sub.sizeAt(3)); + + ASSERT_EQ(64, sub.lengthOf()); + ASSERT_NEAR(32, sub.e(0), 1e-5); + ASSERT_NEAR(112, sub.e(32), 1e-5); +} + +TEST_F(NDArrayTest, TestIndexing3) { + auto matrix = NDArrayFactory::create('c', {5, 5}); + matrix.linspace(0); + + auto sub = matrix({2,4, 0,0}); + + ASSERT_EQ(2, sub.rows()); + ASSERT_EQ(5, sub.columns()); + + ASSERT_NEAR(10, sub.e(0), 1e-5); +} + + +TEST_F(NDArrayTest, TestIndexing4) { + auto matrix = NDArrayFactory::create('c', {2, 5, 4, 4}); + matrix.linspace(0); + + auto sub = matrix({0,0, 2,4, 0,0, 0,0}); + + ASSERT_EQ(2, sub.sizeAt(0)); + ASSERT_EQ(2, sub.sizeAt(1)); + ASSERT_EQ(4, sub.sizeAt(2)); + ASSERT_EQ(4, sub.sizeAt(3)); + + + ASSERT_EQ(64, sub.lengthOf()); + ASSERT_NEAR(32, sub.e(0), 1e-5); + ASSERT_NEAR(112, sub.e(32), 1e-5); +} + +TEST_F(NDArrayTest, TestReshapeNegative1) { + std::unique_ptr array(NDArrayFactory::create_('c', {2, 3, 4, 64})); + + array->reshapei('c', {-1, 64}); + + ASSERT_EQ(24, array->sizeAt(0)); + ASSERT_EQ(64, array->sizeAt(1)); +} + +TEST_F(NDArrayTest, TestReshapeNegative2) { + std::unique_ptr array(NDArrayFactory::create_('c', {2, 3, 4, 64})); + + auto reshaped = array->reshape('c', {-1, 64}); + + ASSERT_EQ(24, reshaped.sizeAt(0)); + ASSERT_EQ(64, reshaped.sizeAt(1)); +} + +////////////////////////////////////////////////////////////////////// +// TEST_F(NDArrayTest, SVD1) { + +// double arrA[8] = {1, 2, 3, 4, 5, 6, 7, 8}; +// double arrU[8] = {-0.822647, 0.152483, -0.421375, 0.349918, -0.020103, 0.547354, 0.381169, 0.744789}; +// double arrS[2] = {0.626828, 14.269095}; +// double arrVt[4] = {0.767187,-0.641423, 0.641423, 0.767187}; + +// int shapeA[8] = {2, 4, 2, 2, 1, 0, 1, 99}; +// int shapeS[8] = {2, 1, 2, 2, 1, 0, 1, 99}; +// int shapeVt[8] = {2, 2, 2, 2, 1, 0, 1, 99}; + +// auto a(arrA, shapeA); +// auto u(arrU, shapeA); +// auto s(arrS, shapeS); +// auto vt(arrVt, shapeVt); +// auto expU, expS(shapeS), expVt(shapeVt); + +// a.svd(expU, expS, expVt); +// ASSERT_TRUE(u.equalsTo(&expU)); +// ASSERT_TRUE(s.equalsTo(&expS)); +// ASSERT_TRUE(vt.equalsTo(&expVt)); + +// } + +// ////////////////////////////////////////////////////////////////////// +// TEST_F(NDArrayTest, SVD2) { + +// double arrA[6] = {1, 2, 3, 4, 5, 6}; +// double arrU[6] = {-0.386318, -0.922366, 0.000000, -0.922366, 0.386318, 0.000000}; +// double arrS[3] = {9.508032, 0.77287, 0.000}; +// double arrVt[9] = {-0.428667, -0.566307, -0.703947, 0.805964, 0.112382, -0.581199, 0.408248, -0.816497, 0.408248}; + +// int shapeA[8] = {2, 2, 3, 3, 1, 0, 1, 99}; +// int shapeS[8] = {2, 1, 3, 3, 1, 0, 1, 99}; +// int shapeVt[8] = {2, 3, 3, 3, 1, 0, 1, 99}; + +// auto a(arrA, shapeA); +// auto u(arrU, shapeA); +// auto s(arrS, shapeS); +// auto vt(arrVt, shapeVt); +// auto expU, expS(shapeS), expVt(shapeVt); + +// a.svd(expU, expS, expVt); +// ASSERT_TRUE(u.equalsTo (&expU)); +// ASSERT_TRUE(s.equalsTo(&expS)); +// ASSERT_TRUE(vt.equalsTo(&expVt)); + +// } + +// ////////////////////////////////////////////////////////////////////// +// TEST_F(NDArrayTest, SVD3) { + +// double arrA[8] = {1, 2, 3, 4, 5, 6, 7, 8}; +// double arrU[8] = {-0.822647, 0.152483, -0.421375, 0.349918, -0.020103, 0.547354, 0.381169, 0.744789}; +// double arrS[2] = {0.626828, 14.269095}; +// double arrVt[4] = {0.767187,-0.641423, 0.641423, 0.767187}; + +// int shapeA[8] = {2, 4, 2, 2, 1, 0, 1, 99}; +// int shapeS[8] = {2, 1, 2, 2, 1, 0, 1, 99}; +// int shapeVt[8] = {2, 2, 2, 2, 1, 0, 1, 99}; + +// auto a(arrA, shapeA); +// auto u(arrU, shapeA); +// auto s(arrS, shapeS); +// auto vt(arrVt, shapeVt); +// auto expU, expS(shapeS), expVt(shapeVt); + +// a.svd(expU, expS, expVt); +// ASSERT_TRUE(expU.hasOrthonormalBasis(1)); +// ASSERT_TRUE(expVt.hasOrthonormalBasis(0)); +// ASSERT_TRUE(expVt.hasOrthonormalBasis(1)); +// ASSERT_TRUE(expVt.isUnitary()); +// } + +// ////////////////////////////////////////////////////////////////////// +// TEST_F(NDArrayTest, SVD4) { + +// double arrA[6] = {1, 2, 3, 4, 5, 6}; +// double arrU[6] = {-0.386318, -0.922366, 0.000000, -0.922366, 0.386318, 0.000000}; +// double arrS[3] = {9.508032, 0.77287, 0.000}; +// double arrVt[9] = {-0.428667, -0.566307, -0.703947, 0.805964, 0.112382, -0.581199, 0.408248, -0.816497, 0.408248}; + +// int shapeA[8] = {2, 2, 3, 3, 1, 0, 1, 99}; +// int shapeS[8] = {2, 1, 3, 3, 1, 0, 1, 99}; +// int shapeVt[8] = {2, 3, 3, 3, 1, 0, 1, 99}; + +// auto a(arrA, shapeA); +// auto u(arrU, shapeA); +// auto s(arrS, shapeS); +// auto vt(arrVt, shapeVt); +// auto expU, expS(shapeS), expVt(shapeVt); + +// a.svd(expU, expS, expVt); +// ASSERT_TRUE(expU.hasOrthonormalBasis(1)); +// ASSERT_TRUE(expVt.hasOrthonormalBasis(0)); +// ASSERT_TRUE(expVt.hasOrthonormalBasis(1)); +// ASSERT_TRUE(expVt.isUnitary()); +// } + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestStdDev1) { + auto array = NDArrayFactory::create('c', {1, 5}); + for (int e = 0; e < array.lengthOf(); e++) + array.p(e, e+1); + + auto std = array.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); + ASSERT_NEAR(std, 1.58109, 1e-4); +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestStdDev2) { + auto array = NDArrayFactory::create('c', {5, 6}); + auto tad = array(0, {1}); + + ASSERT_EQ(5, tad.lengthOf()); + + for (int e = 0; e < tad.lengthOf(); e++) + tad.p(e, e+1); + + ASSERT_NEAR(15, tad.sumNumber().e(0), 1e-5); + + auto std = tad.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); + ASSERT_NEAR(std, 1.58109, 1e-4); +} + +TEST_F(NDArrayTest, TestStdDev3) { + auto array = NDArrayFactory::create('c', {1, 50000}); + for (int e = 0; e < array.lengthOf(); e++) + array.p(e, 1.f + (e%2?0.5f:-0.5f)); + + auto std = array.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); + // nd4j_printf("Variance is %f\n", std); + ASSERT_NEAR(std, 0.5f, 1.0e-5f); +} + +TEST_F(NDArrayTest, TestStdDev4) { + auto array = NDArrayFactory::create('c', {1, 20000}); + float const ethalon = 1 / 3.f; + float x = ethalon; + int total = array.lengthOf(); + for (int e = 0; e < total; e++) { + array.p(e, 1.0f + (e % 2?ethalon:-ethalon)); + x *= (e % 2? 2.f: 0.5f); + } + x = 0.f; + for (int e = 0; e < total; ++e) { + x += array.e(e); + } + x /= array.lengthOf(); + float y = 0; + double M2 = 0; + for (int e = 0; e < total; ++e) { + // y += sd::math::nd4j_abs(array(e) - x); + M2 += (array.e(e) - x) * (array.e(e) - x); + } + //y /= total; + M2 /= total; + + y = M2; + auto a = array.varianceNumber(variance::SummaryStatsStandardDeviation, false); + auto std = a.e(0); +// float bY = array.varianceNumber(); + float bY = 0.3333333f; + // nd4j_printf("Variance is %f, res is %f, internal is %f\n, deviance is %f(%f)\n", std, x, bY, y, sd::math::nd4j_sqrt(M2)); + ASSERT_NEAR(std, 0.3333333f, 1.0e-5f); +} + +TEST_F(NDArrayTest, TestStdDev5) { + auto array = NDArrayFactory::create('c', {1, 10000}); //00000}); + auto arrayD = NDArrayFactory::create('c', {1, 10000}); //00000}); + for (int e = 0; e < array.lengthOf(); e++) { + array.p(e, 1.f + (e%2?1/5.f:-1/5.f)); + arrayD.p(e, 1.0 + (e%2?1/5.:-1/5.)); + } + float stdF = array.varianceNumber(variance::SummaryStatsStandardDeviation, false).e(0); + double stdD = arrayD.varianceNumber(variance::SummaryStatsStandardDeviation, false).e(0); + // nd4j_printf("Variance is %f(%f)\n", stdF, stdD); + ASSERT_NEAR(stdD, 0.2, 1.0e-8); // 1/5 = 0.2 + ASSERT_NEAR(stdF, 0.2f, 1.0e-5f); // 1/5 = 0.2 +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestApplyIndexReduce1) { + float xBuff[] = {1, 5, 2, 12, 9, 3, 10, 7, 4, 11, 6, 8}; + Nd4jLong xShapeInfo[] = {3, 2, 3, 2, 6, 2, 1, 8192, 1, 99}; + std::vector dim = {0,1}; + + NDArray x(xBuff, xShapeInfo); + auto exp = NDArrayFactory::create({3, 1}); + + auto result = x.applyIndexReduce(indexreduce::IndexMax, dim); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, applyReduce3Dot) { + float xBuff[] = {1, 2, 3, 4, 5, 6}; + float yBuff[] = {2, 2, 2, 2, 2, 2}; + Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; + + NDArray x(xBuff, xShapeInfo); + NDArray y(yBuff, xShapeInfo); + + auto result = x.applyReduce3(reduce3::Dot, y); + ASSERT_TRUE(result.lengthOf() == 1); + ASSERT_NEAR(42, result.e(0), 1e-5); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, applyAllReduce3EuclideanDistance) { + float xBuff[] = {1, 2, 3, 4, 5, 6}; + float yBuff[] = {2, 2, 2, 2, 2, 2}; + float expBuff[] = {1.414214f, 1.414214f, 5.385165f, 5.385165f}; + Nd4jLong expShapeInfo[] = {2, 2, 2, 2, 1, 8192, 1, 99}; + Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; + + NDArray x(xBuff, xShapeInfo); + NDArray y(yBuff, xShapeInfo); + auto exp = NDArrayFactory::create('c', {2, 2}, {1.414214f, 1.414214f, 5.385165f, 5.385165f}); + + auto result = x.applyAllReduce3(reduce3::EuclideanDistance, y, {1}); + + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, applyReduce3EuclideanDistance) { + float xBuff[] = {1, 2, 3, 4, 5, 6}; + float yBuff[] = {2, 2, 2, 2, 2, 2}; + float expBuff[] = {1.414214f, 1.414214f, 5.385165f, 5.385165f}; + Nd4jLong expShapeInfo[] = {2, 2, 2, 2, 1, 8192, 1, 99}; + Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; + + NDArray x(xBuff, xShapeInfo); + NDArray y(yBuff, xShapeInfo); + NDArray exp(expBuff, expShapeInfo); + + auto result = x.applyAllReduce3(reduce3::EuclideanDistance, y ,{1}); + + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestVarianceAlongDimension1) { + + float xBuff[] = {1, 2, 3, 4, 5, 6}; + float expBuff[] = {0.816497f, 0.816497f}; + Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; + Nd4jLong expShapeInfo[] = {1, 2, 1, 8192, 1, 99}; + + NDArray x(xBuff, xShapeInfo); + NDArray exp(expBuff, expShapeInfo); + + auto result = x.varianceAlongDimension(variance::SummaryStatsStandardDeviation, false, {1}); + + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestVarianceAlongDimension2) { + float xBuff[] = {1, 2, 3, 4, 5, 6}; + float expBuff[] = {0.666667f, 0.666667f}; + Nd4jLong xShapeInfo[] = {2, 2, 3, 3, 1, 8192, 1, 99}; + Nd4jLong expShapeInfo[] = {1, 2, 1, 8192, 1, 99}; + + + NDArray x(xBuff, xShapeInfo); + NDArray exp(expBuff, expShapeInfo); + + auto result = x.varianceAlongDimension(variance::SummaryStatsVariance, false, {1}); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); +} +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestVarianceAlongDimension3) { + + + NDArray x = NDArrayFactory::create('c', {10, 10});//(xBuff, xShapeInfo); + NDArray exp = NDArrayFactory::create('c', {10});//(expBuff, expShapeInfo); + x.linspace(1); // 1, 2, 3, ..., 100 + exp.assign(825.f); + auto result = x.varianceAlongDimension(variance::SummaryStatsVariance, false, {0}); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestVarianceAlongDimension4) { + + + NDArray x = NDArrayFactory::create('c', {12, 1, 12});//(xBuff, xShapeInfo); + NDArray exp = NDArrayFactory::create('c', {1,12});//(expBuff, expShapeInfo); + x.linspace(1); // 1, 2, 3, ..., 100 + exp.assign(1716.); + auto result = x.varianceAlongDimension(variance::SummaryStatsVariance, false, {0}); + ASSERT_TRUE(exp.isSameShapeStrict(result)); + ASSERT_TRUE(exp.equalsTo(result)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestSubRowVector1) { + float xBuff[] = {6, 7, 8, 9}; + float yBuff[] = {1, 2}; + float expBuff[] = {5, 5, 7, 7}; + Nd4jLong xShapeInfo[] = {2, 2, 2, 2, 1, 8192, 1, 99}; + Nd4jLong yShapeInfo[] = {2, 1, 2, 2, 1, 8192, 1, 99}; + + NDArray x(xBuff, xShapeInfo); + NDArray y(yBuff, yShapeInfo); + NDArray target(x); + NDArray exp(expBuff, xShapeInfo); + + x.subRowVector(y, target); + + ASSERT_TRUE(exp.isSameShapeStrict(target)); + ASSERT_TRUE(exp.equalsTo(&target)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestDivRowVector1) { + float xBuff[] = {6, 8, 10, 12}; + float yBuff[] = {2, 4}; + float expBuff[] = {3, 2, 5, 3}; + Nd4jLong xShapeInfo[] = {2, 2, 2, 2, 1, 8192, 1, 99}; + Nd4jLong yShapeInfo[] = {2, 1, 2, 2, 1, 8192, 1, 99}; + + NDArray x(xBuff, xShapeInfo); + NDArray y(yBuff, yShapeInfo); + NDArray target(x); + NDArray exp(expBuff, xShapeInfo); + + x.divRowVector(y, target); + + ASSERT_TRUE(exp.isSameShapeStrict(target)); + ASSERT_TRUE(exp.equalsTo(&target)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestMulRowVector1) { + float xBuff[] = {6, 8, 10, 12}; + float yBuff[] = {2, 4}; + float expBuff[] = {12, 32, 20, 48}; + Nd4jLong xShapeInfo[] = {2, 2, 2, 2, 1, 8192, 1, 99}; + Nd4jLong yShapeInfo[] = {2, 1, 2, 2, 1, 8192, 1, 99}; + + NDArray x(xBuff, xShapeInfo); + NDArray y(yBuff, yShapeInfo); + NDArray target(x); + NDArray exp(expBuff, xShapeInfo); + + x.mulRowVector(y, target); + + ASSERT_TRUE(exp.isSameShapeStrict(target)); + ASSERT_TRUE(exp.equalsTo(&target)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestTensorDotAgain_1) { + int sY = 1; + int sX = 1; + int pY = 0; + int pX = 0; + int iC = 2; + int oC = 2; + int kY = 3; + int kX = 3; + int iY = 2; + int iX = 2; + int oY = 6; + int oX = 6; + int eD = iC * oC; + int B = 2; + + /* + input = np.linspace(1, B * iC * iY * iX, B * iC * iY * iX).reshape(B, iC, iY, iX) + weights = np.linspace(1, iC * oC * kY * kX, iC * oC * kY * kX).reshape(iC, oC, kY, kX) + */ + double _expB[] = {96.0, 116.0, 136.0, 156.0, 256.0, 276.0, 296.0, 316.0, 102.0, 124.0, 146.0, 168.0, 278.0, 300.0, 322.0, 344.0, 108.0, 132.0, 156.0, 180.0, 300.0, 324.0, 348.0, 372.0, 114.0, 140.0, 166.0, 192.0, 322.0, 348.0, 374.0, 400.0, 120.0, 148.0, 176.0, 204.0, 344.0, 372.0, 400.0, 428.0, 126.0, 156.0, 186.0, 216.0, 366.0, 396.0, 426.0, 456.0, 132.0, 164.0, 196.0, 228.0, 388.0, 420.0, 452.0, 484.0, 138.0, 172.0, 206.0, 240.0, 410.0, 444.0, 478.0, 512.0, 144.0, 180.0, 216.0, 252.0, 432.0, 468.0, 504.0, 540.0, 150.0, 188.0, 226.0, 264.0, 454.0, 492.0, 530.0, 568.0, 156.0, 196.0, 236.0, 276.0, 476.0, 516.0, 556.0, 596.0, 162.0, 204.0, 246.0, 288.0, 498.0, 540.0, 582.0, 624.0, 168.0, 212.0, 256.0, 300.0, 520.0, 564.0, 608.0, 652.0, 174.0, 220.0, 266.0, 312.0, 542.0, 588.0, 634.0, 680.0, 180.0, 228.0, 276.0, 324.0, 564.0, 612.0, 660.0, 708.0, 186.0, 236.0, 286.0, 336.0, 586.0, 636.0, 686.0, 736.0, 192.0, 244.0, 296.0, 348.0, 608.0, 660.0, 712.0, 764.0, 198.0, 252.0, 306.0, 360.0, 630.0, 684.0, 738.0, 792.0}; + + Nd4jLong _expS[] = {6, 2, 3, 3, 2, 2, 2, 72, 24, 8, 4, 2, 1, 16384, 1, 99}; + NDArray exp(_expB, _expS, sd::LaunchContext ::defaultContext(), false); + + auto input = NDArrayFactory::create('c', {B, iC, iY, iX}); + auto weights = NDArrayFactory::create('c', {iC, oC, kY, kX}); + + input.linspace(1); + weights.linspace(1); + + auto result = MmulHelper::tensorDot(&weights, &input, {0}, {1}); + + //result->printShapeInfo("result shape"); + ASSERT_TRUE(exp.isSameShape(result)); + +// exp.printBuffer("Expctd buffer"); +// result->printBuffer("Result buffer"); + ASSERT_TRUE(exp.equalsTo(result)); + + delete result; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestBroadcast_1) { + double _expB[] = {1.000000, 1.000000, 1.000000, 1.000000, 2.000000, 2.000000, 2.000000, 2.000000, 3.000000, 3.000000, 3.000000, 3.000000, 1.000000, 1.000000, 1.000000, 1.000000, 2.000000, 2.000000, 2.000000, 2.000000, 3.000000, 3.000000, 3.000000, 3.000000}; + Nd4jLong _expS[] = {4, 2, 3, 2, 2, 12, 4, 2, 1, 16384, 1, 99}; + NDArray exp(_expB, _expS, sd::LaunchContext ::defaultContext(), false); + + auto input = NDArrayFactory::create('c',{ 2, 3, 2, 2}); + auto bias = NDArrayFactory::create('c', {1, 3}); + + bias.linspace(1); + + input.applyBroadcast(broadcast::Add, {1}, bias, input); + + //input.printBuffer("result"); + ASSERT_TRUE(exp.equalsTo(&input)); +} + +TEST_F(NDArrayTest, TestTranspose_11) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + x.transposei(); + + ASSERT_EQ(4, x.sizeAt(0)); + ASSERT_EQ(3, x.sizeAt(1)); + ASSERT_EQ(2, x.sizeAt(2)); +} + + +TEST_F(NDArrayTest, TestTranspose_12) { + auto x = NDArrayFactory::create('c', {2, 3, 4}); + auto y = x.transpose(); + + ASSERT_EQ(4, y.sizeAt(0)); + ASSERT_EQ(3, y.sizeAt(1)); + ASSERT_EQ(2, y.sizeAt(2)); + + ASSERT_EQ(2, x.sizeAt(0)); + ASSERT_EQ(3, x.sizeAt(1)); + ASSERT_EQ(4, x.sizeAt(2)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestMMulMultiDim) { + const int bS=2; + const int K=3; + const int N=4; + + auto input = NDArrayFactory::create('c', {bS, K, N}); + auto weights = NDArrayFactory::create('c', {3*K, K}); + auto expected = NDArrayFactory::create('c', {bS, 3*K, N}, { 38, 44, 50, 56, 83, 98, 113, 128, 128, 152, 176, 200, 173, 206, 239, 272, 218, 260, 302, 344, 263, 314, 365, 416, 308, 368, 428, 488, 353, 422, 491, 560, 398, 476, 554, 632, 110, 116, 122, 128, 263, 278, 293, 308, 416, 440, 464, 488, 569, 602, 635, 668, 722, 764, 806, 848, 875, 926, 977, 1028, 1028, 1088, 1148, 1208, 1181, 1250, 1319, 1388, 1334, 1412, 1490, 1568}); + + input.linspace(1); + weights.linspace(1); + + auto result = MmulHelper::mmul(&weights, &input, nullptr, 1., 0.); + // result must have such shape [bS x 3K x N] + + ASSERT_TRUE(result->isSameShape(&expected)); + + //result->printShapeInfo("result shape"); + // result->printBuffer("result buffer"); + ASSERT_TRUE(result->equalsTo(&expected)); + delete result; +} + + +TEST_F(NDArrayTest, AdditionOperator1) { + + auto input1 = NDArrayFactory::create('c', {2,2}); + auto input2 = NDArrayFactory::create('c', {2,2}); + auto expected = NDArrayFactory::create('c', {2,2}); + + input1.assign(1.5); + input2.assign(2.); + expected.assign(3.5); + + input2 = input1 + input2; + + ASSERT_TRUE(input2.equalsTo(&expected)); + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestMatmMul_Again_1) { + auto a = NDArrayFactory::create('c', {3, 4, 1}); + auto b = NDArrayFactory::create('c', {3, 1, 5}); + + a.linspace(1); + b.linspace(1); + + float _expB[] = {1.f, 2.f, 3.f, 4.f, 5.f, 2.f, 4.f, 6.f, 8.f, 10.f, 3.f, 6.f, 9.f, 12.f, 15.f, 4.f, 8.f, 12.f, 16.f, 20.f, 30.f, 35.f, 40.f, 45.f, 50.f, 36.f, 42.f, 48.f, 54.f, 60.f, 42.f, 49.f, 56.f, 63.f, 70.f, 48.f, 56.f, 64.f, 72.f, 80.f, 99.f, 108.f, 117.f, 126.f, 135.f, 110.f, 120.f, 130.f, 140.f, 150.f, 121.f, 132.f, 143.f, 154.f, 165.f, 132.f, 144.f, 156.f, 168.f, 180.f}; + Nd4jLong _expS[] = {3, 3, 4, 5, 20, 5, 1, 8192, 1, 99}; + NDArray c(_expB, _expS, sd::LaunchContext ::defaultContext(), false); + + auto c_ = MmulHelper::mmul(&a, &b); + + ASSERT_TRUE(c.isSameShape(c_)); + ASSERT_TRUE(c.equalsTo(c_)); + + delete c_; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, TestMatmMul_Again_2) { + auto a = NDArrayFactory::create('c', {2, 5, 4}); + auto b = NDArrayFactory::create('c', {2, 4, 1}); + + a.linspace(1); + b.linspace(1); + + double _expB[] = {30.f, 70.f, 110.f, 150.f, 190.f, 590.f, 694.f, 798.f, 902.f, 1006.f}; + Nd4jLong _expS[] = {3, 2, 5, 1, 5, 1, 1, 16384, 1, 99}; + NDArray c(_expB, _expS); + + auto c_ = MmulHelper::mmul(&a, &b); + + ASSERT_TRUE(c.isSameShape(c_)); + + ASSERT_TRUE(c.equalsTo(c_)); + + delete c_; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Operator_Plus_Test_1) +{ + double expBuff[] = {2., 3, 3., 4., 4., 5, 5., 6., 6., 7, 7., 8.}; + + auto x = NDArrayFactory::create('c', {3, 1, 2}); + auto y = NDArrayFactory::create('c', {2, 1}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); + + //x.printShapeInfo("x shape"); + //y.printShapeInfo("y shape"); + //expected.printShapeInfo("e shape"); + //expected.printIndexedBuffer("e"); + + x.linspace(1); + y.linspace(1); + + auto result = x + y; + + //result.printIndexedBuffer("result"); + + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Operator_Plus_Test_2) +{ + double expBuff[] = {2., 3, 3., 4., 4., 5, 5., 6., 6., 7, 7., 8.}; + + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); + + x.linspace(1); + y.linspace(1); + + auto result = x + y; + // result.printIndexedBuffer(); + + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Operator_Plus_Test_3) +{ + double expBuff[] = {2., 3, 3., 4., 4., 5, 5., 6., 6., 7, 7., 8.}; + + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); + + x.linspace(1); + y.linspace(1); + + auto result = x + y; + // result.printIndexedBuffer(); + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Operator_Plus_Test_4) +{ + double expBuff[] = {11.,12., 12.,13., 13.,14., 14.,15., 13.,14., 14.,15., 15.,16., 16.,17., 15.,16., 16.,17., 17.,18., 18.,19.}; + + auto x = NDArrayFactory::create('c', {3, 1, 2}); + auto y = NDArrayFactory::create('c', {4, 1}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 2}); + + x.linspace(10); + y.linspace(1); + + auto result = x + y; + + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Operator_Minus_Test_1) +{ + double expBuff[] = {9. ,10., 10.,11., 11.,12., 12.,13., 17.,18., 18.,19., 19.,20., 20.,21., 25.,26., 26.,27., 27.,28., 28.,29.}; + + auto x = NDArrayFactory::create('c', {3, 4, 2}); + auto y = NDArrayFactory::create('c', {4, 1}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 2}); + + x.linspace(10); + y.linspace(1); + + auto result = x - y; + + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Operator_Minus_Test_2) +{ + double expBuff[] = {9., 8., 7., 6., 6., 5., 4., 3., 11.,10., 9., 8., 8., 7., 6., 5., 13.,12.,11.,10., 10., 9., 8., 7.}; + + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 2, 4}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 4}); + + x.linspace(10); + y.linspace(1); + + auto result = x - y; + + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Operator_Minus_Test_3) +{ + double expBuff[] = {9., 8., 7., 6., 6., 5., 4., 3., 11.,10., 9., 8., 8., 7., 6., 5., 13.,12.,11.,10., 10., 9., 8., 7.}; + + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {2, 4}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 4}); + + x.linspace(10); + y.linspace(1); + + auto result = x - y; + + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Operator_Minus_Test_4) +{ + double expBuff[] = {9.,10., 8., 9., 11.,12.,10.,11., 13.,14.,12.,13.}; + + auto x = NDArrayFactory::create('c', {3, 1, 2}); + auto y = NDArrayFactory::create('c', {2, 1}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); + + x.linspace(10); + y.linspace(1); + + auto result = x - y; + + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Operator_Minus_Test_5) +{ + double expBuff[] = {9. ,8 ,10., 9., 11.,10, 12.,11., 13.,12, 14.,13.}; + + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); + + x.linspace(10); + y.linspace(1); + + auto result = x - y; + + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Operator_Minus_Test_6) +{ + double expBuff[] = {9., 8, 10., 9, 11.,10, 12.,11., 13.,12, 14.,13, 15.,14, 16.,15., 17.,16, 18.,17, 19.,18, 20.,19.}; + + auto x = NDArrayFactory::create('c', {3, 4, 1}); + auto y = NDArrayFactory::create('c', {1, 1, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 2}); + + x.linspace(10); + y.linspace(1); + + auto result = x - y; + + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Operator_Multiply_Test_1) +{ + double expBuff[] = {10., 11., 24., 26., 42., 45., 64., 68., 18., 19., 40., 42., 66., 69., 96.,100., 26., 27., 56., 58., 90., 93., 128.,132.}; + + auto x = NDArrayFactory::create('c', {3, 4, 2}); + auto y = NDArrayFactory::create('c', {4, 1}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 2}); + + x.linspace(10); + y.linspace(1); + + auto result = x * y; + + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Operator_Multiply_Test_2) +{ + double expBuff[] = {10.,20., 30., 40., 55.,66., 77., 88., 12.,24., 36., 48., 65.,78., 91.,104., 14.,28., 42., 56., 75.,90.,105.,120.}; + + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 2, 4}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 4}); + + x.linspace(10); + y.linspace(1); + + auto result = x * y; + + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Operator_Multiply_Test_3) +{ + double expBuff[] = {10.,20., 30., 40.,55.,66., 77., 88., 12.,24., 36., 48.,65.,78., 91.,104., 14.,28., 42., 56.,75.,90.,105.,120.}; + + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {2, 4}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 4}); + + x.linspace(10); + y.linspace(1); + + auto result = x * y; + + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Operator_Multiply_Test_4) +{ + double expBuff[] = {10.,11.,20.,22., 12.,13.,24.,26., 14.,15.,28.,30.}; + + auto x = NDArrayFactory::create('c', {3, 1, 2}); + auto y = NDArrayFactory::create('c', {2, 1}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); + + x.linspace(10); + y.linspace(1); + + auto result = x * y; + + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); +} + + + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Operator_Multiply_Test_5) +{ + double expBuff[] = {10.,20.,11.,22., 12.,24.,13.,26., 14.,28.,15.,30.}; + + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); + + x.linspace(10); + y.linspace(1); + + auto result = x * y; + + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); +} + + + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Operator_Multiply_Test_6) +{ + double expBuff[] = {10,11.,12.,13.,28.,30.,32.,34.,54.,57.,60.,63.}; + + auto x = NDArrayFactory::create('c', {3, 4, 1}); + auto y = NDArrayFactory::create('c', {3, 1, 1}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 1}); + + x.linspace(10); + y.linspace(1); + + auto result = x * y; + + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Operator_Divide_Test_1) +{ + double expBuff[] = {10. ,11. , 6. , 6.5 , 4.6666, 5. , 4. , 4.25 , 18. ,19. , 10. ,10.5 , 7.3333, 7.6666, 6. , 6.25 , 26. ,27. , 14. ,14.5 , 10. ,10.3333, 8. , 8.25}; + + auto x = NDArrayFactory::create('c', {3, 4, 2}); + auto y = NDArrayFactory::create('c', {4, 1}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 2}); + + x.linspace(10); + y.linspace(1); + + auto result = x / y; + + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result,1e-4)); +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Operator_Divide_Test_2) +{ + double expBuff[] = {10. ,5. ,3.333333,2.5 , 2.2,1.83333,1.571428,1.375, 12. ,6. ,4. ,3. , 2.6,2.16666,1.857142,1.625, 14. ,7. ,4.666666,3.5 , 3. ,2.5 ,2.142857,1.875}; + + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 2, 4}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 4}); + + x.linspace(10); + y.linspace(1); + + auto result = x / y; + + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Operator_Divide_Test_3) +{ + double expBuff[] = {10. ,5. ,3.333333,2.5 , 2.2,1.833333,1.571428,1.375, 12. ,6. ,4. ,3. , 2.6,2.166666,1.857142,1.625, 14. ,7. ,4.666666,3.5 , 3. ,2.5 ,2.142857,1.875}; + + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {2, 4}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 4}); + + x.linspace(10); + y.linspace(1); + + auto result = x / y; + + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Operator_Divide_Test_4) +{ + double expBuff[] = {10.,11., 5., 5.5, 12.,13., 6., 6.5, 14.,15., 7., 7.5}; + + auto x = NDArrayFactory::create('c', {3, 1, 2}); + auto y = NDArrayFactory::create('c', {2, 1}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); + + x.linspace(10); + y.linspace(1); + + auto result = x / y; + + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Operator_Divide_Test_5) +{ + double expBuff[] = {10.,5., 11.,5.5, 12.,6., 13.,6.5, 14.,7., 15.,7.5}; + + auto x = NDArrayFactory::create('c', {3, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 2}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 2, 2}); + + x.linspace(10); + y.linspace(1); + + auto result = x / y; + + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Operator_Divide_Test_6) +{ + double expBuff[] = {10. , 5.5 , 4. , 3.25 ,14. , 7.5 , 5.333333, 4.25 ,18. , 9.5 , 6.666666, 5.25}; + + auto x = NDArrayFactory::create('c', {3, 4, 1}); + auto y = NDArrayFactory::create('c', {1, 4, 1}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 1}); + + x.linspace(10); + y.linspace(1); + + auto result = x / y; + + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Operator_Divide_Test_7) +{ + double expBuff[] = {10., 5. ,3.333333,2.5 ,11., 5.5,3.666666,2.75,12., 6. ,4. ,3. ,13., 6.5,4.333333,3.25, 14., 7. ,4.666666,3.5 ,15., 7.5,5. ,3.75,16., 8. ,5.333333,4. ,17., 8.5,5.666666,4.25, 18., 9. ,6. ,4.5 ,19., 9.5,6.333333,4.75,20.,10. ,6.666666,5. ,21.,10.5,7. ,5.25}; + + auto x = NDArrayFactory::create('c', {3, 4, 1}); + auto y = NDArrayFactory::create('c', {1, 1, 4}); + auto expected = NDArrayFactory::create(expBuff, 'c', {3, 4, 4}); + + x.linspace(10); + y.linspace(1); + + auto result = x / y; + + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); +} + +#ifndef __CUDABLAS__ +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Test_Lambda_1) { + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); + auto exp = NDArrayFactory::create('c', {1, 5}, {4, 5, 6, 7, 8}); + + auto lambda = LAMBDA_F(_val) { + return _val + 3.0f; + }; + + x.applyLambda(lambda, x); + + ASSERT_TRUE(exp.equalsTo(&x)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Test_Lambda_2) { + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 1, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 5}, {1, 2, 1, 2, 1}); + auto exp = NDArrayFactory::create('c', {1, 5}, {3, 5, 3, 5, 3}); + + auto lambda = LAMBDA_FF(_x, _y) { + return _x + _y + 1.0f; + }; + + x.applyPairwiseLambda(y, lambda, x); + + ASSERT_TRUE(exp.equalsTo(&x)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Test_Lambda_3) { + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 1, 2, 1}); + auto y = NDArrayFactory::create('c', {1, 5}, {1, 2, 1, 2, 1}); + auto exp = NDArrayFactory::create('c', {1, 5}, {4, 8, 4, 8, 4}); + + auto lambda = LAMBDA_DD(_x, _y) { + return (_x + _y) * 2; + }; + + x.applyPairwiseLambda(y, lambda, x); + + ASSERT_TRUE(exp.equalsTo(&x)); +} + +#endif + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Test_swapUnsafe_1) { + + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {1, 4}, {5, 6, 7, 8}); + auto expX = NDArrayFactory::create('c', {2, 2}, {5, 6, 7, 8}); + auto expY = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + + x.swapUnsafe(y); + + ASSERT_TRUE(expX.equalsTo(&x)); + ASSERT_TRUE(expY.equalsTo(&y)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Test_diagonal_1) { + + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + auto exp = NDArrayFactory::create('c', {2, 1}, {1, 5}); + + auto diag = x.diagonal('c'); + + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Test_diagonal_2) { + + auto x = NDArrayFactory::create('f', {2, 3}); + auto exp = NDArrayFactory::create('f', {2, 1}, {1, 5}); + x.linspace(1); + + auto diag = x.diagonal('c'); + + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Test_diagonal_3) { + + auto x = NDArrayFactory::create('c', {2, 2}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {1, 2}, {1, 4}); + + auto diag = x.diagonal('r'); + + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Test_diagonal_4) { + + auto x = NDArrayFactory::create('f', {2, 2}); + x.linspace(1); + auto exp = NDArrayFactory::create('f', {1, 2}, {1, 4}); + + auto diag = x.diagonal('r'); + + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Test_diagonal_5) { + + auto x = NDArrayFactory::create('c', {2, 2, 2}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {1, 2}, {1, 8}); + + auto diag = x.diagonal('r'); + + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Test_diagonal_6) { + + auto x = NDArrayFactory::create('f', {2, 2, 2}); + x.linspace(1); + auto exp = NDArrayFactory::create('f', {1, 2}, {1, 8}); + + auto diag = x.diagonal('r'); + + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Test_diagonal_7) { + + auto x = NDArrayFactory::create('f', {2, 2, 2}); + x.linspace(1); + auto exp = NDArrayFactory::create('f', {2, 1}, {1, 8}); + + auto diag = x.diagonal('c'); + + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Test_diagonal_8) { + + auto x = NDArrayFactory::create('c', {2, 3}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {1, 2}, {1, 5}); + + auto diag = x.diagonal('r'); + + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Test_diagonal_9) { + + auto x = NDArrayFactory::create('c', {2, 2}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {2, 1}, {1, 4}); + + auto diag = x.diagonal('c'); + + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Test_diagonal_10) { + + auto x = NDArrayFactory::create('f', {2, 2}); + x.linspace(1); + auto exp = NDArrayFactory::create('f', {2, 1}, {1, 4}); + + auto diag = x.diagonal('c'); + + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Test_diagonal_11) { + + auto x = NDArrayFactory::create('f', {3, 3}); + x.linspace(1); + auto exp = NDArrayFactory::create('f', {3, 1}, {1, 5, 9}); + + auto diag = x.diagonal('c'); + + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Test_diagonal_12) { + + auto x = NDArrayFactory::create('c', {3, 3}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {1, 3}, {1, 5, 9}); + + auto diag = x.diagonal('r'); + + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Test_diagonal_13) { + + auto x = NDArrayFactory::create('c', {3, 3, 4}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {3, 1}, {1,18,35}); + + auto diag = x.diagonal('c'); + + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Test_diagonal_14) { + + auto x = NDArrayFactory::create('c', {3, 3, 4}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {1, 3}, {1,18,35}); + + auto diag = x.diagonal('r'); + + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Test_diagonal_15) { + + auto x = NDArrayFactory::create('f', {3, 3, 4}); + x.linspace(1); + auto exp = NDArrayFactory::create('f', {1, 3}, {1,18,35}); + + auto diag = x.diagonal('r'); + + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Test_diagonal_16) { + + auto x = NDArrayFactory::create('f', {1, 5}); + x.linspace(1); + auto exp = NDArrayFactory::create('f', {1, 1}, {1}); + + auto diag = x.diagonal('c'); + + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Test_diagonal_17) { + + auto x = NDArrayFactory::create('c', {5, 1}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {1, 1}, {1}); + + auto diag = x.diagonal('r'); + + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, Test_diagonal_18) { + + auto x = NDArrayFactory::create('f', {1, 1}); + x.linspace(1); + auto exp = NDArrayFactory::create('f', {1, 1}, {1}); + + auto diag = x.diagonal('r'); + + ASSERT_TRUE(exp.isSameShape(diag)); + ASSERT_TRUE(exp.equalsTo(diag)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest, assign_test1) { + + NDArray x('c', {2, 3}, {1,2,3,4,5,6}); + NDArray y('c', {2, 3}, {10,20,30,40,50,60}); + y.reshapei('c',{3, 2}); + + x.assign(y); + x.reshapei('c',{3, 2}); + ASSERT_TRUE(x.equalsTo(y)); +} diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/NDArrayTests2.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/NDArrayTests2.cpp new file mode 100644 index 000000000..18a8aeac5 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/NDArrayTests2.cpp @@ -0,0 +1,1309 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 21.11.17. +// + +#include "testlayers.h" +#include +#include +#include +#include + +using namespace sd; + +////////////////////////////////////////////////////////////////////// +class NDArrayTest2 : public testing::Test { +public: + +}; + + +TEST_F(NDArrayTest2, Test_ByteVector_1) { + auto x = NDArrayFactory::create('c', {10, 10}); + x.linspace(1); + + auto vec = x.asByteVector(); + + auto restored = new NDArray((float *)vec.data(), x.shapeInfo(), x.getContext(), false); + + + ASSERT_TRUE(x.equalsTo(restored)); + + delete restored; +} + +TEST_F(NDArrayTest2, Test_ByteVector_2) { + auto x = NDArrayFactory::create('c', {10, 10}); + x.linspace(1); + + auto vec = x.asByteVector(); + + auto restored = new NDArray((bfloat16 *)vec.data(), x.shapeInfo(), x.getContext(), false); + + ASSERT_TRUE(x.equalsTo(restored)); + + delete restored; +} + +TEST_F(NDArrayTest2, Test_ByteVector_3) { + auto x = NDArrayFactory::create('c', {10, 10}); + x.linspace(1); + + auto vec = x.asByteVector(); + + auto restored = new NDArray((double *)vec.data(), x.shapeInfo(), x.getContext(), false); + + ASSERT_TRUE(x.equalsTo(restored)); + + delete restored; +} + +TEST_F(NDArrayTest2, Test_Reshape_Scalar_1) { + auto x = NDArrayFactory::create('c', {1, 1}, {1.0}); + auto e = NDArrayFactory::create(1.0); + + x.reshapei({}); + + ASSERT_EQ(e, x); + ASSERT_EQ(e.rankOf(), x.rankOf()); +} + +TEST_F(NDArrayTest2, Test_Reshape_Scalar_2) { + auto x = NDArrayFactory::create('c', {1, 1}, {1.0}); + auto e = NDArrayFactory::create('c', {1}, {1.0}); + + x.reshapei({1}); + + ASSERT_EQ(e, x); + ASSERT_EQ(e.rankOf(), x.rankOf()); +} + +TEST_F(NDArrayTest2, Test_IndexReduce_1) { + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); + + ExtraArguments extras({3.0, 0.0, 10.0}); + int idx = x.indexReduceNumber(indexreduce::FirstIndex, &extras).e(0); + + ASSERT_EQ(2, idx); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, SetIdentity_test_1) { + + auto x = NDArrayFactory::create('c', {1, 5}); + auto xExp = NDArrayFactory::create('c', {1, 5}, {1, 0, 0, 0, 0}); + + x.setIdentity(); + ASSERT_TRUE(x.equalsTo(&xExp)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, SetIdentity_test_2) { + + auto x = NDArrayFactory::create('f', {1, 5}); + auto xExp = NDArrayFactory::create('f', {1, 5}, {1, 0, 0, 0, 0}); + + x.setIdentity(); + + ASSERT_TRUE(x.equalsTo(&xExp)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, SetIdentity_test_3) { + + auto x = NDArrayFactory::create('f', {1, 1}); + auto xExp = NDArrayFactory::create('f', {1, 1}, {1}); + + x.setIdentity(); + + ASSERT_TRUE(x.equalsTo(&xExp)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, SetIdentity_test_4) { + + auto x = NDArrayFactory::create('f', {2, 1}); + auto xExp = NDArrayFactory::create('f', {2, 1}, {1,0}); + + x.setIdentity(); + + ASSERT_TRUE(x.equalsTo(&xExp)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, SetIdentity_test_5) { + + auto x = NDArrayFactory::create('f', {2, 2}); + auto xExp = NDArrayFactory::create('f', {2, 2}, {1,0,0,1}); + + x.setIdentity(); + + ASSERT_TRUE(x.equalsTo(&xExp)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, SetIdentity_test_6) { + + auto x = NDArrayFactory::create('c', {3, 2}); + auto xExp = NDArrayFactory::create('c', {3, 2}, {1.f, 0.f, 0.f, 1.f, 0.f, 0.f}); + + x.setIdentity(); + + ASSERT_TRUE(x.equalsTo(&xExp)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, SetIdentity_test_7) { + + auto x = NDArrayFactory::create('c', {3, 4}); + auto xExp = NDArrayFactory::create('c', {3, 4}, {1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, 0.f, 0.f, 1.f, 0.f}); + + x.setIdentity(); + + ASSERT_TRUE(x.equalsTo(&xExp)); +} + +#ifdef ALLOWED_3D_IDENTITY +//////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, SetIdentity_test_8) { + + auto x = NDArrayFactory::create('c', {3, 3, 3}); + auto xExp = NDArrayFactory::create('c', {3, 3, 3}, {1.,0.,0. ,0.,0.,0., 0.,0.,0., 0.,0.,0. ,0.,1.,0., 0.,0.,0., 0.,0.,0. ,0.,0.,0., 0.,0.,1.}); + x.setIdentity(); + + ASSERT_TRUE(x.equalsTo(&xExp)); +} +#endif + +//////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, Test_AllReduce3_1) { + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 1, 2, 3}); + auto y = NDArrayFactory::create('c', {2, 3}, {2, 3, 4, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {2, 2}, {1.73205, 1.73205, 1.73205, 1.73205}); + + auto z = x.applyAllReduce3(reduce3::EuclideanDistance, y, {1}); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, Test_AllReduce3_2) { + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 2, 3, 4 }); + auto y = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {2, 2}, {0., 1.73205, 1.73205, 0.}); + + auto z = x.applyAllReduce3(reduce3::EuclideanDistance, y, {1}); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, mmul_test1) { + + auto x = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4, 4}, {1,2, 3, 4,2,4, 6, 8,3,6, 9,12,4,8,12,16}); + + auto result = mmul(x, y); + ASSERT_TRUE(exp.isSameShape(&result)); + ASSERT_TRUE(exp.equalsTo(&result)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, mmul_test2) { + + auto x = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {1, 1}, {30}); + + auto result = mmul(y ,x); + + ASSERT_TRUE(exp.isSameShape(&result)); + ASSERT_TRUE(exp.equalsTo(&result)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, mmul_test3) { + + auto x = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4, 4}, {1. ,0.2 ,0.3 ,0.4 ,0.2,0.04,0.06,0.08,0.3,0.06,0.09,0.12,0.4,0.08,0.12,0.16}); + auto w = NDArrayFactory::create( x.ordering(), {(int)x.lengthOf(), 1}, x.getContext()); // column-vector + auto wT = NDArrayFactory::create(x.ordering(), {1, (int)x.lengthOf()}, x.getContext()); // row-vector (transposed w) + + w = x / (float)10.; + w.p(0, 1.); + wT.assign(&w); + + auto result = mmul(w ,wT); + + ASSERT_TRUE(exp.isSameShape(&result)); + ASSERT_TRUE(exp.equalsTo(&result)); + +} + + +TEST_F(NDArrayTest2, Test_Streamline_1) { + auto x = NDArrayFactory::create('c', {3, 4, 6}); + auto y = NDArrayFactory::create('c', {3, 4, 6}); + x.linspace(1); + y.linspace(1); + + x.permutei({1, 0, 2}); + y.permutei({1, 0, 2}); + + y.streamline(); + + ASSERT_TRUE(x.isSameShape(&y)); + ASSERT_TRUE(x.equalsTo(&y)); + ASSERT_FALSE(x.isSameShapeStrict(y)); +} + + +TEST_F(NDArrayTest2, Test_Streamline_2) { + auto x = NDArrayFactory::create('c', {3, 4, 6}); + auto y = NDArrayFactory::create('f', {3, 4, 6}); + x.linspace(1); + y.linspace(1); + + ASSERT_TRUE(x.isSameShape(&y)); + ASSERT_TRUE(x.equalsTo(&y)); + + y.streamline('c'); + + ASSERT_TRUE(x.isSameShape(&y)); + ASSERT_TRUE(x.equalsTo(&y)); +} + +TEST_F(NDArrayTest2, Test_Enforce_1) { + auto x = NDArrayFactory::create('c', {4, 1, 1, 4}); + auto exp = NDArrayFactory::create('c', {4, 4}); + + x.linspace(1); + exp.linspace(1); + + x.enforce({4, 4}, 'c'); + + ASSERT_TRUE(exp.isSameShapeStrict(x)); + ASSERT_TRUE(exp.equalsTo(&x)); +} + +TEST_F(NDArrayTest2, TestVector_1) { + auto x = NDArrayFactory::create('c', {2, 3}); + auto row = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 1, 2, 3}); + + x.addiRowVector(row); + + ASSERT_TRUE(exp.equalsTo(&x)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, Operator_Plus_Test_5) +{ + + auto x = NDArrayFactory::create('c', {8, 8, 8}); + auto y = NDArrayFactory::create('c', {8, 1, 8}); + auto expected = NDArrayFactory::create('c', {8, 8, 8}); + + x = 1.; + y = 2.; + expected = 3.; + + auto result = x + y; + + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, Operator_Plus_Test_6) { + + auto x = NDArrayFactory::create('c', {3, 3, 3}); + auto y = NDArrayFactory::create('c', {3, 1, 3}); + auto expected = NDArrayFactory::create('c', {3, 3, 3}, {2., 4., 6., 5., 7., 9., 8.,10.,12., 14.,16.,18.,17.,19.,21.,20.,22.,24., 26.,28.,30.,29.,31.,33.,32.,34.,36.}); + x.linspace(1); + y.linspace(1); + + auto result = x + y; + + ASSERT_TRUE(expected.isSameShape(&result)); + ASSERT_TRUE(expected.equalsTo(&result)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, tileToShape_test1) { + + auto x = NDArrayFactory::create('c', {2, 2}, {1,2,3,4}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1,2,3,4,1,2,3,4}); + + x.tileToShape({2,2,2}, x); + + ASSERT_TRUE(x.isSameShape(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, tileToShape_test2) { + + auto x = NDArrayFactory::create('c', {2, 1, 2}, {1,2,3,4}); + auto exp = NDArrayFactory::create('c', {2, 3, 2}, {1,2,1,2,1,2,3,4,3,4,3,4}); + + x.tileToShape({2,3,2}, x); + + ASSERT_TRUE(x.isSameShape(&exp)); + ASSERT_TRUE(x.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, tileToShape_test3) { + + auto x = NDArrayFactory::create('c', {2, 2}, {1,2,3,4}); + auto result = NDArrayFactory::create('c', {2, 2, 2}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1,2,3,4,1,2,3,4}); + + x.tileToShape({2,2,2}, result); + // result.printIndexedBuffer(); + + ASSERT_TRUE(result.isSameShape(&exp)); + ASSERT_TRUE(result.equalsTo(&exp)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, tileToShape_test4) { + + auto x = NDArrayFactory::create('c', {2, 1, 2}, {1,2,3,4}); + auto result = NDArrayFactory::create('c', {2, 3, 2}); + auto exp = NDArrayFactory::create('c', {2, 3, 2}, {1,2,1,2,1,2,3,4,3,4,3,4}); + + x.tileToShape({2,3,2}, result); + + ASSERT_TRUE(result.isSameShape(&exp)); + ASSERT_TRUE(result.equalsTo(&exp)); +} + +#ifndef __CUDABLAS__ + +TEST_F(NDArrayTest2, Test_TriplewiseLambda_1) { + auto t = NDArrayFactory::create('c', {3, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1}); + auto u = NDArrayFactory::create('c', {3, 3}, {2, 2, 2, 2, 2, 2, 2, 2, 2}); + auto v = NDArrayFactory::create('c', {3, 3}, {3, 3, 3, 3, 3, 3, 3, 3, 3}); + auto exp = NDArrayFactory::create('c', {3, 3}, {7, 7, 7, 7, 7, 7, 7, 7, 7}); + + float extra = 1.0f; + + auto la = LAMBDA_DDD(_t, _u, _v, extra) { + return _t + _u + _v + extra; + }; + + t.applyTriplewiseLambda(u, v, la, t); + + ASSERT_TRUE(t.equalsTo(&exp)); +} + + +TEST_F(NDArrayTest2, Test_TriplewiseLambda_2) { + auto t = NDArrayFactory::create('c', {3, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1}); + auto u = NDArrayFactory::create('f', {3, 3}, {2, 2, 2, 2, 2, 2, 2, 2, 2}); + auto v = NDArrayFactory::create('c', {3, 3}, {3, 3, 3, 3, 3, 3, 3, 3, 3}); + auto exp = NDArrayFactory::create('c', {3, 3}, {7, 7, 7, 7, 7, 7, 7, 7, 7}); + + float extra = 1.0f; + + auto la = LAMBDA_DDD(_t, _u, _v, extra) { + return _t + _u + _v + extra; + }; + + t.applyTriplewiseLambda(u, v, la, t); + + ASSERT_TRUE(t.equalsTo(&exp)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, Test_Indexed_Lambda) { + auto x = NDArrayFactory::create('c', {2, 2}); + auto exp = NDArrayFactory::create('c', {2, 2}, {0, 1, 2, 3}); + + auto lambda = ILAMBDA_D(_x) { + return (float) _idx; + }; + + x.applyIndexedLambda(lambda, x); + + ASSERT_TRUE(exp.equalsTo(&x)); +} + +#endif + +TEST_F(NDArrayTest2, Test_PermuteEquality_1) { + auto x = NDArrayFactory::create('c', {1, 60}); + auto exp = NDArrayFactory::create('c', {3, 5, 4}, {1.0, 6.0, 11.0, 16.0, 2.0, 7.0, 12.0, 17.0, 3.0, 8.0, 13.0, 18.0, 4.0, 9.0, 14.0, 19.0, 5.0, 10.0, 15.0, 20.0, 21.0, 26.0, 31.0, 36.0, 22.0, 27.0, 32.0, 37.0, 23.0, 28.0, 33.0, 38.0, 24.0, 29.0, 34.0, 39.0, 25.0, 30.0, 35.0, 40.0, 41.0, 46.0, 51.0, 56.0, 42.0, 47.0, 52.0, 57.0, 43.0, 48.0, 53.0, 58.0, 44.0, 49.0, 54.0, 59.0, 45.0, 50.0, 55.0, 60.0}); + x.linspace(1); + x.reshapei('c', {3, 4, 5}); + + x.permutei({0, 2, 1}); + x.streamline(); + +// x.printShapeInfo("{0, 2, 1} shape"); +// x.printBuffer("{0, 2, 1} data"); + + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); +} + +TEST_F(NDArrayTest2, Test_PermuteEquality_0) { + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {3, 4, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0}); + x.reshapei('c', {3, 4, 5}); + + x.permutei({0, 1, 2}); + x.streamline(); + +// x.printShapeInfo("{0, 1, 2} shape"); +// x.printBuffer("{0, 1, 2} data"); + + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); +} + + +TEST_F(NDArrayTest2, Test_PermuteEquality_2) { + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {4, 3, 5}, {1.0, 2.0, 3.0, 4.0, 5.0, 21.0, 22.0, 23.0, 24.0, 25.0, 41.0, 42.0, 43.0, 44.0, 45.0, 6.0, 7.0, 8.0, 9.0, 10.0, 26.0, 27.0, 28.0, 29.0, 30.0, 46.0, 47.0, 48.0, 49.0, 50.0, 11.0, 12.0, 13.0, 14.0, 15.0, 31.0, 32.0, 33.0, 34.0, 35.0, 51.0, 52.0, 53.0, 54.0, 55.0, 16.0, 17.0, 18.0, 19.0, 20.0, 36.0, 37.0, 38.0, 39.0, 40.0, 56.0, 57.0, 58.0, 59.0, 60.0}); + x.reshapei('c', {3, 4, 5}); + + x.permutei({1, 0, 2}); + x.streamline(); + +// x.printShapeInfo("{1, 0, 2} shape"); +// x.printBuffer("{1, 0, 2} data"); + + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); +} + +TEST_F(NDArrayTest2, Test_PermuteEquality_3) { + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {4, 5, 3}, {1.0, 21.0, 41.0, 2.0, 22.0, 42.0, 3.0, 23.0, 43.0, 4.0, 24.0, 44.0, 5.0, 25.0, 45.0, 6.0, 26.0, 46.0, 7.0, 27.0, 47.0, 8.0, 28.0, 48.0, 9.0, 29.0, 49.0, 10.0, 30.0, 50.0, 11.0, 31.0, 51.0, 12.0, 32.0, 52.0, 13.0, 33.0, 53.0, 14.0, 34.0, 54.0, 15.0, 35.0, 55.0, 16.0, 36.0, 56.0, 17.0, 37.0, 57.0, 18.0, 38.0, 58.0, 19.0, 39.0, 59.0, 20.0, 40.0, 60.0}); + x.reshapei('c', {3, 4, 5}); + + x.permutei({1, 2, 0}); + x.streamline(); + +// x.printShapeInfo("{1, 2, 0} shape"); +// x.printBuffer("{1, 2, 0} data"); + + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); +} + +TEST_F(NDArrayTest2, Test_PermuteEquality_4) { + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {5, 3, 4}, {1.0, 6.0, 11.0, 16.0, 21.0, 26.0, 31.0, 36.0, 41.0, 46.0, 51.0, 56.0, 2.0, 7.0, 12.0, 17.0, 22.0, 27.0, 32.0, 37.0, 42.0, 47.0, 52.0, 57.0, 3.0, 8.0, 13.0, 18.0, 23.0, 28.0, 33.0, 38.0, 43.0, 48.0, 53.0, 58.0, 4.0, 9.0, 14.0, 19.0, 24.0, 29.0, 34.0, 39.0, 44.0, 49.0, 54.0, 59.0, 5.0, 10.0, 15.0, 20.0, 25.0, 30.0, 35.0, 40.0, 45.0, 50.0, 55.0, 60.0}); + x.reshapei('c', {3, 4, 5}); + + x.permutei({2, 0, 1}); + x.streamline(); + +// x.printShapeInfo("{2, 0, 1} shape"); +// x.printBuffer("{2, 0, 1} data"); + + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); +} + +TEST_F(NDArrayTest2, Test_PermuteEquality_5) { + auto x = NDArrayFactory::create('c', {1, 60}); + x.linspace(1); + auto exp = NDArrayFactory::create('c', {5, 4, 3}, + {1.0, 21.0, 41.0, 6.0, 26.0, 46.0, 11.0, 31.0, 51.0, 16.0, 36.0, 56.0, 2.0, 22.0, 42.0, 7.0, + 27.0, 47.0, 12.0, 32.0, 52.0, 17.0, 37.0, 57.0, 3.0, 23.0, 43.0, 8.0, 28.0, 48.0, 13.0, 33.0, + 53.0, 18.0, 38.0, 58.0, 4.0, 24.0, 44.0, 9.0, 29.0, 49.0, 14.0, 34.0, 54.0, 19.0, 39.0, 59.0, + 5.0, 25.0, 45.0, 10.0, 30.0, 50.0, 15.0, 35.0, 55.0, 20.0, 40.0, 60.0}); + x.reshapei('c', {3, 4, 5}); + + x.permutei({2, 1, 0}); + x.streamline(); + +// x.printShapeInfo("{2, 0, 1} shape"); +// x.printBuffer("{2, 0, 1} data"); + + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, fillAsTriangular_test1) { + + auto x = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}); + auto exp = NDArrayFactory::create('c', {4, 4}, {1,0,0,0,5,6,0,0,9,10,11,0 ,13,14,15,16}); + + x.fillAsTriangular(0., 0, 0, x, 'u'); + + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, fillAsTriangular_test2) { + + auto x = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}); + auto exp = NDArrayFactory::create('c', {4, 4}, {0,0,0,0,5,0,0,0,9,10,0 ,0 ,13,14,15,0}); + + x.fillAsTriangular(0., 0, -1, x, 'u'); + + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, fillAsTriangular_test3) { + + auto x = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}); + auto exp = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,0,6,7,8,0,0 ,11,12,0 ,0 , 0,16}); + + x.fillAsTriangular(0., 0, 0, x, 'l'); + + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, fillAsTriangular_test4) { + + auto x = NDArrayFactory::create('c', {4, 4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16}); + auto exp = NDArrayFactory::create('c', {4, 4}, {0,2,3,4,0,0,7,8,0,0 , 0,12, 0, 0, 0, 0}); + + x.fillAsTriangular(0., 1, 0, x, 'l'); + + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, Test_DType_Conversion_1) { + auto x = NDArrayFactory::create('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + + auto xd = x.template asT(); + + auto xf = xd.template asT(); + + ASSERT_TRUE(x.isSameShape(xf)); + ASSERT_TRUE(x.equalsTo(xf)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, Test_ScalarArray_Assign_1) { + auto x = NDArrayFactory::create('c', {2, 2}); + auto y = NDArrayFactory::create(2.0f); + auto exp = NDArrayFactory::create('c', {2, 2}, {2.0f, 2.0f, 2.0f, 2.0f}); + + x.assign(y); + + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, Test_Reshape_To_Vector_1) { + auto x = NDArrayFactory::create('c', {2, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto exp = NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + + x.reshapei({-1}); + + ASSERT_TRUE(exp.isSameShape(x)); + ASSERT_TRUE(exp.equalsTo(x)); +} + + +TEST_F(NDArrayTest2, Test_toIndexedString_1) { + auto x = NDArrayFactory::create('c', {2, 2}, {1.5f, 2.5f, 3.f, 4.5f}); + + auto str = x.asIndexedString(); + std::string exp = "[1.5, 2.5, 3, 4.5]"; + + ASSERT_EQ(exp, str); +} + + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, permute_test4) { + + Nd4jLong arr1ShapeInfo[] = {6, 1, 1, 4, 3, 2, 2, 48, 48, 12, 4, 2, 1, 8192, 1, 99}; + Nd4jLong arr2ShapeInfo[] = {6, 1, 2, 2, 1, 4, 3, 48, 2, 1, 48, 12, 4, 8192, 0, 99}; + + + auto arr1Buffer = new float[786432]; + auto arr2Buffer = new float[786432]; + + NDArray arr1(arr1Buffer, arr1ShapeInfo, sd::LaunchContext ::defaultContext()); + NDArray arr2(arr2Buffer, arr2ShapeInfo, sd::LaunchContext ::defaultContext()); + + const std::vector perm = {0, 4, 5, 1, 2, 3}; + auto arr1P = arr1.permute(perm); + // arr1P->printShapeInfo(); + + // ASSERT_TRUE(arr1.isSameShapeStrict(&arr2)); + ASSERT_TRUE(arr1P.isSameShapeStrict(arr2)); + delete []arr1Buffer; + delete []arr2Buffer; +} + +//////////////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, TestStdDev3) { + + // autoarray('c', {10, 10}); + auto array = NDArrayFactory::create('c', {2, 2}, {0.2946, 0.2084, 0.0345, 0.7368}); + const int len = array.lengthOf(); + + double sum = 0.; + for(int i=0; i < len; ++i) + sum += array.e(i); + + const double mean = sum / len; + + double diffSquared = 0.; + for(int i=0; i < len; ++i) + diffSquared += (array.e(i) - mean) * (array.e(i) - mean); + + const double trueVariance = math::nd4j_sqrt(diffSquared / len); + const double trueVarianceCorr = math::nd4j_sqrt(diffSquared / (len - 1)); + + const double variance = array.varianceNumber(variance::SummaryStatsStandardDeviation, false).e(0); + const double varianceCorr = array.varianceNumber(variance::SummaryStatsStandardDeviation, true).e(0); + + // printf("%s expected %.10f calculated %.10f\n","variance :", trueVariance, variance ); + // printf("%s expected %.10f calculated %.10f\n","variance corrected:", trueVarianceCorr, varianceCorr); + + ASSERT_NEAR(trueVariance, variance, 1e-8); + ASSERT_NEAR(trueVarianceCorr, varianceCorr, 1e-8); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, Test_Linspace_1) { + auto exp = NDArrayFactory::create('c',{1,5}, {1., 2., 3., 4., 5.}); + auto x = NDArrayFactory::create('c', {1, 5}); + x.linspace(1); + + ASSERT_TRUE(x.equalsTo(&exp)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, Test_Linspace_2) { + auto exp = NDArrayFactory::create('c',{1,5}, {1., 3., 5., 7., 9.}); + auto x = NDArrayFactory::create('c', {1, 5}); + + x.linspace(1, 2); + + ASSERT_TRUE(x.equalsTo(&exp)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, Test_Linspace_3) { + + auto exp = NDArrayFactory::create('c',{1,5}, {1., 4., 7., 10., 13.}); + + auto x = NDArrayFactory::create('c', {1, 5}); + x.linspace(1,3); + + ASSERT_TRUE(x.equalsTo(&exp)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, Test_Linspace_4) { + auto exp = NDArrayFactory::create('c',{1,5}, {-1., -2., -3., -4., -5.}); + + auto x = NDArrayFactory::create('c', {1, 5}); + x.linspace(-1, -1); + + ASSERT_TRUE(x.equalsTo(&exp)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, Test_Linspace_5) { + auto exp = NDArrayFactory::create('c',{1,5}, {9., 8., 7., 6., 5.}); + + auto x = NDArrayFactory::create('c', {1, 5}); + x.linspace(9, -1); + + ASSERT_TRUE(x.equalsTo(&exp)); +} + + +//////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, allTensorsAlongDimension_test1) { + + auto x = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + auto exp = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + + auto set = x.allTensorsAlongDimension({0}); + // set->at(0)->printShapeInfo(); + // set->at(0)->printIndexedBuffer(); + + ASSERT_TRUE(set.size() == 1); + ASSERT_TRUE(exp.isSameShape(set.at(0))); + ASSERT_TRUE(exp.equalsTo(set.at(0))); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, scalar_get_test1) { + + auto scalar1 = NDArrayFactory::create(20.f); + + NDArray arr('c', {2,2}, {0., 10., 20., 30.}, sd::DataType::FLOAT32); + + NDArray scalar2 = arr.e(2); + + ASSERT_TRUE(scalar1.isSameShape(scalar2)); + ASSERT_TRUE(scalar1.equalsTo(scalar2)); + ASSERT_TRUE(scalar1.dataType() == scalar2.dataType()); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, scalar_get_test2) { + + auto scalar1 = NDArrayFactory::create(20.f); + + NDArray arr('f', {2,2}, {0., 10., 20., 30.}, sd::DataType::FLOAT32); + + NDArray scalar2 = arr.e(1); + + ASSERT_TRUE(scalar1.isSameShape(scalar2)); + ASSERT_TRUE(scalar1.equalsTo(scalar2)); + ASSERT_TRUE(scalar1.dataType() == scalar2.dataType()); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, scalar_set_test1) { + + NDArray scalar1 = NDArrayFactory::create(20.f); + + NDArray arr('c', {2,2}, {0., 10., -20., 30.}, sd::DataType::FLOAT32); + NDArray exp('c', {2,2}, {0., 10., 20., 30.}, sd::DataType::FLOAT32); + + arr.p(2, scalar1); + + ASSERT_TRUE(exp.equalsTo(arr)); +} + + +//////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, scalar_set_test2) { + + NDArray scalar1 = NDArrayFactory::create(20.f); + + NDArray arr('f', {2,2}, {0., 10., -20., 30.}, sd::DataType::FLOAT32); + NDArray exp('f', {2,2}, {0., 10., 20., 30.}, sd::DataType::FLOAT32); + + arr.p(1, scalar1); + + ASSERT_TRUE(exp.equalsTo(arr)); +} + +TEST_F(NDArrayTest2, big_dup_test) { + // auto arr = NDArrayFactory::linspace(1.0f, 10000000.0f, 100000000); + auto arr = NDArrayFactory::linspace(1.0f, 1000.0f, 10000); + auto dup = new NDArray(arr->dup('c')); + + ASSERT_EQ(*arr, *dup); + + delete arr; + delete dup; +} + +TEST_F(NDArrayTest2, debugInfoTest_1) { + NDArray testArray('c', {2, 4, 4, 4}, { + 91., 82., 37., 64., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., + 51., 42., 67., 24., 15., 56., 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., + 31., 22., 87., 44., 55., 46., 73., 28., -119., 12., 112., 13., 14., 114., 16., 117., + 91., -82., 37., 64., -55.1, 0, 73., 28., -119., 12., 112., 13., 14., 114., 16.2, 117., + 91., -82., 37., 64., 55., 46., 73., 28., -119., 12., 112., 13., 14., 114., 16., 117., + 51., 42., 67., 24., 15., 0., 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., + 31., 22., 87., 44., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., + 91., 82., 37., 64., -3, 0, 73., 28., 119., 12., 112., 13., 140., 110., 160., 107.}, sd::DataType::DOUBLE); + NDArray res(sd::DataType::DOUBLE); + DebugInfo info = DebugHelper::debugStatistics(&testArray); + DebugInfo exp; // = {} + sd::ops::reduce_min minOp; + sd::ops::reduce_mean meanOp; + sd::ops::reduce_max maxOp; + sd::ops::reduce_stdev stdevOp; + + minOp.execute({&testArray}, {&res}, {}, {}, {}); + exp._minValue = res.e(0); + meanOp.execute({&testArray}, {&res}, {}, {}, {}); + exp._meanValue = res.e(0); + maxOp.execute({&testArray}, {&res}, {}, {}, {}); + exp._maxValue = res.e(0); + stdevOp.execute({&testArray}, {&res}, {}, {}, {}); + exp._stdDevValue = res.e(0); + exp._zeroCount = 3; + exp._negativeCount = 7; + exp._positiveCount = 118; + exp._infCount = 0; + exp._nanCount = 0; + printf("Output statistics %lf %lf %lf %lf\n", info._minValue, info._maxValue, info._meanValue, info._stdDevValue); + printf("Expect statistics %lf %lf %lf %lf\n", exp._minValue, exp._maxValue, exp._meanValue, exp._stdDevValue); + printf("%lld %lld %lld %lld %lld\n", info._zeroCount, info._negativeCount, info._positiveCount, info._infCount, info._nanCount); + ASSERT_EQ(exp, info); +} + +TEST_F(NDArrayTest2, debugInfoTest_2) { + NDArray testArray('c', {2, 4, 4, 4}, { + 91., 82., 37., 64., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., + 51., 42., 67., 24., 15., 56., 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., + 31., 22., 87., 44., 55., 46., 73., 28., -119., 12., 112., 13., 14., 114., 16., 117., + 91., -82., 37., 64., -55.1, 0, 73., 28., -119., 12., 112., 13., 14., 114., 16.2, 117., + 91., -82., 37., 64., 55., 46., 73., 28., -119., 12., 112., 13., 14., 114., 16., 117., + 51., 42., 67., 24., 15., 0., 93., 28., 109., 82., 12., 113., 114., 14., 116., 11., + 31., 22., 87., 44., 55., 46., 73., 28., 119., 12., 112., 13., 14., 114., 16., 117., + 91., 82., 37., 64., -3, 0, 73., 28., 119., 12., 112., 13., 140., 110., 160., 107.}, sd::DataType::DOUBLE); + + DebugInfo info; + DebugInfo exp; // = {} + exp._minValue = -119; + exp._maxValue = 160.; + exp._meanValue = 51.328906; + exp._stdDevValue = 52.385694; + exp._zeroCount = 3; + exp._negativeCount = 7; + exp._positiveCount = 118; + exp._infCount = 0; + exp._nanCount = 0; + DebugHelper::retrieveDebugStatistics(&info, &testArray); + printf("Output statistics %lf %lf %lf %lf\n", info._minValue, info._maxValue, info._meanValue, info._stdDevValue); + printf("Expect statistics %lf %lf %lf %lf\n", exp._minValue, exp._maxValue, exp._meanValue, exp._stdDevValue); + printf("%lld %lld %lld %lld %lld\n", info._zeroCount, info._negativeCount, info._positiveCount, info._infCount, info._nanCount); + //printf("%lf %lf %lf %lf\n", info._minValue, info._maxValue, info._meanValue, info._stdDevValue); + //printf("%lld %lld %lld %lld %lld\n", info._zeroCount, info._negativeCount, info._positiveCount, info._infCount, info._nanCount); + ASSERT_EQ(exp, info); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, test_subarray_ews_1) { + + NDArray x('c', {10, 5}, sd::DataType::FLOAT32); + auto subArr1 = x.subarray({NDIndex::all(), NDIndex::point(2)}); + + ASSERT_EQ(5, subArr1.ews()); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, test_subarray_ews_2) { + + NDArray x('f', {10, 5}, sd::DataType::FLOAT32); + auto subArr1 = x.subarray({NDIndex::all(), NDIndex::point(2)}); + + ASSERT_EQ(1, subArr1.ews()); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, test_subarray_ews_3) { + + NDArray x('c', {10, 5}, sd::DataType::FLOAT32); + auto subArr1 = x.subarray({NDIndex::point(2), NDIndex::all()}); + + ASSERT_EQ(1, subArr1.ews()); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, test_subarray_ews_4) { + + NDArray x('f', {10, 5}, sd::DataType::FLOAT32); + auto subArr1 = x.subarray({NDIndex::point(2), NDIndex::all()}); + + ASSERT_EQ(10, subArr1.ews()); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, subarray_1) { + + NDArray x('c', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}, sd::DataType::FLOAT32); + NDArray y('f', {2,3,4}, {1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24}, sd::DataType::FLOAT32); + + Nd4jLong shapeExpX0[] = {1, 2, 12, 8192, 12, 99}; + float buffExpX0[] = {1.000000, 13.000000}; + float buffExpX1[] = {2.000000, 14.000000}; + Nd4jLong shapeExpX2[] = {3, 2, 1, 1, 12, 4, 1, 8192, 12, 99}; + float buffExpX2[] = {1.000000, 13.000000}; + Nd4jLong shapeExpX3[] = {2, 2, 4, 12, 1, 8192, 0, 99}; + float buffExpX3[] = {9.000000, 10.000000, 11.000000, 12.000000, 21.000000, 22.000000, 23.000000, 24.000000}; + Nd4jLong shapeExpX4[] = {3, 2, 1, 4, 12, 4, 1, 8192, 0, 99}; + float buffExpX4[] = {9.000000, 10.000000, 11.000000, 12.000000, 21.000000, 22.000000, 23.000000, 24.000000}; + Nd4jLong shapeExpX5[] = {2, 2, 3, 12, 4, 8192, 4, 99}; + float buffExpX5[] = {4.000000, 8.000000, 12.000000, 16.000000, 20.000000, 24.000000}; + + Nd4jLong shapeExpY0[] = {1, 2, 1, 8192, 1, 102}; + float buffExpY0[] = {1.000000, 2.000000}; + float buffExpY1[] = {7.000000, 8.000000}; + Nd4jLong shapeExpY2[] = {3, 2, 1, 1, 1, 2, 6, 8192, 1, 102}; + float buffExpY2[] = {1.000000, 2.000000}; + Nd4jLong shapeExpY3[] = {2, 2, 4, 1, 6, 8192, 0, 102}; + float buffExpY3[] = {5.000000, 11.000000, 17.000000, 23.000000, 6.000000, 12.000000, 18.000000, 24.000000}; + Nd4jLong shapeExpY4[] = {3, 2, 1, 4, 1, 2, 6, 8192, 0, 102}; + float buffExpY4[] = {5.000000, 11.000000, 17.000000, 23.000000, 6.000000, 12.000000, 18.000000, 24.000000}; + Nd4jLong shapeExpY5[] = {2, 2, 3, 1, 2, 8192, 1, 102}; + float buffExpY5[] = {19.000000, 21.000000, 23.000000, 20.000000, 22.000000, 24.000000}; + + + NDArray x0 = x(0, {1,2}); + for(int i = 0; i < shape::shapeInfoLength(x0.rankOf()); ++i) + ASSERT_TRUE(x0.shapeInfo()[i] == shapeExpX0[i]); + for(int i = 0; i < x0.lengthOf(); ++i) + ASSERT_TRUE(x0.e(i) == buffExpX0[i]); + + NDArray x1 = x(1, {1,2}); + for(int i = 0; i < shape::shapeInfoLength(x1.rankOf()); ++i) + ASSERT_TRUE(x1.shapeInfo()[i] == shapeExpX0[i]); + for(int i = 0; i < x1.lengthOf(); ++i) + ASSERT_TRUE(x1.e(i) == buffExpX1[i]); + + NDArray x2 = x(0, {1,2}, true); + for(int i = 0; i < shape::shapeInfoLength(x2.rankOf()); ++i) + ASSERT_TRUE(x2.shapeInfo()[i] == shapeExpX2[i]); + for(int i = 0; i < x2.lengthOf(); ++i) + ASSERT_TRUE(x2.e(i) == buffExpX2[i]); + + NDArray x3 = x(2, {1}); + for(int i = 0; i < shape::shapeInfoLength(x3.rankOf()); ++i) + ASSERT_TRUE(x3.shapeInfo()[i] == shapeExpX3[i]); + for(int i = 0; i < x3.lengthOf(); ++i) + ASSERT_TRUE(x3.e(i) == buffExpX3[i]); + + NDArray x4 = x(2, {1}, true); + for(int i = 0; i < shape::shapeInfoLength(x4.rankOf()); ++i) + ASSERT_TRUE(x4.shapeInfo()[i] == shapeExpX4[i]); + for(int i = 0; i < x4.lengthOf(); ++i) + ASSERT_TRUE(x4.e(i) == buffExpX4[i]); + + NDArray x5 = x(3, {2}); + for(int i = 0; i < shape::shapeInfoLength(x5.rankOf()); ++i) + ASSERT_TRUE(x5.shapeInfo()[i] == shapeExpX5[i]); + for(int i = 0; i < x5.lengthOf(); ++i) + ASSERT_TRUE(x5.e(i) == buffExpX5[i]); + + // ******************* // + NDArray y0 = y(0, {1,2}); + for(int i = 0; i < shape::shapeInfoLength(y0.rankOf()); ++i) + ASSERT_TRUE(y0.shapeInfo()[i] == shapeExpY0[i]); + for(int i = 0; i < y0.lengthOf(); ++i) + ASSERT_TRUE(y0.e(i) == buffExpY0[i]); + + NDArray y1 = y(1, {1,2}); + for(int i = 0; i < shape::shapeInfoLength(y1.rankOf()); ++i) + ASSERT_TRUE(y1.shapeInfo()[i] == shapeExpY0[i]); + for(int i = 0; i < y1.lengthOf(); ++i) + ASSERT_TRUE(y1.e(i) == buffExpY1[i]); + + NDArray y2 = y(0, {1,2}, true); + for(int i = 0; i < shape::shapeInfoLength(y2.rankOf()); ++i) + ASSERT_TRUE(y2.shapeInfo()[i] == shapeExpY2[i]); + for(int i = 0; i < y2.lengthOf(); ++i) + ASSERT_TRUE(y2.e(i) == buffExpY2[i]); + + NDArray y3 = y(2, {1}); + for(int i = 0; i < shape::shapeInfoLength(y3.rankOf()); ++i) + ASSERT_TRUE(y3.shapeInfo()[i] == shapeExpY3[i]); + for(int i = 0; i < y3.lengthOf(); ++i) + ASSERT_TRUE(y3.e(i) == buffExpY3[i]); + + NDArray y4 = y(2, {1}, true); + for(int i = 0; i < shape::shapeInfoLength(y4.rankOf()); ++i) + ASSERT_TRUE(y4.shapeInfo()[i] == shapeExpY4[i]); + for(int i = 0; i < y4.lengthOf(); ++i) + ASSERT_TRUE(y4.e(i) == buffExpY4[i]); + + NDArray y5 = y(3, {2}); + for(int i = 0; i < shape::shapeInfoLength(y5.rankOf()); ++i) + ASSERT_TRUE(y5.shapeInfo()[i] == shapeExpY5[i]); + for(int i = 0; i < y5.lengthOf(); ++i) + ASSERT_TRUE(y5.e(i) == buffExpY5[i]); + +} + +TEST_F(NDArrayTest2, test_subarray_interval_1) { + + NDArray x('f', {10, 10}, sd::DataType::FLOAT32); + auto subArr1 = x.subarray({NDIndex::all(), NDIndex::interval(0,9)}); + + ASSERT_EQ(10, subArr1.sizeAt(0)); + ASSERT_EQ(9, subArr1.sizeAt(1)); +} + +TEST_F(NDArrayTest2, test_subarray_interval_2) { + + NDArray x('c', {10, 10}, sd::DataType::FLOAT32); + auto subArr1 = x.subarray({NDIndex::all(), NDIndex::interval(0,9)}); + + ASSERT_EQ(10, subArr1.sizeAt(0)); + ASSERT_EQ(9, subArr1.sizeAt(1)); +} + +TEST_F(NDArrayTest2, test_subarray_3d_cf) { + NDArray f('f', {10, 20, 30}, sd::DataType::FLOAT32); + NDArray c('c', {10, 20, 30}, sd::DataType::FLOAT32); + + auto subarrayF = f({0,0, 0,0, 2,3}, true); + + auto subarrayC = c({2,3, 0,0, 0,0}, true); +} + +TEST_F(NDArrayTest2, test_broadcast_row_1) { + auto x = NDArrayFactory::create('c', {10, 5}); + auto y = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); + auto e = NDArrayFactory::create('c', {10, 5}); + e.assign(1.0f); + + x += y; + + ASSERT_EQ(e, x); +} + +TEST_F(NDArrayTest2, test_broadcast_column_1) { + auto x = NDArrayFactory::create('c', {5, 10}); + auto y = NDArrayFactory::create('c', {5, 1}, {1.f, 1.f, 1.f, 1.f, 1.f}); + auto e = NDArrayFactory::create('c', {5, 10}); + e.assign(1.0f); + + x += y; + + ASSERT_EQ(e, x); +} + +TEST_F(NDArrayTest2, test_broadcast_column_2) { + auto x = NDArrayFactory::create('c', {5, 10}); + auto y = NDArrayFactory::create('c', {5, 1}, {1.f, 1.f, 1.f, 1.f, 1.f}); + auto e = NDArrayFactory::create('c', {5, 10}); + e.assign(1.0f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Add(), y, x, false); + + ASSERT_EQ(e, x); +} + +TEST_F(NDArrayTest2, test_broadcast_column_3) { + auto x = NDArrayFactory::create('c', {5, 10}); + auto y = NDArrayFactory::create('c', {5, 1}, {1.f, 1.f, 1.f, 1.f, 1.f}); + auto e = NDArrayFactory::create('c', {5, 10}); + e.assign(1.0f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Add(), y, x); + + ASSERT_EQ(e, x); +} + +TEST_F(NDArrayTest2, test_broadcast_column_4) { + auto x = NDArrayFactory::create('f', {10, 5}); + auto y = NDArrayFactory::create('f', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); + auto e = NDArrayFactory::create('f', {10, 5}); + e.assign(1.0f); + + x.applyTrueBroadcast(BroadcastOpsTuple::Add(), y, x); + + ASSERT_EQ(e, x); +} + +TEST_F(NDArrayTest2, test_not_tiled_1) { + auto x = NDArrayFactory::create('c', {4, 12, 128, 128}); + auto y = NDArrayFactory::create('c', {4, 1, 128, 128}); + auto e = NDArrayFactory::create('c', {4, 12, 128, 128}); + y.assign(1.0f); + e.assign(1.0f); + + x += y; + + ASSERT_EQ(e, x); +} + +TEST_F(NDArrayTest2, test_not_tiled_2) { + auto x = NDArrayFactory::create('c', {4, 128, 768}); + auto y = NDArrayFactory::create('c', {4, 128, 1}); + auto e = NDArrayFactory::create('c', {4, 128, 768}); + y.assign(1.0f); + e.assign(1.0f); + + x += y; + + ASSERT_EQ(e, x); +} + +TEST_F(NDArrayTest2, test_long_sum_1) { + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + + auto z = x.reduceAlongDimension(reduce::Sum, {0}); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, reshapei_1) { + + Nd4jLong shapeInfo1[] = {6, 2,1,2,1,7,1, 7,7,14,28,1,1, 8192, 0, 99}; + Nd4jLong shapeInfo2[] = {2, 4, 7, 7, 1, 8192, 1, 99}; + + auto buffer = new float[shape::length(shapeInfo1)]; + NDArray x(buffer, shapeInfo1); + + const bool canReshape = x.reshapei({4,7}); + + ASSERT_FALSE(canReshape); + ASSERT_TRUE(shape::equalsStrict(x.shapeInfo(), shapeInfo2)); + + delete[] buffer; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, reshapei_2) { + + Nd4jLong shapeInfo1[] = {6, 1,2,1,2,7,1, 28,7,7,14,1,1, 8192, 0, 99}; + Nd4jLong shapeInfo2[] = {2, 4, 7, 7, 1, 8192, 1, 99}; + + auto buffer = new float[shape::length(shapeInfo1)]; + NDArray x(buffer, shapeInfo1); + + const bool canReshape = x.reshapei({4,7}); + + ASSERT_FALSE(canReshape); + ASSERT_TRUE(shape::equalsStrict(x.shapeInfo(), shapeInfo2)); + + delete[] buffer; +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, trueBroadcast_1) { + + NDArray x('f', {2, 3}, {1., 2., 3., 4., 5., 6.}); + NDArray y('f', {1, 3}, {5., 4., 3.}); + NDArray z('c', {2, 3}, sd::DataType::DOUBLE); + + auto exp = x - y; + x.applyTrueBroadcast(sd::BroadcastOpsTuple::Subtract(), y, z); + + // exp.printIndexedBuffer(); + // z.printIndexedBuffer(); + + ASSERT_TRUE(exp.equalsTo(z)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, reduce_1) { + + NDArray arr6('f', {1, 1, 4, 4, 4, 4}, sd::DataType::DOUBLE); + NDArray exp('f', {1, 1, 4, 4}, sd::DataType::DOUBLE); + + arr6.linspace(1); + + NDArray arr6s = arr6.reduceAlongDimension(sd::reduce::Sum, {2,3}); + + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) { + double sum = 0; + for (int x = 0; x < 4; x++) { + for (int y = 0; y < 4; y++) { + Nd4jLong indices[] = {0, 0, x, y, i, j}; + Nd4jLong offset = shape::getOffset(arr6.shapeInfo(), indices); + sum += ((double*)arr6.buffer())[offset]; + } + } + exp.p(0, 0, i, j, sum); + } + } + + // arr6s->printShapeInfo(); + // exp.printShapeInfo(); + // exp.printIndexedBuffer(); + // arr6s->printIndexedBuffer(); + + ASSERT_TRUE(exp.equalsTo(arr6s)); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(NDArrayTest2, reduce3_1) { + + NDArray x('c', {1,4}, {1,2,3,4}); + NDArray y('c', {1,4}, {2,3,4,5}); + NDArray exp('c', {4}, {1,1,1,1}); + + NDArray z = x.applyReduce3(sd::reduce3::EuclideanDistance, y, {0}, nullptr); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} + +TEST_F(NDArrayTest2, all_tads_1) { + auto x = NDArrayFactory::create('c', {3, 5}); + + auto arrays = x.allTensorsAlongDimension({1}); + ASSERT_EQ(3, arrays.size()); +} + +TEST_F(NDArrayTest2, test_trueBroadcast_empty_1) { + auto x = NDArrayFactory::create('c', {0, 2}); + auto y = NDArrayFactory::create('c', {1, 2}); + + auto z = x + y; + + ASSERT_EQ(x, z); +} + +TEST_F(NDArrayTest2, test_trueBroadcast_empty_2) { + auto x = NDArrayFactory::create('c', {0, 2}); + auto y = NDArrayFactory::create('c', {1, 2}); + + auto z = y + x; + + ASSERT_EQ(x, z); +} + +TEST_F(NDArrayTest2, test_subarray_followed_by_reshape_1) { + + NDArray x('c', {5, 1, 3}, sd::DataType::FLOAT32); + NDArray e('c', {1, 3}, {7.f, 8.f, 9.f}, sd::DataType::FLOAT32); + + x.linspace(1.); + + auto s = x({2,3, 0,0, 0,0}); + + // s.printIndexedBuffer("s"); + + auto r = s.reshape(x.ordering(), {1, 3}); + // r.printIndexedBuffer("r"); + + ASSERT_EQ(e, r); +} + +TEST_F(NDArrayTest2, test_numpy_import_1) { + std::string fname("./resources/arr_3,4_float32.npy"); + auto exp = NDArrayFactory::create('c', {3, 4}); + exp.linspace(0); + + auto array = NDArrayFactory::fromNpyFile(fname.c_str()); + + ASSERT_EQ(exp, array); +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/NativeOpsTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/NativeOpsTests.cpp new file mode 100644 index 000000000..29c56c727 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/NativeOpsTests.cpp @@ -0,0 +1,1612 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by GS on 22.07.2019. +// + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +using namespace sd; +using namespace sd::ops; + +class NativeOpsTests : public testing::Test { +public: + +}; + + +TEST_F(NativeOpsTests, CreateContextTests_1) { +// auto x = NDArrayFactory::create('c', {5, 5}); +// x.assign(1.0); +// auto z = NDArrayFactory::create('c', {5,5}); +// auto exp = NDArrayFactory::create('c', {5, 5}); + auto context = ::createContext(); + ASSERT_TRUE(context == nullptr); + //delete context; +} + +TEST_F(NativeOpsTests, CreateContextTests_2) { +// auto x = NDArrayFactory::create('c', {5, 5}); +// x.assign(1.0); +// auto z = NDArrayFactory::create('c', {5,5}); +// auto exp = NDArrayFactory::create('c', {5, 5}); + auto context1 = ::createContext(); + auto context2 = ::createContext(); + ASSERT_TRUE(context1 == context2); + //delete context1; + //delete context2; +} + +TEST_F(NativeOpsTests, PointerTests_1) { + auto x = NDArrayFactory::create('c', {5}, {1,2,3,4,5}); +// x.linspace(1.0); +#ifdef __CUDABLAS__ +printf("Unsupported for cuda now.\n"); +#else + ::tryPointer(nullptr, x.buffer(), 4); +#endif + +// auto exp = NDArrayFactory::create('c', {5, 5}); +// exp.assign(-1.0); +// +// sd::ops::LegacyTransformSameOp op(transform::Neg); // Neg +// auto result = op.execute({&x}, {}, {}); +// +// ASSERT_EQ(1, result->size()); +// +// auto z = result->at(0); +// +// ASSERT_TRUE(exp.equalsTo(z)); +// +// delete result; +} + +TEST_F(NativeOpsTests, ThresholdTests_1) { +// auto x = NDArrayFactory::create('c', {5}, {1,2,3,4,5}); +// x.linspace(1.0); +#ifdef __CUDABLAS__ + printf("Unsupported for cuda now.\n"); +#else + ::setElementThreshold(4); + ASSERT_TRUE(4 == sd::Environment::getInstance().elementwiseThreshold()); +#endif + +} + +TEST_F(NativeOpsTests, ThresholdTests_2) { +// auto x = NDArrayFactory::create('c', {5}, {1,2,3,4,5}); +// x.linspace(1.0); +#ifdef __CUDABLAS__ + printf("Unsupported for cuda now.\n"); +#else + ::setTADThreshold(4); + ASSERT_TRUE(4 == sd::Environment::getInstance().tadThreshold()); +#endif + +} + +TEST_F(NativeOpsTests, ExecIndexReduce_1) { + auto x = NDArrayFactory::create('c', {5}, {1,2,3,4,5}); + auto exp = NDArrayFactory::create(120); + x.linspace(1.0); +#ifdef __CUDABLAS__ + printf("Unsupported for cuda now.\n"); +#else + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execIndexReduceScalar(nullptr, + indexreduce::IndexMax, + &xBuf, x.shapeInfo(), + nullptr, + nullptr, + &expBuf, exp.shapeInfo(), + nullptr); + + ASSERT_TRUE(exp.e(0) == 4LL); +#endif + +} + +TEST_F(NativeOpsTests, ExecIndexReduce_2) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(120); + x.linspace(1.0); +#ifdef __CUDABLAS__ + printf("Unsupported for cuda now.\n"); +#else + NDArray dimension = NDArrayFactory::create({}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimensionBuf(dimension.dataBuffer()); + + ::execIndexReduce(nullptr, + indexreduce::IndexMax, + &xBuf, x.shapeInfo(), nullptr, + nullptr, + &expBuf, exp.shapeInfo(), + nullptr, + &dimensionBuf, dimension.shapeInfo(), + nullptr); + + ASSERT_TRUE(exp.e(0) == 24LL); +#endif + +} + +TEST_F(NativeOpsTests, ExecBroadcast_1) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 1}); + auto exp = NDArrayFactory::create('c', {5, 5}); + x.linspace(1.0); + y.linspace(2,2); +#ifdef __CUDABLAS__ + printf("Unsupported for cuda now.\n"); +#else + auto dimension = NDArrayFactory::create('c', {1}, {1}); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + + ::execBroadcast(nullptr, + broadcast::Add, + &xBuf, x.shapeInfo(), + nullptr, + &yBuf, y.shapeInfo(), + nullptr, + &expBuf, exp.shapeInfo(), + nullptr, + &dimBuf, dimension.shapeInfo(), + nullptr); + + ASSERT_TRUE(exp.e(0) == 3.); +#endif + +} + +TEST_F(NativeOpsTests, ExecBroadcast_2) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 1}); + auto exp = NDArrayFactory::create('c', {5, 5}); + x.linspace(1.0); + y.linspace(2,2); +#ifdef __CUDABLAS__ +printf("Unsupported for cuda now.\n"); +#else + int dimd = 0; + auto dimension = NDArrayFactory::create('c', {1}, {dimd}); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + + ::execBroadcastBool(nullptr, + broadcast::EqualTo, + &xBuf, x.shapeInfo(), nullptr, + &yBuf, y.shapeInfo(), nullptr, + &expBuf, exp.shapeInfo(), nullptr, nullptr, + &dimBuf, dimension.shapeInfo(), + nullptr); + ASSERT_TRUE(exp.e(1) && !exp.e(0)); +#endif + +} + +TEST_F(NativeOpsTests, ExecPairwise_1) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + x.linspace(1.0); + y.assign(2.); +#ifdef __CUDABLAS__ + printf("Unsupported for cuda now.\n"); +#else + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execPairwiseTransform(nullptr, + pairwise::Add, + &xBuf, x.shapeInfo(), nullptr, + &yBuf, y.shapeInfo(), nullptr, + &expBuf, exp.shapeInfo(), nullptr, + nullptr); + ASSERT_TRUE(exp.e(5) == 8.); +#endif + +} + +TEST_F(NativeOpsTests, ExecPairwise_2) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + x.assign(true); + y.assign(false); + y.r(5) = true; +#ifdef __CUDABLAS__ + printf("Unsupported for cuda now.\n"); +#else + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execPairwiseTransformBool(nullptr, + pairwise::And, + &xBuf, x.shapeInfo(), nullptr, + &yBuf, y.shapeInfo(), nullptr, + &expBuf, exp.shapeInfo(), nullptr, + nullptr); + ASSERT_TRUE(exp.e(5) && !exp.e(4)); +#endif + +} + +TEST_F(NativeOpsTests, ReduceTest_1) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(120.); + x.linspace(1.0); + +#ifdef __CUDABLAS__ + printf("Unsupported for cuda now.\n"); +#else + auto dimension = NDArrayFactory::create('c', {1}, {1}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execReduceFloat(nullptr, + reduce::Mean, + &xBuf, x.shapeInfo(), nullptr, + nullptr, + &expBuf, exp.shapeInfo(), nullptr); +// x.printIndexedBuffer("Input"); +// exp.printIndexedBuffer("Reduce Mean"); + ASSERT_TRUE(exp.e(0) == 13.); +#endif + +} + +TEST_F(NativeOpsTests, ReduceTest_2) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(120.); + x.linspace(1.0); + +#ifdef __CUDABLAS__ + printf("Unsupported for cuda now.\n"); +#else + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execReduceSame(nullptr, + reduce::Sum, + &xBuf, x.shapeInfo(), nullptr, + nullptr, + &expBuf, exp.shapeInfo(), nullptr); +// x.printIndexedBuffer("Input"); +// exp.printIndexedBuffer("Reduce Sum"); + ASSERT_TRUE(exp.e(0) == 325.); +#endif + +} + +TEST_F(NativeOpsTests, ReduceTest_3) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(false); + x.linspace(1.0); + +#ifdef __CUDABLAS__ + printf("Unsupported for cuda now.\n"); +#else + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execReduceBool(nullptr, + reduce::All, + &xBuf, x.shapeInfo(), nullptr, + nullptr, + &expBuf, exp.shapeInfo(), nullptr); +// x.printIndexedBuffer("Input"); +// exp.printIndexedBuffer("Reduce All"); + ASSERT_TRUE(exp.e(0) == true); +#endif + +} + +TEST_F(NativeOpsTests, ReduceTest_4) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(120LL); + x.linspace(1.0); + +#ifdef __CUDABLAS__ + printf("Unsupported for cuda now.\n"); +#else + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execReduceLong(nullptr, + reduce::CountNonZero, + &xBuf, x.shapeInfo(), nullptr, + nullptr, + &expBuf, exp.shapeInfo(), nullptr); +// x.printIndexedBuffer("Input"); +// exp.printIndexedBuffer("Reduce CountNonZero"); + ASSERT_TRUE(exp.e(0) == 25LL); +#endif + +} + +TEST_F(NativeOpsTests, ReduceTest_5) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(120LL); + x.linspace(1.0); + +#ifdef __CUDABLAS__ + printf("Unsupported for cuda now.\n"); +#else + auto dimension = NDArrayFactory::create({0, 1}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + + ::execReduceLong2(nullptr, + reduce::CountNonZero, + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + nullptr, + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo()); +// x.printIndexedBuffer("Input"); +// exp.printIndexedBuffer("Reduce CountNonZero"); + ASSERT_TRUE(exp.e(0) == 25LL); +#endif + +} + +TEST_F(NativeOpsTests, ReduceTest_6) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create({5, 4, 3, 2, 1}); + auto exp = NDArrayFactory::create({1,2,3,4,6}); + x.linspace(1.0); + +#ifdef __CUDABLAS__ + printf("Unsupported for cuda now.\n"); +#else + auto dimension = NDArrayFactory::create('c', {1}, {1}); + x.p(5, 0); + x.p(10, 0); x.p(11, 0); + x.p(15, 0); x.p(16, 0); x.p(17, 0); + x.p(20, 0); x.p(21, 0); x.p(22, 0); x.p(23, 0); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execReduceLong2(nullptr, + reduce::CountNonZero, + &xBuf, x.shapeInfo(), nullptr, + nullptr, + &expBuf, exp.shapeInfo(), nullptr, + &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo()); +// x.printIndexedBuffer("Input"); +// exp.printIndexedBuffer("Reduce CountNonZero"); + ASSERT_TRUE(exp.equalsTo(z)); +#endif + +} + +TEST_F(NativeOpsTests, ReduceTest_7) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(120.); + auto z = NDArrayFactory::create(13.); + + + auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); + Nd4jPointer extra[6]; +#ifdef __CUDABLAS__ + x.syncToHost(); + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; +#endif + x.linspace(1.0); + x.syncToDevice(); + dimension.syncToHost(); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execReduceFloat2(extra, + reduce::Mean, + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + nullptr, + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo()); +// x.printIndexedBuffer("Input"); +// exp.printIndexedBuffer("Reduce Mean"); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +TEST_F(NativeOpsTests, ReduceTest_8) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create(120.); + auto exp = NDArrayFactory::create(325.); + + + auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); + Nd4jPointer extra[6]; +#ifdef __CUDABLAS__ + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); +#endif + x.linspace(1.0); + x.syncToDevice(); + + dimension.syncToHost(); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + + ::execReduceSame2(extra, + reduce::Sum, + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + nullptr, + &zBuf, z.shapeInfo(), z.specialShapeInfo(), + &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo()); +// x.printIndexedBuffer("Input"); +// exp.printIndexedBuffer("Reduce Sum"); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +TEST_F(NativeOpsTests, ReduceTest_9) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(false); + auto z = NDArrayFactory::create(true); + + auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); + Nd4jPointer extra[6]; +#ifdef __CUDABLAS__ + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); +#endif + x.linspace(1.0); + x.syncToDevice(); + + dimension.syncToHost(); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execReduceBool2(extra, + reduce::All, + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + nullptr, + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo()); +// x.printIndexedBuffer("Input"); +// exp.printIndexedBuffer("Reduce All"); + ASSERT_TRUE(exp.equalsTo(z)); +} + +TEST_F(NativeOpsTests, Reduce3Test_1) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(120.); + auto z = NDArrayFactory::create(650.); + + auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); + Nd4jPointer extra[6]; +#ifdef __CUDABLAS__ + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; +#endif + x.linspace(1.0); + y.assign(2.); + x.syncToDevice(); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execReduce3(extra, + reduce3::Dot, + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + nullptr, + &yBuf, y.shapeInfo(), y.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo()); + //z.printIndexedBuffer("Z"); + //exp.printIndexedBuffer("Reduce3 Dot"); + ASSERT_TRUE(exp.equalsTo(z)); +} + +TEST_F(NativeOpsTests, Reduce3Test_2) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(120.); + auto z = NDArrayFactory::create(650.); + + auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); + Nd4jPointer extra[6]; +#ifdef __CUDABLAS__ + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; +#endif + x.linspace(1.0); + y.assign(2.); + x.syncToDevice(); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execReduce3Scalar(extra, + reduce3::Dot, + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + nullptr, + &yBuf, y.shapeInfo(), y.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo()); +// x.printIndexedBuffer("Input"); +// exp.printIndexedBuffer("Reduce3 Dot"); + ASSERT_TRUE(exp.equalsTo(z)); +} + +TEST_F(NativeOpsTests, Reduce3Test_3) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(120.); + auto z = NDArrayFactory::create(650.); + + auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); + Nd4jPointer extra[6]; +#ifdef __CUDABLAS__ + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; +#endif + x.linspace(1.0); + y.assign(2.); + x.syncToDevice(); + dimension.syncToHost(); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + + ::execReduce3Tad(extra, + reduce3::Dot, + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + nullptr, + &yBuf, y.shapeInfo(), y.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo(), + nullptr, nullptr, nullptr, nullptr); +// x.printIndexedBuffer("Input"); +// exp.printIndexedBuffer("Reduce All"); + ASSERT_TRUE(exp.equalsTo(z)); +} + +TEST_F(NativeOpsTests, Reduce3Test_4) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create(120.); + auto z = NDArrayFactory::create(650.); + + auto dimension = NDArrayFactory::create('c', {2}, {0, 1}); + Nd4jPointer extra[6]; +#ifdef __CUDABLAS__ + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; +#endif + x.linspace(1.0); + y.assign(2.); + x.syncToDevice(); + dimension.syncToHost(); + int* dimensions = reinterpret_cast(dimension.buffer()); + auto tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), dimensions, dimension.lengthOf()); + auto tadPackY = sd::ConstantTadHelper::getInstance().tadForDimensions(y.shapeInfo(), dimensions, dimension.lengthOf()); + + auto hTADShapeInfoX = tadPackX.primaryShapeInfo(); + auto hTADOffsetsX = tadPackX.primaryOffsets(); + auto hTADShapeInfoY = tadPackY.primaryShapeInfo(); + auto hTADOffsetsY = tadPackY.primaryOffsets(); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + + ::execReduce3All(extra, + reduce3::Dot, + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + nullptr, + &yBuf, y.shapeInfo(), y.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo(), + hTADShapeInfoX, hTADOffsetsX, hTADShapeInfoY, hTADOffsetsY); +// x.printIndexedBuffer("Input"); +// exp.printIndexedBuffer("Reduce All"); + ASSERT_TRUE(exp.equalsTo(z)); +} + +TEST_F(NativeOpsTests, ScalarTest_1) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create(10.); + auto exp = NDArrayFactory::create('c', {5,5}); + auto z = NDArrayFactory::create('c', {5,5}); + + Nd4jPointer extra[6]; +#ifdef __CUDABLAS__ + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + y.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; +#endif + x.linspace(1.0); + z.linspace(10., 10.); + //y.assign(2.); + x.syncToDevice(); + z.syncToDevice(); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execScalar(extra, + scalar::Multiply, + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + &yBuf, y.shapeInfo(), y.specialShapeInfo(), nullptr); +// x.printIndexedBuffer("Input"); +// exp.printIndexedBuffer("Reduce All"); + ASSERT_TRUE(exp.equalsTo(z)); +} + +TEST_F(NativeOpsTests, ScalarTest_2) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create(10.f); + auto exp = NDArrayFactory::create('c', {5,5}); + auto z = NDArrayFactory::create('c', {5,5}); + + Nd4jPointer extra[6]; +#ifdef __CUDABLAS__ + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + y.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; +#endif + x.linspace(1.0); + z.assign(false); + //y.assign(2.); + x.syncToDevice(); + z.syncToDevice(); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execScalarBool(extra, + scalar::GreaterThan, + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + &yBuf, y.shapeInfo(), y.specialShapeInfo(), nullptr); +// x.printIndexedBuffer("Input"); +// exp.printIndexedBuffer("Reduce All"); + ASSERT_TRUE(exp.e(5) == z.e(5) && exp.e(15) != z.e(15)); +} + +TEST_F(NativeOpsTests, SummaryStatsScalarTest_1) { + auto x = NDArrayFactory::create('c', {5, 5}, {0.1f, 0.2f, 0.3f, -0.3f, -0.5f, 0.5f, 0.7f, 0.9f, 0.8f, 0.1f, 0.11f, 0.12f, 0.5f, -0.8f, -0.9f, 0.4f, 0.1f, 0.2f, 0.3f, -0.3f, -0.5f, 0.2f, 0.3f, -0.3f, -0.5f}); + auto exp = NDArrayFactory::create(0.9f); + auto z = NDArrayFactory::create(0.21587136f); + + Nd4jPointer extra[6]; +#ifdef __CUDABLAS__ + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; +#endif + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execSummaryStatsScalar(extra, + variance::SummaryStatsVariance, + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + nullptr, + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), false); +// x.printIndexedBuffer("Input"); +// exp.printIndexedBuffer("Standard Variance"); + ASSERT_TRUE(exp.equalsTo(z)); +} + +TEST_F(NativeOpsTests, SummaryStatsScalarTest_2) { + auto x = NDArrayFactory::create('c', {5, 5}, {0.1, 0.2, 0.3, -0.3, -0.5, 0.5, 0.7, 0.9, 0.8, 0.1, 0.11, 0.12, 0.5, -0.8, -0.9, 0.4, 0.1, 0.2, 0.3, -0.3, -0.5, 0.2, 0.3, -0.3, -0.5}); + auto exp = NDArrayFactory::create(0.9); + auto z = NDArrayFactory::create(0.21587136); + + Nd4jPointer extra[6]; +#ifdef __CUDABLAS__ + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; +#endif + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + ::execSummaryStats(extra, + variance::SummaryStatsVariance, + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + nullptr, + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), false); +// x.printIndexedBuffer("Input"); +// exp.printIndexedBuffer("Standard Variance"); + ASSERT_TRUE(exp.equalsTo(z)); +} + +TEST_F(NativeOpsTests, SummaryStatsScalarTest_3) { + auto x = NDArrayFactory::create('c', {5, 5}, {0.1, 0.2, 0.3, -0.3, -0.5, 0.5, 0.7, 0.9, 0.8, 0.1, 0.11, 0.12, 0.5, -0.8, -0.9, 0.4, 0.1, 0.2, 0.3, -0.3, -0.5, 0.2, 0.3, -0.3, -0.5}); + auto exp = NDArrayFactory::create(0.9); + auto z = NDArrayFactory::create(0.21587136); + + Nd4jPointer extra[6]; +#ifdef __CUDABLAS__ + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; +#endif + auto dimensions = NDArrayFactory::create({0, 1}); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimensions.dataBuffer()); + + ::execSummaryStatsTad(extra, + variance::SummaryStatsVariance, + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + nullptr, + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + &dimBuf, dimensions.shapeInfo(), dimensions.specialShapeInfo(), + false, + nullptr, nullptr); +// x.printIndexedBuffer("Input"); +// exp.printIndexedBuffer("Standard Variance"); + ASSERT_TRUE(exp.equalsTo(z)); +} + +TEST_F(NativeOpsTests, TransformTest_1) { + auto x = NDArrayFactory::create('c', {5, 5}, {1, 4, 9, 16, 25, 36, 49, 64, 81, 100, 121, 144, 169, 196, 225, 256, 289, 324, 361, 400, 441, 484, 529, 576, 625}); + auto exp = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5,5}); + + Nd4jPointer extra[6]; +#ifdef __CUDABLAS__ + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; +#endif + z.linspace(1.); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execTransformFloat(extra, + transform::Sqrt, + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + nullptr); +// x.printIndexedBuffer("Input"); +// exp.printIndexedBuffer("Sqrt is"); + ASSERT_TRUE(exp.equalsTo(z)); +} + +TEST_F(NativeOpsTests, TransformTest_2) { + auto x = NDArrayFactory::create('c', {5, 5}, {1.f, 4.f, 9.f, 16.f, 25.f, 36.f, 49.f, 64.f, 81.f, 100.f, 121.f, 144.f, 169.f, 196.f, 225.f, 256.f, 289.f, 324.f, 361.f, 400.f, 441.f, 484.f, 529.f, 576.f, 625.f}); + auto exp = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5,5}); + + Nd4jPointer extra[6]; +#ifdef __CUDABLAS__ + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; +#endif + z.linspace(1.); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execTransformSame(extra, + transform::Square, + &zBuf, z.shapeInfo(), z.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + nullptr); +// x.printIndexedBuffer("Input"); +// exp.printIndexedBuffer("Square is"); + ASSERT_TRUE(exp.equalsTo(x)); +} + +TEST_F(NativeOpsTests, TransformTest_3) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5,5}); + + Nd4jPointer extra[6]; +#ifdef __CUDABLAS__ + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; +#endif + x.linspace(1.); + z.assign(true); + x.p(24, -25); + z.p(24, false); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execTransformBool(extra, + transform::IsPositive, + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + nullptr); +// x.printIndexedBuffer("Input"); +// exp.printIndexedBuffer("IsPositive"); + ASSERT_TRUE(exp.equalsTo(z)); +} + +TEST_F(NativeOpsTests, TransformTest_4) { + auto x = NDArrayFactory::create('c', {5, 5}, {0, 1, 2, 3, 2, 1, 0, 1.57, 1.57, 1.57, 3.141592, 3.141592, + 3.141592, 0, 0, 0, 0, 1, 1, 2, 2, 2, 1, 0, 0}); + auto exp = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5,5}, {1., 0.540302, -0.416147, -0.989992, -0.416147, 0.540302, 1.0, + 0.000796, 0.000796, 0.000796, -1, -1, -1, 1., 1., 1.0, 1.0, + 0.540302, 0.540302, -0.416147, -0.416147, -0.416147, 0.540302, 1., 1.}); + + Nd4jPointer extra[6]; +#ifdef __CUDABLAS__ + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; +#endif + //z.linspace(1.); + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + + ::execTransformStrict(extra, + transform::Cosine, + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + nullptr); +// x.printIndexedBuffer("Input"); +// exp.printIndexedBuffer("Cosine"); + ASSERT_TRUE(exp.equalsTo(z)); +} + +TEST_F(NativeOpsTests, ScalarTadTest_1) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create(10.f); + auto exp = NDArrayFactory::create('c', {5,5}); + auto z = NDArrayFactory::create('c', {5,5}); + + Nd4jPointer extra[6]; +#ifdef __CUDABLAS__ + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + y.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; +#endif + x.linspace(1.0); + z.linspace(10., 10.); + //y.assign(2.); + x.syncToDevice(); + z.syncToDevice(); + auto dimension = NDArrayFactory::create({0, 1}); + auto dimensions = reinterpret_cast(dimension.buffer()); + auto tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), dimensions, dimension.lengthOf()); + auto tadPackZ = sd::ConstantTadHelper::getInstance().tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf()); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + + ::execScalarTad(extra, + scalar::Multiply, + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + &expBuf, exp.shapeInfo(), exp.specialShapeInfo(), + &yBuf, y.shapeInfo(), y.specialShapeInfo(), + nullptr, + &dimBuf, dimension.shapeInfo(), dimension.specialShapeInfo(), + tadPackX.primaryShapeInfo(), tadPackX.primaryOffsets(), tadPackZ.primaryShapeInfo(), tadPackZ.primaryOffsets()); +// x.printIndexedBuffer("Input"); +// exp.printIndexedBuffer("Reduce All"); + ASSERT_TRUE(exp.equalsTo(z)); +} + +TEST_F(NativeOpsTests, ScalarTadTest_2) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create(true); + auto exp = NDArrayFactory::create('c', {5,5}); + auto z = NDArrayFactory::create('c', {5, 5}); + + Nd4jPointer extra[6]; +#ifdef __CUDABLAS__ + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + y.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; +#endif + x.assign(false); + x.p(5, true); + x.p(15, true); + //z.linspace(10., 10.); + //y.assign(2.); + x.syncToDevice(); + z.syncToDevice(); + auto dimension = NDArrayFactory::create({0, 1}); + auto dimensions = reinterpret_cast(dimension.buffer()); + auto tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), dimensions, dimension.lengthOf()); + auto tadPackZ = sd::ConstantTadHelper::getInstance().tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf()); + z.assign(true); + + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer expBuf(exp.dataBuffer()); + OpaqueDataBuffer dimBuf(dimension.dataBuffer()); + + ::execScalarBoolTad(extra, + scalar::And, + &xBuf, x.shapeInfo(), x.specialShapeInfo(), + &expBuf, exp.shapeInfo(), + exp.specialShapeInfo(), + &yBuf, y.shapeInfo(), + y.specialShapeInfo(), + nullptr, + &dimBuf, dimension.shapeInfo(), + dimension.specialShapeInfo(), + tadPackX.primaryShapeInfo(), tadPackX.primaryOffsets(), tadPackZ.primaryShapeInfo(), tadPackZ.primaryOffsets()); +// x.printIndexedBuffer("Input"); +// exp.printIndexedBuffer("And"); + ASSERT_TRUE(exp.e(5) == z.e(5) && exp.e(15)); +} + +TEST_F(NativeOpsTests, ConcatTest_2) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {10,5}); + auto z = NDArrayFactory::create('c', {10,5}); + + Nd4jPointer extra[6]; +#ifdef __CUDABLAS__ + extra[1] = x.getContext()->getCudaStream(); + extra[0] = extra[2] = extra[3] = extra[4] = extra[5] = nullptr; + x.syncToHost(); + y.syncToHost(); + printf("Unsupported for CUDA platform yet.\n"); + return; +#endif + x.linspace(1.0); + y.linspace(26); + + //y.assign(2.); + x.syncToDevice(); + z.syncToDevice(); + int d = 0; + auto dimension = NDArrayFactory::create('c', {1}, {d}); + auto dimensions = reinterpret_cast(dimension.buffer()); + //auto tadPackX = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), dimensions, dimension.lengthOf()); + auto tadPackZ = sd::ConstantTadHelper::getInstance().tadForDimensions(z.shapeInfo(), dimensions, dimension.lengthOf()); + exp.linspace(1); + Nd4jPointer datas[] = {x.buffer(), y.buffer()}; + Nd4jPointer shapes[] = {(Nd4jPointer)x.shapeInfo(), (Nd4jPointer)y.shapeInfo()}; + + ::specialConcat(extra, 0, 2, datas, shapes, z.buffer(), z.shapeInfo(), nullptr, nullptr); + +// exp.printIndexedBuffer("Exp"); +// z.printIndexedBuffer("Concat"); + ASSERT_TRUE(exp.equalsTo(z)); +} + +TEST_F(NativeOpsTests, InitializeTest_1) { +// ::initializeDevicesAndFunctions(); +} + +TEST_F(NativeOpsTests, MallocTest_1) { + auto a = ::mallocHost(16, 0); + ::freeHost(a); + auto dA = ::mallocDevice(16, 0, 0); + ::freeDevice(dA, 0); +} + +TEST_F(NativeOpsTests, OMPTest_1) { + auto maxThreads = ::ompGetMaxThreads(); + auto numThreads = ::ompGetNumThreads(); + //::setOmpMinThreads(maxThreads); + //::setOmpNumThreads(numThreads); +} + +TEST_F(NativeOpsTests, CreateTest_1) { + auto xx = ::createContext(); + auto yy = ::createStream(); + auto zz = ::createEvent(); + ::destroyEvent(zz); + if (xx) + delete (LaunchContext*)xx; + if (yy) + printf("Stream should be destoyed before."); + +} + +TEST_F(NativeOpsTests, MemTest_1) { + auto x = NDArrayFactory::create({10, 20, 30, 40, 50}); + auto y = NDArrayFactory::create({20, 20, 20, 20, 20}); + +#ifdef __CUDABLAS__ + return ; +#endif + //ASSERT_TRUE(0 == ::memcpy(x.buffer(), y.buffer(), x.lengthOf() * sizeof(double), 0, nullptr)); + ASSERT_TRUE(0 == ::memcpyAsync(x.buffer(), y.buffer(), x.lengthOf() * sizeof(double), 0, nullptr)); + //ASSERT_TRUE(0 == ::memset(x.buffer(), 119, x.lengthOf() * sizeof(double), 0, nullptr)); + ASSERT_TRUE(0 == ::memsetAsync(x.buffer(), 119, x.lengthOf() * sizeof(double), 0, nullptr)); + +} + +TEST_F(NativeOpsTests, PullRowsTest_1) { + NDArray x('c', {5, 1}, {0,1,2,3,4}); + NDArray z('c', {4, 1}, sd::DataType::DOUBLE); + NDArray exp('c', {4, 1}, {0,2,3,4}); + + Nd4jLong indexes[] = {0,2,3,4}; + PointersManager pm(LaunchContext::defaultContext(), "NativeOpsTests::pullRows"); + auto pidx = reinterpret_cast(pm.replicatePointer(indexes, 4 * sizeof(Nd4jLong))); + + std::vector dims = {1}; + + auto xTadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), dims); + auto zTadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(z.shapeInfo(), dims); + + Nd4jPointer nativeStart[2]; + +#ifdef __CUDABLAS__ + nativeStart[1] = (x.getContext()->getCudaStream()); +#endif + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + + pullRows(nativeStart, &xBuf, x.shapeInfo(), x.specialShapeInfo(), + &zBuf, z.shapeInfo(), z.specialShapeInfo(), + 4, pidx, + xTadPack.platformShapeInfo(), xTadPack.platformOffsets(), + zTadPack.platformShapeInfo(), zTadPack.platformOffsets()); + + ASSERT_TRUE(z.equalsTo(exp)); + pm.synchronize(); +} + +TEST_F(NativeOpsTests, TadPackTest_1) { + int dimension[] = {1}; + int const dimensionLength = 1; + auto x = NDArrayFactory::create('c', {2,3,4}); + sd::TadPack* pack = ::tadOnlyShapeInfo(x.shapeInfo(), + dimension, + dimensionLength); + ASSERT_TRUE(pack != nullptr); + delete pack; +} + +TEST_F(NativeOpsTests, AverageTest_1) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {5,5}); + auto z = NDArrayFactory::create('c', {5,5}); +#ifdef __CUDABLAS__ + return; +#endif + x.linspace(1); + exp.linspace(1); + Nd4jPointer xList[] = {x.buffer(), x.buffer()}; + Nd4jPointer dxList[] = {x.specialBuffer(), x.specialBuffer()}; + ::average(nullptr, + xList, x.shapeInfo(), + dxList, x.specialShapeInfo(), + z.buffer(), z.shapeInfo(), + z.specialBuffer(), z.specialShapeInfo(), + 2, + x.lengthOf(), + true); +// z.printIndexedBuffer("RES"); + ASSERT_TRUE(z.equalsTo(exp)); +} + +TEST_F(NativeOpsTests, AccumulateTest_1) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {5,5}); + auto z = NDArrayFactory::create('c', {5,5}); +#ifdef __CUDABLAS__ + return; +#endif + x.linspace(1); + exp.linspace(2,2); + Nd4jPointer xList[] = {x.buffer(), x.buffer()}; + Nd4jPointer dxList[] = {x.specialBuffer(), x.specialBuffer()}; + ::accumulate(nullptr, + xList, x.shapeInfo(), + dxList, x.specialShapeInfo(), + z.buffer(), z.shapeInfo(), + z.specialBuffer(), z.specialShapeInfo(), + 2, + x.lengthOf()); +// z.printIndexedBuffer("RES"); + ASSERT_TRUE(z.equalsTo(exp)); +} + +TEST_F(NativeOpsTests, P2PTest_1) { + ::enableP2P(true); + ::checkP2P(); + ::isP2PAvailable(); +} + +TEST_F(NativeOpsTests, ShuffleTest_1) { + auto x = NDArrayFactory::create('c', {5, 5}); + auto y = NDArrayFactory::create('c', {5, 5}); + auto exp = NDArrayFactory::create('c', {5,5}); + auto z = NDArrayFactory::create('c', {5,5}); +#ifdef __CUDABLAS__ + return; +#endif + x.linspace(1); + y.linspace(34); + exp.linspace(2,2); + Nd4jPointer xList[] = {x.buffer(), x.buffer()}; + Nd4jPointer dxList[] = {x.specialBuffer(), y.specialBuffer()}; + Nd4jPointer xShapeList[] = {(Nd4jPointer)x.shapeInfo(), (Nd4jPointer)y.shapeInfo()}; + Nd4jPointer dxShapeList[] = {(Nd4jPointer)x.specialShapeInfo(), (Nd4jPointer)y.specialShapeInfo()}; + Nd4jPointer zList[] = {z.buffer(), z.buffer()}; + Nd4jPointer dzList[] = {z.specialBuffer(), z.specialBuffer()}; + Nd4jPointer zShapeList[] = {(Nd4jPointer)z.shapeInfo(), (Nd4jPointer)z.shapeInfo()}; + Nd4jPointer dzShapeList[] = {(Nd4jPointer)z.specialShapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; + int shuffleMap[] = {1, 0, 4, 3, 2}; + auto zTadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), {1}); + Nd4jPointer zListOffset[] = {(Nd4jPointer)zTadPack.platformOffsets(), (Nd4jPointer)zTadPack.platformOffsets()}; + Nd4jPointer zListTADs[] = {(Nd4jPointer)zTadPack.platformShapeInfo(), (Nd4jPointer)zTadPack.platformShapeInfo()}; + ::shuffle(nullptr, + xList, xShapeList, + dxList, dxShapeList, + zList, zShapeList, + dzList, dzShapeList, + 2, + shuffleMap, zListTADs, zListOffset); +// z.printIndexedBuffer("RES"); +// x.printIndexedBuffer("INPUT shuffled"); +// y.printIndexedBuffer("INPUT 2 shuffled"); +// ASSERT_TRUE(z.equalsTo(exp)); +} + +TEST_F(NativeOpsTests, ConvertTypesTest_1) { + auto x = NDArrayFactory::create('c', {5, 5}); + + auto exp = NDArrayFactory::create('c', {5, 5}); + auto z = NDArrayFactory::create('c', {5, 5}); + +#ifdef __CUDABLAS__ + return; +#endif + x.linspace(2, 2); + exp.linspace(2, 2); + ::convertTypes(nullptr, ND4J_FLOAT32, x.buffer(), x.lengthOf(), ND4J_DOUBLE, z.buffer()); + ASSERT_TRUE(z.equalsTo(exp)); +} + +//TEST_F(NativeOpsTests, Test_Aggregations_1) { +// NativeOps ops; +// auto x = NDArrayFactory::create('c', {5,5}); +// auto y = NDArrayFactory::create('c', {5,5}); +// +// +// ops.execAggregate(nullptr, 0, maxArgs, maxShapes, maxIntArrays, maxIntArraySize, maxIndexArguments, maxRealArguments, pointer.data(), sd::DataType::FLOAT32); +// void **arguments, +// int numArguments, +// Nd4jLong **shapeArguments, +// int numShapeArguments, +// int *indexArguments, +// int numIndexArguments, +// int **intArrays, +// int numIntArrays, +// void *realArguments, +// int numRealArguments, +// sd::DataType dtype +//} + +TEST_F(NativeOpsTests, RandomTest_1) { + auto z = NDArrayFactory::create('c', {100}); + Nd4jPointer extra[] = {nullptr, nullptr}; +#ifdef __CUDABLAS__ + return; + extra[1] = z.getContext()->getCudaStream(); +#endif + graph::RandomGenerator rng(1023, 119); + double p = 0.5; + OpaqueDataBuffer zBuf(z.dataBuffer()); + + ::execRandom(extra, random::BernoulliDistribution, &rng, &zBuf, z.shapeInfo(), z.specialShapeInfo(), &p); +} + +TEST_F(NativeOpsTests, RandomTest_2) { + auto x = NDArrayFactory::create('c', {100}); + auto z = NDArrayFactory::create('c', {100}); + Nd4jPointer extra[] = {nullptr, nullptr}; +#ifdef __CUDABLAS__ + return; + extra[1] = z.getContext()->getCudaStream(); +#endif + x.linspace(0, 0.01); + graph::RandomGenerator rng(1023, 119); + double p = 0.5; + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + + ::execRandom2(extra, random::DropOut, &rng, &xBuf, x.shapeInfo(), x.specialShapeInfo(), &zBuf, z.shapeInfo(), z.specialShapeInfo(), &p); +} + +TEST_F(NativeOpsTests, RandomTest_3) { + auto x = NDArrayFactory::create('c', {100}); + auto y = NDArrayFactory::create('c', {100}); + auto z = NDArrayFactory::create('c', {100}); + Nd4jPointer extra[] = {nullptr, nullptr}; +#ifdef __CUDABLAS__ + return; + extra[1] = z.getContext()->getCudaStream(); +#endif + x.linspace(0, 0.01); + x.linspace(1, -0.01); + graph::RandomGenerator rng(1023, 119); + double p = 0.5; + OpaqueDataBuffer xBuf(x.dataBuffer()); + OpaqueDataBuffer yBuf(y.dataBuffer()); + OpaqueDataBuffer zBuf(z.dataBuffer()); + + ::execRandom3(extra, random::ProbablisticMerge, &rng, &xBuf, x.shapeInfo(), x.specialShapeInfo(), &yBuf, + y.shapeInfo(), y.specialShapeInfo(), &zBuf, z.shapeInfo(), z.specialShapeInfo(), &p); +} + +TEST_F(NativeOpsTests, RandomTest_4) { +#ifdef __CUDABLAS__ + return ; +#endif + graph::RandomGenerator* rng = (graph::RandomGenerator*)::initRandom(nullptr, 1023, 0, nullptr); + ::refreshBuffer(nullptr, 1203L, rng); + ::reSeedBuffer(nullptr, 3113L, rng); + ::destroyRandom(rng); +} + +TEST_F(NativeOpsTests, SortTest_1) { +#ifdef __CUDABLAS__ + return ; +#endif + auto sortedVals = NDArrayFactory::create( + {10, 1, 5, 120, 34, 5, 78, 138, 3, 111, 331, 29, 91, 71, 73, 50, 56, 4}); + auto exp = NDArrayFactory::create({1, 3, 4, 5, 5, 10, 29, 34, 50, 56, 71, 73, 78, 91, 111, 120, 138, 331}); + + ::sort(nullptr, sortedVals.buffer(), sortedVals.shapeInfo(), sortedVals.specialBuffer(), + sortedVals.specialShapeInfo(), false); + ASSERT_TRUE(sortedVals.equalsTo(exp)); +} + +TEST_F(NativeOpsTests, SortTests_2) { + auto k = NDArrayFactory::create('c', {10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create('c', {10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + + auto ek = NDArrayFactory::create('c', {10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + Nd4jPointer extras[2]; +#ifdef __CUDABLAS__ + extras[1] = LaunchContext::defaultContext()->getCudaStream(); +#endif +// OpaqueDataBuffer xBuf(x.dataBuffer()); +// OpaqueDataBuffer yBuf(y.dataBuffer()); +// OpaqueDataBuffer expBuf(exp.dataBuffer()); +// OpaqueDataBuffer dimBuf(exp.dataBuffer()); + + ::sortByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false); + k.tickWriteDevice(); + v.tickWriteDevice(); + + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); +} + +TEST_F(NativeOpsTests, SortTest_3) { + auto k = NDArrayFactory::create('c', {10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create('c', {10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + + auto ek = NDArrayFactory::create('c', {10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + +#ifdef __CUDABLAS__ + Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; +#else + Nd4jPointer extras[2]; +#endif + + ::sortByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false); + k.tickWriteDevice(); + v.tickWriteDevice(); + + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); +} + +TEST_F(NativeOpsTests, SortTest_4) { +#ifdef __CUDABLAS__ + return ; +#endif + auto sortedVals = NDArrayFactory::create('c', {3, 6}, + { 10, 1, 5, 120, 34, 5, + 78, 138, 3, 111, 331, 29, + 91, 71, 73, 50, 56, 4}); + auto exp = NDArrayFactory::create('c', {3, 6}, {1, 5, 5, 10, 34, 120, 3, 29, 78, 111, 138, 331, 4, 50, 56, 71, 73, 91}); + + std::vector dims({1}); + auto packX = ConstantTadHelper::getInstance().tadForDimensions(sortedVals.shapeInfo(), {1}); + ::sortTad(nullptr, sortedVals.buffer(), sortedVals.shapeInfo(), sortedVals.specialBuffer(), + sortedVals.specialShapeInfo(), dims.data(), dims.size(), packX.platformShapeInfo(), packX.platformOffsets(), false); +// sortedVals.printBuffer("OUT"); +// exp.printIndexedBuffer("EXP"); + ASSERT_TRUE(sortedVals.equalsTo(exp)); +} + +TEST_F(NativeOpsTests, SortTests_5) { + auto k = NDArrayFactory::create('c', {2, 10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create('c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + + auto ek = NDArrayFactory::create('c', {2, 10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create('c', {2, 10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + + Nd4jPointer extras[2]; +#ifdef __CUDABLAS__ + extras[1] = LaunchContext::defaultContext()->getCudaStream(); +#endif + + int axis = 1; + + ::sortTadByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false); + k.tickWriteDevice(); + v.tickWriteDevice(); + +// k.printIndexedBuffer("k"); +// v.printIndexedBuffer("v"); + + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); +} + +TEST_F(NativeOpsTests, SortTests_6) { + auto k = NDArrayFactory::create('c', {2, 10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create('c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + + auto ek = NDArrayFactory::create('c', {2, 10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create('c', {2, 10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + + Nd4jPointer extras[2]; +#ifdef __CUDABLAS__ + extras[1] = LaunchContext::defaultContext()->getCudaStream(); +#endif + + int axis = 1; + + ::sortTadByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false); + k.tickWriteDevice(); + v.tickWriteDevice(); + + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); +} + +//TEST_F(NativeOpsTests, MapTests_1) { +//#ifdef __CUDABLAS__ +// return ; +//#endif +//#ifdef GTEST_OS_LINUX +// auto ptrMap = ::mmapFile(nullptr, "/tmp/maptest.$$$", 100LL); +// +// ::munmapFile(nullptr, ptrMap, 100LL); +//#endif +// +//} + +TEST_F(NativeOpsTests, MapTests_1) { + //printf("Custom ops: %s\n", ::getAllCustomOps()); + //printf("All ops: %s\n", ::getAllOperations()); + + ::getAllCustomOps(); + ::getAllOperations(); +} + +TEST_F(NativeOpsTests, CustomOpTest_1) { + auto x = NDArrayFactory::create('c', {1, 6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto z = NDArrayFactory::create('c', {6}); + auto e = NDArrayFactory::create('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + + sd::ops::squeeze op; + + Nd4jPointer ptrsInBuffer[] = {(Nd4jPointer) x.buffer(), x.specialBuffer()}; + Nd4jPointer ptrsInShapes[] = {(Nd4jPointer) x.shapeInfo(), (Nd4jPointer)x.specialShapeInfo()}; + + Nd4jPointer ptrsOutBuffers[] = {(Nd4jPointer) z.buffer(), z.specialBuffer()}; + Nd4jPointer ptrsOutShapes[] = {(Nd4jPointer) z.shapeInfo(), (Nd4jPointer)z.specialShapeInfo()}; + + + auto status = ::execCustomOp(nullptr, op.getOpHash(), ptrsInBuffer, ptrsInShapes, 1, ptrsOutBuffers, ptrsOutShapes, 1, nullptr, 0, nullptr, 0, nullptr, 0, false); + ASSERT_EQ(Status::OK(), status); + + ASSERT_EQ(e, z); +} +TEST_F(NativeOpsTests, CustomOpTests_2) { + auto array0 = NDArrayFactory::create('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto array1 = NDArrayFactory::create('c', {3, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f}); + auto z = NDArrayFactory::create('c', {3, 2}); + + auto exp = NDArrayFactory::create('c', {3, 2}, {2.f, 4.f, 6.f, 8.f, 10.f, 12.f}); + Context ctx(1); + + NDArray::prepareSpecialUse({&z}, {&array0, &array1}); + + ctx.setInputArray(0, array0.buffer(), array0.shapeInfo(), array0.specialBuffer(), array0.specialShapeInfo()); + ctx.setInputArray(1, array1.buffer(), array1.shapeInfo(), array1.specialBuffer(), array1.specialShapeInfo()); + ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo()); + + ASSERT_EQ(2, ctx.width()); + + sd::ops::add op; + ::execCustomOp2(nullptr, op.getOpHash(), &ctx); + + NDArray::registerSpecialUse({&z}, {&array0, &array1}); + + ASSERT_EQ(exp, z); +} +TEST_F(NativeOpsTests, CalculateOutputShapeTests_1) { + auto input = NDArrayFactory::create('c', {1, 2, 5, 4}); + auto weights = NDArrayFactory::create('c', {2, 2, 2, 3}); + auto exp = NDArrayFactory::create('c', {1, 3, 5, 4}); + + sd::ops::conv2d op; + + std::vector tArgs({}); + std::vector iArgs({2, 2, 1, 1, 0, 0, 1, 1, 1}); + + Nd4jPointer ptrs[] = {(Nd4jPointer) input.shapeInfo(), (Nd4jPointer) weights.shapeInfo()}; +#ifdef __CUDABLAS__ + return; +#endif + + auto shapeList = ::calculateOutputShapes(nullptr, op.getOpHash(), ptrs, 2, tArgs.data(), tArgs.size(), iArgs.data(), iArgs.size()); + + ASSERT_EQ(1, shapeList->size()); + + ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0))); + ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]); + ASSERT_EQ(exp.sizeAt(1), shape::shapeOf((Nd4jLong *)shapeList->at(0))[1]); + ASSERT_EQ(exp.sizeAt(2), shape::shapeOf((Nd4jLong *)shapeList->at(0))[2]); + ASSERT_EQ(exp.sizeAt(3), shape::shapeOf((Nd4jLong *)shapeList->at(0))[3]); + + //int *ptr = (int *) shapeList[0]; + //delete[] ptr; + //delete shapeList; + + ::deleteShapeList((Nd4jPointer) shapeList); +} + +TEST_F(NativeOpsTests, CalculateOutputShapeTests_2) { + auto input = NDArrayFactory::create('c', {1, 2, 5, 4}); + auto weights = NDArrayFactory::create('c', {2, 2, 2, 3}); + auto exp = NDArrayFactory::create('c', {1, 3, 5, 4}); + + sd::ops::conv2d op; + + std::vector tArgs({}); + std::vector bArgsF({}); + std::vector iArgs({2, 2, 1, 1, 0, 0, 1, 1, 1}); + + Nd4jPointer shapePtrs[] = {(Nd4jPointer) input.shapeInfo(), (Nd4jPointer) weights.shapeInfo()}; + Nd4jPointer dataPtrs[] = {(Nd4jPointer)input.buffer(), (Nd4jPointer)weights.buffer()}; +#ifdef __CUDABLAS__ + return; +#endif + + auto shapeList = ::calculateOutputShapes2(nullptr, op.getOpHash(), dataPtrs, shapePtrs, 2, const_cast(tArgs.data()), tArgs.size(), + const_cast(iArgs.data()), iArgs.size(), nullptr, bArgsF.size(), nullptr, 0); +// Nd4jPointer* extraPointers, Nd4jLong hash, Nd4jPointer* inputBuffers, Nd4jPointer* inputShapes, int numInputShapes, double* tArgs, int numTArgs, Nd4jLong *iArgs, int numIArgs, bool *bArgs, int numBArgs + ASSERT_EQ(1, shapeList->size()); + + ASSERT_EQ(exp.rankOf(), shape::rank((Nd4jLong *)shapeList->at(0))); + ASSERT_EQ(exp.sizeAt(0), shape::shapeOf((Nd4jLong *)shapeList->at(0))[0]); + ASSERT_EQ(exp.sizeAt(1), shape::shapeOf((Nd4jLong *)shapeList->at(0))[1]); + ASSERT_EQ(exp.sizeAt(2), shape::shapeOf((Nd4jLong *)shapeList->at(0))[2]); + ASSERT_EQ(exp.sizeAt(3), shape::shapeOf((Nd4jLong *)shapeList->at(0))[3]); + + //int *ptr = (int *) shapeList[0]; + //delete[] ptr; + //delete shapeList; + + ::deleteShapeList((Nd4jPointer) shapeList); +} + + +TEST_F(NativeOpsTests, interop_databuffer_tests_1) { + auto idb = ::allocateDataBuffer(100, 10, false); + auto ptr = ::dbPrimaryBuffer(idb); + ::deleteDataBuffer(idb); +} + +//Uncomment when needed only - massive calculations +//TEST_F(NativeOpsTests, BenchmarkTests_1) { +// +// printf("%s\n", ::runLightBenchmarkSuit(true)); +// printf("%s\n", ::runFullBenchmarkSuit(true)); +//} diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/NlpTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/NlpTests.cpp new file mode 100644 index 000000000..14a20b99d --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/NlpTests.cpp @@ -0,0 +1,476 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include + + +using namespace sd; + + +class NlpTests : public testing::Test { +public: + + NlpTests() { + printf("\n"); + fflush(stdout); + } +}; + +TEST_F(NlpTests, basic_sg_hs_test_1) { + auto exp0 = NDArrayFactory::create('c', {1, 10}); + auto exp1 = NDArrayFactory::create('c', {1, 10}); + + exp0.assign(0.01001f); + exp1.assign(0.020005f); + + auto target = NDArrayFactory::create(0); + auto ngStarter = NDArrayFactory::empty(); + auto indices = NDArrayFactory::create('c', {1}, {1}); + auto codes = NDArrayFactory::create('c', {1}); + auto syn0 = NDArrayFactory::create('c', {100, 10}); + auto syn1 = NDArrayFactory::create('c', {100, 10}); + auto syn1Neg = NDArrayFactory::empty(); + auto expTable = NDArrayFactory::create('c', {10000}); + auto negTable = NDArrayFactory::empty(); + auto neu1e = NDArrayFactory::create('c', {10}); + + syn0.assign(0.01); + syn1.assign(0.02); + expTable.assign(0.5); + + auto alpha = NDArrayFactory::create(0.001); + auto randomValue = NDArrayFactory::create(1L); + auto inferenceVector = NDArrayFactory::empty(); + + sd::ops::skipgram op; + auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true); + ASSERT_EQ(Status::OK(), result.status()); + + auto row0 = syn0({0,1, 0,0}, true); + auto row1 = syn1({1,2, 0,0}, true); + + ASSERT_EQ(exp0, row0); + ASSERT_EQ(exp1, row1); + + +} + +TEST_F(NlpTests, basic_sg_hs_test_2) { + auto exp0 = NDArrayFactory::create('c', {1, 10}); + auto exp1 = NDArrayFactory::create('c', {1, 10}); + auto exp2 = NDArrayFactory::create('c', {1, 10}); + + exp0.assign(0.01f); + exp1.assign(0.020005f); + exp2.assign(0.019995f); + + auto target = NDArrayFactory::create(0); + auto ngStarter = NDArrayFactory::empty(); + auto indices = NDArrayFactory::create('c', {2}, {1, 2}); + auto codes = NDArrayFactory::create('c', {2}, {0, 1}); + auto syn0 = NDArrayFactory::create('c', {100, 10}); + auto syn1 = NDArrayFactory::create('c', {100, 10}); + auto syn1Neg = NDArrayFactory::empty(); + auto expTable = NDArrayFactory::create('c', {10000}); + auto negTable = NDArrayFactory::empty(); + auto neu1e = NDArrayFactory::create('c', {10}); + + syn0.assign(0.01); + syn1.assign(0.02); + expTable.assign(0.5); + + auto alpha = NDArrayFactory::create(0.001); + auto randomValue = NDArrayFactory::create(1L); + auto inferenceVector = NDArrayFactory::empty(); + + sd::ops::skipgram op; + auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true); + ASSERT_EQ(Status::OK(), result.status()); + + auto row0 = syn0({0,1, 0,0}, true); + auto row1 = syn1({1,2, 0,0}, true); + auto row2 = syn1({2,3, 0,0}, true); + + ASSERT_EQ(exp0, row0); + ASSERT_EQ(exp1, row1); + ASSERT_EQ(exp2, row2); + + +} + +TEST_F(NlpTests, basic_sg_hs_test_3) { + auto exp0 = NDArrayFactory::create('c', {1, 10}); + auto exp1 = NDArrayFactory::create('c', {1, 10}); + auto exp2 = NDArrayFactory::create('c', {1, 10}); + + exp0.assign(0.01f); + exp1.assign(0.020005f); + exp2.assign(0.019995f); + + auto target = NDArrayFactory::create(0); + auto ngStarter = NDArrayFactory::empty(); + auto indices0 = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto indices1 = NDArrayFactory::create('c', {3}, {3, 1, 2}); + auto codes00 = NDArrayFactory::create('c', {3}, {0, 1, 1}); + auto codes01 = NDArrayFactory::create('c', {3}, {1, 0, 1}); + auto syn00 = NDArrayFactory::create('c', {100, 10}); + auto syn01 = NDArrayFactory::create('c', {100, 10}); + auto syn10 = NDArrayFactory::create('c', {100, 10}); + auto syn11 = NDArrayFactory::create('c', {100, 10}); + auto syn1Neg = NDArrayFactory::empty(); + auto expTable = NDArrayFactory::create('c', {10000}); + auto negTable = NDArrayFactory::empty(); + auto neu1e = NDArrayFactory::create('c', {10}); + + RandomGenerator rng(119L, 198L); + RandomLauncher::fillUniform(LaunchContext::defaultContext(), rng, &syn00, 0.0, 1.0); + RandomLauncher::fillUniform(LaunchContext::defaultContext(), rng, &syn10, 0.0, 1.0); + + syn01.assign(syn00); + syn11.assign(syn10); + expTable.assign(0.5); + + auto alpha = NDArrayFactory::create(0.001); + auto randomValue = NDArrayFactory::create(1L); + auto inferenceVector = NDArrayFactory::empty(); + + sd::ops::skipgram op; + auto result0 = op.evaluate({&target, &ngStarter, &indices0, &codes00, &syn00, &syn10, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true); + auto result1 = op.evaluate({&target, &ngStarter, &indices1, &codes01, &syn01, &syn11, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false}, {}, true); + ASSERT_EQ(Status::OK(), result0.status()); + + auto row00 = syn00({0,1, 0,0}, true); + auto row01 = syn01({0,1, 0,0}, true); + auto row1 = syn10({1,2, 0,0}, true); + auto row2 = syn11({1,2, 0,0}, true); + + ASSERT_EQ(row2, row1); + ASSERT_EQ(row00, row01); +} + +TEST_F(NlpTests, basic_sg_hs_ns_test_1) { + auto target = NDArrayFactory::create(0); + auto ngStarter = NDArrayFactory::create(1); + auto indices = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + auto codes = NDArrayFactory::create('c', {5}, {1, 1, 0, 1, 1}); + auto syn0 = NDArrayFactory::create('c', {100, 150}); + auto syn1 = NDArrayFactory::create('c', {100, 150}); + auto syn1Neg = NDArrayFactory::create('c', {100, 150}); + auto expTable = NDArrayFactory::create('c', {1000}); + auto negTable = NDArrayFactory::create('c', {1000}); + auto neu1e = NDArrayFactory::create('c', {10}); + negTable.linspace(1.0); + + auto alpha = NDArrayFactory::create(1.25); + auto randomValue = NDArrayFactory::create(119L); + auto inferenceVector = NDArrayFactory::empty(); + + sd::ops::skipgram op; + auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {3}, {false}, {}, true); + ASSERT_EQ(Status::OK(), result.status()); + + +} + +TEST_F(NlpTests, basic_sg_ns_test_1) { + auto exp0 = NDArrayFactory::create('c', {1, 10}); + + exp0.assign(0.01); + + auto target = NDArrayFactory::create(1); + auto ngStarter = NDArrayFactory::create(3); + auto indices = NDArrayFactory::empty(); + auto codes = NDArrayFactory::empty(); + auto syn0 = NDArrayFactory::create('c', {10, 10}); + auto syn1 = NDArrayFactory::empty(); + auto syn1Neg = NDArrayFactory::create('c', {10, 10}); + auto expTable = NDArrayFactory::create('c', {1000}); + auto negTable = NDArrayFactory::create('c', {1000}); + auto neu1e = NDArrayFactory::create('c', {10}); + + auto syn1Neg2 = NDArrayFactory::create('c', {10, 10}); + + syn0.assign(0.01); + syn1.assign(0.02); + syn1Neg.assign(0.03); + syn1Neg2.assign(0.03); + expTable.assign(0.5); + + auto alpha = NDArrayFactory::create(0.001); + auto randomValue = NDArrayFactory::create(2L); + auto inferenceVector = NDArrayFactory::empty(); + + sd::ops::skipgram op; + auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {1, 1}, {false}, {}, true); + ASSERT_EQ(Status::OK(), result.status()); + + auto row0 = syn0({1,2, 0,0}, true); + + ASSERT_EQ(exp0, row0); + ASSERT_FALSE(syn1Neg2.equalsTo(syn1Neg, 1e-6)); + + +} + +TEST_F(NlpTests, basic_cb_hs_test_1) { + auto exp0 = NDArrayFactory::create('c', {1, 10}); + auto exp1 = NDArrayFactory::create('c', {1, 10}); + auto exp2 = NDArrayFactory::create('c', {1, 10}); + + exp0.assign(0.0095f); + exp1.assign(0.019875f); + exp2.assign(0.02f); + + auto target = NDArrayFactory::create(0); + auto ngStarter = NDArrayFactory::empty(); + auto context = NDArrayFactory::create('c', {3}, {0, 1, 2}); + auto locked = NDArrayFactory::create('c', {3}); + auto indices = NDArrayFactory::create('c', {2}, {4, 5}); + auto codes = NDArrayFactory::create('c', {2}, {1, 1}); + auto syn0 = NDArrayFactory::create('c', {100, 10}); + auto syn1 = NDArrayFactory::create('c', {100, 10}); + auto syn1Neg = NDArrayFactory::empty(); + auto expTable = NDArrayFactory::create('c', {10000}); + auto negTable = NDArrayFactory::empty(); + auto numWords = NDArrayFactory::create('c', {1}, {1}); + + syn0.assign(0.01); + syn1.assign(0.02); + expTable.assign(0.5); + + auto alpha = NDArrayFactory::create(0.025); + auto randomValue = NDArrayFactory::create(2L); + auto inferenceVector = NDArrayFactory::empty(); + + sd::ops::cbow op; + auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, {}, true); + ASSERT_EQ(Status::OK(), result.status()); + + auto row_s0_0 = syn0({0,1, 0,0}, true); + auto row_s0_1 = syn0({1,2, 0,0}, true); + auto row_s0_2 = syn0({2,3, 0,0}, true); + + auto row_s1_4 = syn1({4,5, 0,0}, true); + auto row_s1_5 = syn1({5,6, 0,0}, true); + auto row_s1_6 = syn1({6,7, 0,0}, true); + + ASSERT_EQ(exp0, row_s0_0); + ASSERT_EQ(exp0, row_s0_1); + ASSERT_EQ(exp0, row_s0_2); + + ASSERT_EQ(exp1, row_s1_4); + ASSERT_EQ(exp1, row_s1_5); + ASSERT_EQ(exp2, row_s1_6); + + +} + +TEST_F(NlpTests, basic_cb_ns_test_1) { + auto exp0 = NDArrayFactory::create('c', {1, 10}); + auto exp1 = NDArrayFactory::create('c', {1, 10}); + auto exp2 = NDArrayFactory::create('c', {1, 10}); + + exp0.assign(0.0096265625); + exp1.assign(0.01); + exp2.assign(0.030125f); + + auto target = NDArrayFactory::create(0); + auto ngStarter = NDArrayFactory::create(6); + auto context = NDArrayFactory::create('c', {3}, {0, 1, 2}); + auto locked = NDArrayFactory::create('c', {3}); + auto indices = NDArrayFactory::empty(); + auto codes = NDArrayFactory::empty(); + auto syn0 = NDArrayFactory::create('c', {100, 10}); + auto syn1 = NDArrayFactory::create('c', {100, 10}); + auto syn1Neg = NDArrayFactory::create('c', {100, 10}); + auto expTable = NDArrayFactory::create('c', {10000}); + auto negTable = NDArrayFactory::create('c', {100000}); + auto numWords = NDArrayFactory::create('c', {2}, {1, 2}); + + syn0.assign(0.01); + syn1.assign(0.02); + syn1Neg.assign(0.03); + expTable.assign(0.5); + + auto alpha = NDArrayFactory::create(0.025); + auto randomValue = NDArrayFactory::create(2L); + auto inferenceVector = NDArrayFactory::empty(); + + sd::ops::cbow op; + auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {1, 2, 0}, {true}, {}, true); + ASSERT_EQ(Status::OK(), result.status()); + + auto row_s0_0 = syn0({0,1, 0,0}, true); + auto row_s0_1 = syn0({1,2, 0,0}, true); + auto row_s0_2 = syn0({2,3, 0,0}, true); + + auto row_s1_4 = syn1({4,5, 0,0}, true); + auto row_s1_5 = syn1({5,6, 0,0}, true); + auto row_s1_6 = syn1Neg({6,7, 0,0}, true); + + + ASSERT_EQ(exp0, row_s0_0); + ASSERT_EQ(exp0, row_s0_1); + ASSERT_EQ(exp0, row_s0_2); + ASSERT_EQ(exp2, row_s1_6); + + +} + +TEST_F(NlpTests, test_sg_hs_batch_1) { + auto exp0 = NDArrayFactory::create('c', {1, 10}); + auto exp1 = NDArrayFactory::create('c', {1, 10}); + auto exp2 = NDArrayFactory::create('c', {1, 10}); + + exp0.assign(0.01f); + exp1.assign(0.020005f); + exp2.assign(0.019995f); + + auto target = NDArrayFactory::create('c', {2}, {0, 5}); + auto ngStarter = NDArrayFactory::empty(); + auto indices = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + auto codes = NDArrayFactory::create('c', {2, 2}, {0, 1, 1, 1}); + auto syn0 = NDArrayFactory::create('c', {100, 10}); + auto syn1 = NDArrayFactory::create('c', {100, 10}); + auto syn1Neg = NDArrayFactory::empty(); + auto expTable = NDArrayFactory::create('c', {10000}); + auto negTable = NDArrayFactory::empty(); + + auto alpha = NDArrayFactory::create('c', {2}, {0.001, 0.024}); + auto randomValue = NDArrayFactory::create('c', {2}, {1L, 3L}); + auto inferenceVector = NDArrayFactory::empty(); + auto neu1e = NDArrayFactory::create('c', {2, 10}); + + syn0.assign(0.01); + syn1.assign(0.02); + expTable.assign(0.5); + + sd::ops::skipgram op; + auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {}, {false, true}, {}, true); + ASSERT_EQ(Status::OK(), result.status()); + + auto row0 = syn0({0,1, 0,0}, true); + auto row1 = syn1({1,2, 0,0}, true); + auto row2 = syn1({2,3, 0,0}, true); + + ASSERT_TRUE(exp0.equalsTo(row0, 1e-6)); + ASSERT_TRUE(exp1.equalsTo(row1, 1e-6)); + ASSERT_TRUE(exp2.equalsTo(row2, 1e-6)); + + +} + +TEST_F(NlpTests, test_sg_ns_batch_1) { + auto exp0 = NDArrayFactory::create('c', {1, 10}); + auto exp1 = NDArrayFactory::create('c', {1, 10}); + auto exp2 = NDArrayFactory::create('c', {1, 10}); + + exp0.assign(0.01f); + exp1.assign(0.020005f); + exp2.assign(0.019995f); + + auto target = NDArrayFactory::create('c', {2}, {0, 5}); + auto ngStarter = NDArrayFactory::create('c', {2}, {3, 8}); + auto indices = NDArrayFactory::empty(); + auto codes = NDArrayFactory::empty(); + auto syn0 = NDArrayFactory::create('c', {100, 10}); + auto syn1Neg = NDArrayFactory::create('c', {100, 10}); + auto syn1 = NDArrayFactory::empty(); + auto expTable = NDArrayFactory::create('c', {10000}); + auto negTable = NDArrayFactory::create('c', {100000}); + + auto alpha = NDArrayFactory::create('c', {2}, {0.001, 0.024}); + auto randomValue = NDArrayFactory::create('c', {2}, {1L, 3L}); + auto inferenceVector = NDArrayFactory::empty(); + auto neu1e = NDArrayFactory::create('c', {2, 10}); + + syn0.assign(0.01); + syn1.assign(0.02); + expTable.assign(0.5); + negTable.linspace(0.0); + + sd::ops::skipgram op; + auto result = op.evaluate({&target, &ngStarter, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &inferenceVector, &neu1e}, {}, {4, 5}, {false, true}, {}, true); + ASSERT_EQ(Status::OK(), result.status()); + + +} + +TEST_F(NlpTests, test_cbow_hs_batch_1) { +#ifdef __CUDABLAS__ + return ; +#endif + + auto target = NDArrayFactory::create(0); + auto ngStarter = NDArrayFactory::empty(); + auto context = NDArrayFactory::create('c', {2, 3}, {0, 1, 2, 100, 101, 102}); + auto locked = NDArrayFactory::create('c', {2, 3}); + auto indices = NDArrayFactory::create('c', {2, 2}, {4, 5, 40, 50}); + auto codes = NDArrayFactory::create('c', {2, 2}, {1, 1, 1, 1}); + auto syn0 = NDArrayFactory::create('c', {244, 10}); + auto syn1 = NDArrayFactory::create('c', {244, 10}); + auto syn1Neg = NDArrayFactory::empty(); + auto expTable = NDArrayFactory::create('c', {10000}); + auto negTable = NDArrayFactory::empty(); + auto numWords = NDArrayFactory::create('c', {2}, {1, 2}); + + syn0.assign(0.01); + syn1.assign(0.02); + expTable.assign(0.5); + + auto alpha = NDArrayFactory::create('c', {2}, {0.025, 0.025}); + auto randomValue = NDArrayFactory::create('c', {2}, {2L, 2L}); + auto inferenceVector = NDArrayFactory::empty(); + + sd::ops::cbow op; + auto result = op.evaluate({&target, &ngStarter, &context, &indices, &codes, &syn0, &syn1, &syn1Neg, &expTable, &negTable, &alpha, &randomValue, &numWords, &locked, &inferenceVector}, {}, {}, {true}, {}, true); + ASSERT_EQ(Status::OK(), result.status()); + + auto exp0 = NDArrayFactory::create('c', {1, 10}); + auto exp1 = NDArrayFactory::create('c', {1, 10}); + auto exp2 = NDArrayFactory::create('c', {1, 10}); + + exp0.assign(0.0095f); + exp1.assign(0.019875f); + exp2.assign(0.02f); + + auto row_s0_0 = syn0({0,1, 0,0}, true); + auto row_s0_1 = syn0({1,2, 0,0}, true); + auto row_s0_2 = syn0({2,3, 0,0}, true); + + auto row_s1_4 = syn1({4,5, 0,0}, true); + auto row_s1_5 = syn1({5,6, 0,0}, true); + auto row_s1_6 = syn1({6,7, 0,0}, true); + + ASSERT_EQ(exp0, row_s0_0); + ASSERT_EQ(exp0, row_s0_1); + ASSERT_EQ(exp0, row_s0_2); + ASSERT_EQ(exp1, row_s1_4); + ASSERT_EQ(exp1, row_s1_5); + ASSERT_EQ(exp2, row_s1_6); + +} diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/NodeTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/NodeTests.cpp new file mode 100644 index 000000000..f67781e4f --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/NodeTests.cpp @@ -0,0 +1,75 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 21.02.18. +// + +#include "testlayers.h" +#include +#include +#include +#include + +using namespace sd; +using namespace sd::graph; + +class NodeTests : public testing::Test { +public: + +}; + +TEST_F(NodeTests, Test_Dtype_Conversion_1) { + auto nodeA = new Node(OpType_TRANSFORM_SAME, 0, 1, {-1}, {2}); + + auto nd = nodeA->asT(); + auto nf = nd->asT(); + + ASSERT_EQ(nodeA->id(), nf->id()); + ASSERT_EQ(*nodeA->name(), *nf->name()); + ASSERT_EQ(nodeA->getOpClass(), nf->getOpClass()); + ASSERT_EQ(nodeA->opType(), nf->opType()); + ASSERT_EQ(nodeA->opNum(), nf->opNum()); + + delete nodeA; + delete nd; + delete nf; +} + + +TEST_F(NodeTests, Test_Dtype_Conversion_2) { + sd::ops::add opA; + + //auto nodeA = new Node(OpType_CUSTOM, 0, 1, {-1}, {2}); + auto nodeA = new Node(&opA, 1, {-1}, {2}); + //nodeA->setCustomOp(&op); + + auto nd = nodeA->asT(); + auto nf = nd->asT(); + + ASSERT_EQ(nodeA->id(), nf->id()); + ASSERT_EQ(*nodeA->name(), *nf->name()); +// ASSERT_EQ(nodeA->getOpClass(), nf->getOpClass()); + ASSERT_EQ(nodeA->opType(), nf->opType()); + ASSERT_EQ(nodeA->opNum(), nf->opNum()); + ASSERT_EQ(nodeA->getCustomOp()->getOpHash(), nf->getCustomOp()->getOpHash()); + + delete nodeA; + delete nd; + delete nf; +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/OmpLaunchHelperTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/OmpLaunchHelperTests.cpp new file mode 100644 index 000000000..9d9c87b61 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/OmpLaunchHelperTests.cpp @@ -0,0 +1,125 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 30.06.18. +// + +#include "testlayers.h" +#include +#include + + +using namespace sd; +using namespace sd::graph; + +class OmpLaunchHelperTests : public testing::Test { +private: + int ewt = 0; +public: + OmpLaunchHelperTests() { + this->ewt = Environment::getInstance().elementwiseThreshold(); + Environment::getInstance().setElementwiseThreshold(1000); + }; + + ~OmpLaunchHelperTests() { + Environment::getInstance().setElementwiseThreshold(this->ewt); + } +}; + +TEST_F(OmpLaunchHelperTests, Test_BetterSpan_1) { + auto span = OmpLaunchHelper::betterSpan(1000, 4); + ASSERT_EQ(250, span); +} + +TEST_F(OmpLaunchHelperTests, Test_BetterSpan_2) { + auto span = OmpLaunchHelper::betterSpan(1001, 4); + ASSERT_EQ(251, span); +} + +TEST_F(OmpLaunchHelperTests, Test_BetterSpan_3) { + auto span = OmpLaunchHelper::betterSpan(1002, 4); + ASSERT_EQ(251, span); +} + +TEST_F(OmpLaunchHelperTests, Test_BetterSpan_5) { + auto span = OmpLaunchHelper::betterSpan(1003, 4); + ASSERT_EQ(251, span); +} + +TEST_F(OmpLaunchHelperTests, Test_BetterSpan_6) { + auto span = OmpLaunchHelper::betterSpan(1004, 4); + ASSERT_EQ(251, span); +} + + +TEST_F(OmpLaunchHelperTests, Test_BetterThreads_1) { + auto n = OmpLaunchHelper::betterThreads(4000, 6); + ASSERT_EQ(4, n); +} + +TEST_F(OmpLaunchHelperTests, Test_BetterThreads_2) { + auto n = OmpLaunchHelper::betterThreads(12000, 6); + ASSERT_EQ(6, n); +} + +TEST_F(OmpLaunchHelperTests, Test_BetterThreads_3) { + auto n = OmpLaunchHelper::betterThreads(899, 6); + ASSERT_EQ(1, n); +} + +TEST_F(OmpLaunchHelperTests, test_tad_threads_1) { + Nd4jLong numTads = 16; + Nd4jLong tadLength = 16; + +// nd4j_printf("TT: [%i]; ET: [%i];\n", Environment::getInstance().tadThreshold(), Environment::getInstance().elementwiseThreshold()); + ASSERT_EQ(1, OmpLaunchHelper::tadThreads(tadLength, numTads)); +} + +TEST_F(OmpLaunchHelperTests, test_tad_threads_2) { + if (omp_get_max_threads() <= 1) + return; + + Nd4jLong numTads = 2; + Nd4jLong tadLength = Environment::getInstance().elementwiseThreshold(); + + ASSERT_EQ(2, OmpLaunchHelper::tadThreads(tadLength, numTads)); +} + +TEST_F(OmpLaunchHelperTests, test_tad_threads_3) { + Nd4jLong numTads = 2; + Nd4jLong tadLength = 128; + + ASSERT_EQ(1, OmpLaunchHelper::tadThreads(tadLength, numTads)); +} + +TEST_F(OmpLaunchHelperTests, test_tad_threads_4) { + Nd4jLong numTads = 4; + Nd4jLong tadLength = 64; + + ASSERT_EQ(1, OmpLaunchHelper::tadThreads(tadLength, numTads)); +} + +TEST_F(OmpLaunchHelperTests, test_tad_threads_5) { + auto exp = omp_get_max_threads(); + + Nd4jLong numTads = exp; + Nd4jLong tadLength = Environment::getInstance().elementwiseThreshold(); + + ASSERT_EQ(exp, OmpLaunchHelper::tadThreads(tadLength, numTads)); +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/OneOffTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/OneOffTests.cpp new file mode 100644 index 000000000..53a7ba101 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/OneOffTests.cpp @@ -0,0 +1,390 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 11.10.2017. +// + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace sd; +using namespace sd::ops; + +class OneOffTests : public testing::Test { +public: + +}; + +TEST_F(OneOffTests, test_avg_pool_3d_1) { + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/avg_pooling3d.fb"); + + ASSERT_TRUE(graph != nullptr); + + // graph->printOut(); + + Nd4jStatus status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + delete graph; +} + +TEST_F(OneOffTests, test_non2d_0A_1) { + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/non2d_0A.fb"); + + ASSERT_TRUE(graph != nullptr); + + // graph->printOut(); + + Nd4jStatus status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + delete graph; +} + +/* +TEST_F(OneOffTests, test_assert_scalar_float32_1) { + sd::ops::Assert op; + sd::ops::identity op1; + sd::ops::noop op2; + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/scalar_float32.fb"); + + ASSERT_TRUE(graph != nullptr); + + graph->printOut(); + + Nd4jStatus status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + delete graph; +}*/ + +TEST_F(OneOffTests, test_assert_scalar_float32_2) { + sd::ops::Assert op; + sd::ops::identity op1; + sd::ops::noop op2; + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/assertsomething.fb"); + + ASSERT_TRUE(graph != nullptr); + + // graph->printOut(); + + Nd4jStatus status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + delete graph; +} + + +TEST_F(OneOffTests, test_pad_1D_1) { + auto e = NDArrayFactory::create('c', {7}, {10.f,0.778786f, 0.801198f, 0.724375f, 0.230894f, 0.727141f,10.f}); + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/pad_1D.fb"); + + ASSERT_TRUE(graph != nullptr); + + // graph->printOut(); + + Nd4jStatus status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(4)); + + auto z = graph->getVariableSpace()->getVariable(4)->getNDArray(); + ASSERT_TRUE(z != nullptr); + + // z->printIndexedBuffer("z"); + + ASSERT_EQ(e, *z); + delete graph; +} +/* +TEST_F(OneOffTests, test_scatter_nd_update_1) { + + auto e = NDArrayFactory::create('c', {10, 7}, {1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.20446908f, 0.37918627f, 0.99792874f, 0.71881700f, 0.18677747f, + 0.78299069f, 0.55216062f, 0.40746713f, 0.92128086f, 0.57195139f, 0.44686234f, 0.30861020f, 0.31026053f, 0.09293187f, + 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 0.95073712f, 0.45613325f, 0.95149803f, 0.88341522f, 0.54366302f, 0.50060666f, 0.39031255f, + 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, + 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}); + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/scatter_nd_update.fb"); + ASSERT_TRUE(graph != nullptr); + + graph->printOut(); + + Nd4jStatus status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(6)); + + auto z = graph->getVariableSpace()->getVariable(6)->getNDArray(); + ASSERT_TRUE(z != nullptr); + + z->printIndexedBuffer("z"); + + ASSERT_EQ(e, *z); + + delete graph; +} + */ + +TEST_F(OneOffTests, test_conv2d_nhwc_failed_1) { + auto e = NDArrayFactory::create('c', {1, 5, 5, 6}, {0.55744928f, 0.76827729f, 1.09401524f, 0.00000000f, 0.00000000f, 0.00000000f, 0.56373537f, 0.90029907f, 0.78997850f, 0.00000000f, 0.00000000f, 0.00000000f, 0.14252824f, 0.95961076f, 0.87750554f, 0.00000000f, 0.00000000f, 0.00000000f, 0.44874173f, 0.99537718f, 1.17154264f, 0.00000000f, 0.00000000f, 0.00000000f, 0.60377145f, 0.79939061f, 0.56031001f, 0.00000000f, 0.00000000f, 0.00000000f, 0.52975273f, 0.90678585f, 0.73763013f, 0.00000000f, 0.00000000f, 0.00000000f, 0.22146404f, 0.82499605f, 0.47222072f, 0.00000000f, 0.00000000f, 0.00000000f, 0.42772964f, 0.39793295f, 0.71436501f, 0.00000000f, 0.00000000f, 0.00000000f, 0.48836520f, 1.01658893f, 0.74419701f, 0.00000000f, 0.00000000f, 0.00000000f, 0.78984612f, 0.94083673f, 0.83841157f, 0.00000000f, 0.00000000f, 0.00000000f, 0.40448499f, 0.67732805f, 0.75499672f, 0.00000000f, 0.00000000f, 0.00000000f, 0.43675962f, 0.79476535f, 0.72976631f, 0.00000000f, 0.00000000f, 0.00000000f, 0.58808053f, 0.65222591f, 0.72552216f, 0.00000000f, 0.00000000f, 0.00000000f, 0.37445742f, 1.22581339f, 1.05341125f, 0.00000000f, 0.00000000f, 0.00000000f, 0.30095795f, 0.59941679f, 0.63323414f, 0.00000000f, 0.00000000f, 0.00000000f, 0.24199286f, 1.02546394f, 0.69537812f, 0.00000000f, 0.00000000f, 0.00000000f, 0.23628944f, 0.90791851f, 1.01209974f, 0.00000000f, 0.00000000f, 0.00000000f, 0.62740159f, 0.56518674f, 0.76692569f, 0.00000000f, 0.00000000f, 0.00000000f, 0.13327584f, 0.32628393f, 0.10280430f, 0.00000000f, 0.00000000f, 0.00000000f, 0.42691272f, 0.25625113f, 0.30524066f, 0.00000000f, 0.00000000f, 0.00000000f, 0.17797673f, 0.84179950f, 0.80061519f, 0.00000000f, 0.00000000f, 0.00000000f, 0.00199084f, 0.51838887f, 0.43932241f, 0.00000000f, 0.00000000f, 0.00000000f, 0.16684581f, 0.50822425f, 0.48668745f, 0.00000000f, 0.00000000f, 0.00000000f, 0.16749343f, 0.93093169f, 0.86871749f, 0.00000000f, 0.00000000f, 0.00000000f, 0.17486368f, 0.44460732f, 0.44499981f, 0.00000000f, 0.00000000f, 0.00000000f}); + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/channels_last_b1_k2_s1_d1_SAME_crelu.fb"); + ASSERT_TRUE(graph != nullptr); + + // graph->printOut(); + + Nd4jStatus status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(9)); + + auto z = graph->getVariableSpace()->getVariable(9)->getNDArray(); + ASSERT_TRUE(z != nullptr); + + ASSERT_EQ(e, *z); + + delete graph; +} + +TEST_F(OneOffTests, test_tensor_array_1) { + auto e = NDArrayFactory::create('c', {2, 3}, {0.77878559f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f}); + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_array_close_sz1_float32_nodynamic_noname_noshape.fb"); + ASSERT_TRUE(graph != nullptr); + + // graph->printOut(); + + Nd4jStatus status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(5)); + + auto z = graph->getVariableSpace()->getVariable(5)->getNDArray(); + ASSERT_TRUE(z != nullptr); + + ASSERT_EQ(e, *z); + + delete graph; +} + +TEST_F(OneOffTests, test_tensor_array_2) { + auto e = NDArrayFactory::create('c', {2, 3}, {0.77878559f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f}); + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_array_split_sz1_float32_nodynamic_noname_noshape.fb"); + ASSERT_TRUE(graph != nullptr); + + // graph->printOut(); + + Nd4jStatus status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(6)); + + auto z = graph->getVariableSpace()->getVariable(6)->getNDArray(); + ASSERT_TRUE(z != nullptr); + + ASSERT_EQ(e, *z); + + delete graph; +} + +TEST_F(OneOffTests, test_tensor_array_3) { + auto e = NDArrayFactory::create('c', {3, 2, 3}, {7, 2, 9, 4, 3, 3, 8, 7, 0, 0, 6, 8, 7, 9, 0, 1, 1, 4}); + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_array_stack_sz3-1_int32_dynamic_name_shape.fb"); + ASSERT_TRUE(graph != nullptr); + + // graph->printOut(); + + + Nd4jStatus status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(15)); + + auto z = graph->getVariableSpace()->getVariable(15)->getNDArray(); + ASSERT_TRUE(z != nullptr); + + ASSERT_EQ(e, *z); + + delete graph; +} + +TEST_F(OneOffTests, test_tensor_array_4) { + auto e = NDArrayFactory::create('c', {2, 3}, {4, 3, 1, 1, 1, 0}); + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/tensor_array_unstack_sz1_int64_nodynamic_noname_shape2-3.fb"); + ASSERT_TRUE(graph != nullptr); + + // graph->printOut(); + + + Nd4jStatus status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(11)); + + auto z = graph->getVariableSpace()->getVariable(11)->getNDArray(); + ASSERT_TRUE(z != nullptr); + + ASSERT_EQ(e, *z); + + delete graph; +} + +TEST_F(OneOffTests, test_assert_4) { + auto e = NDArrayFactory::create('c', {2, 2}, {1, 1, 1, 1}); + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/assert_type_rank2_int64.fb"); + ASSERT_TRUE(graph != nullptr); + + // graph->printOut(); + + + Nd4jStatus status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(1)); + + auto z = graph->getVariableSpace()->getVariable(1)->getNDArray(); + ASSERT_TRUE(z != nullptr); + + ASSERT_EQ(e, *z); + + delete graph; +} + +// TEST_F(OneOffTests, test_cond_true_1) { +// auto e = NDArrayFactory::create('c', {5}, {1.f, 2.f, 3.f, 4.f, 5.f}); + +// auto graph = GraphExecutioner::importFromFlatBuffers("./resources/cond_true.fb"); +// ASSERT_TRUE(graph != nullptr); + +// graph->printOut(); + + +// Nd4jStatus status = GraphExecutioner::execute(graph); +// ASSERT_EQ(Status::OK(), status); +// ASSERT_TRUE(graph->getVariableSpace()->hasVariable(6)); + +// auto z = graph->getVariableSpace()->getVariable(6)->getNDArray(); +// ASSERT_TRUE(z != nullptr); + +// z->printIndexedBuffer("z buffer"); + +// ASSERT_EQ(e, *z); + +// delete graph; +// } + +/* +TEST_F(OneOffTests, test_cond_false_1) { + auto e = NDArrayFactory::create('c', {5}, {1.f, 1.f, 1.f, 1.f, 1.f}); + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/cond_false.fb"); + ASSERT_TRUE(graph != nullptr); + + graph->printOut(); + + + Nd4jStatus status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(6)); + + auto z = graph->getVariableSpace()->getVariable(6)->getNDArray(); + ASSERT_TRUE(z != nullptr); + + z->printIndexedBuffer("z buffer"); + + ASSERT_EQ(e, *z); + + delete graph; +} +*/ + +TEST_F(OneOffTests, test_identity_n_2) { + auto e = NDArrayFactory::create('c', {2, 3}, {0.77878559f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f}); + + sd::ops::identity_n op; + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/identity_n_2.fb"); + ASSERT_TRUE(graph != nullptr); + + // graph->printOut(); + + + Nd4jStatus status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(1)); + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(1, 1)); + + auto z = graph->getVariableSpace()->getVariable(1)->getNDArray(); + ASSERT_TRUE(z != nullptr); + + ASSERT_EQ(e, *z); + + delete graph; +} + +TEST_F(OneOffTests, test_non2d_1) { + auto e = NDArrayFactory::create('c', {1, 1}, {5.42746449f}); + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/non2d_1.fb"); + ASSERT_TRUE(graph != nullptr); + + // graph->printOut(); + + Nd4jStatus status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(3)); + + auto z = graph->getVariableSpace()->getVariable(3)->getNDArray(); + ASSERT_TRUE(z != nullptr); + + ASSERT_EQ(e, *z); + + + delete graph; +} + +TEST_F(OneOffTests, test_reduce_all_1) { + auto e = NDArrayFactory::create('c', {1, 4}, {true, false, false, false}); + + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_all_rank2_d0_keep.fb"); + ASSERT_TRUE(graph != nullptr); + + // graph->printOut(); + + Nd4jStatus status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(1)); + + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(2)); + auto in = graph->getVariableSpace()->getVariable(2)->getNDArray(); + + + auto z = graph->getVariableSpace()->getVariable(1)->getNDArray(); + ASSERT_TRUE(z != nullptr); + + ASSERT_EQ(e, *z); + + + delete graph; +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/OpTrackerTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/OpTrackerTests.cpp new file mode 100644 index 000000000..6d562d393 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/OpTrackerTests.cpp @@ -0,0 +1,71 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 15.12.17. +// +#include "testlayers.h" +#include +#include +#include +#include +#include + +using namespace sd; +using namespace sd::ops; +using namespace sd::graph; + +class OpTrackerTests : public testing::Test { +public: + int numIterations = 10; + int poolSize = 10; + + OpTrackerTests() { + printf("\n"); + fflush(stdout); + } +}; + +TEST_F(OpTrackerTests, Test_Existence_1) { + sd::_loader loader; + + // nd4j_printf("Groups: %i; Operations: %i\n", OpTracker::getInstance().totalGroups(), OpTracker::getInstance().totalOperations()); + + ASSERT_TRUE(OpTracker::getInstance().totalGroups() > 0); + ASSERT_TRUE(OpTracker::getInstance().totalOperations() > 0); + + OpTracker::getInstance().exportOperations(); +} + +TEST_F(OpTrackerTests, Test_Ops_List_1) { + sd::ops::less op; + auto vec = OpRegistrator::getInstance().getAllHashes(); + + // nd4j_printf("Total ops: %lld\n", vec.size()); + // nd4j_printf("Less hash: %lld\n", op.getOpHash()); + + for (const auto &v: vec) { + if (v == 5484196977525668316L) { + auto op = OpRegistrator::getInstance().getOperation(v); + // nd4j_printf("OpName: %s\n", op->getOpName()->c_str()); + } + } +} + + + diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/OpTupleTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/OpTupleTests.cpp new file mode 100644 index 000000000..e52701ee2 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/OpTupleTests.cpp @@ -0,0 +1,61 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 11.10.2017. +// + +#include "testlayers.h" +#include +#include + +using namespace sd; +using namespace sd::ops; + +class OpTupleTests : public testing::Test { + public: +}; + +TEST_F(OpTupleTests, DirectConstructorTest1) { + auto alpha = NDArrayFactory::create_('c', {1, 2}); + auto beta = NDArrayFactory::create_('c', {1, 2}); + OpTuple tuple("dummy", {alpha, beta}, {12.0f}, {1,2, 3}); + + ASSERT_EQ("dummy", tuple._opName); + ASSERT_EQ(2, tuple._inputs.size()); + ASSERT_EQ(0, tuple._outputs.size()); + ASSERT_EQ(1, tuple._tArgs.size()); + ASSERT_EQ(3, tuple._iArgs.size()); +} + +TEST_F(OpTupleTests, BuilderTest1) { + auto alpha = NDArrayFactory::create_('c', {1, 2}); + auto beta = NDArrayFactory::create_('c', {1, 2}); + OpTuple tuple("dummy"); + tuple.addInput(alpha) + ->addInput(beta) + ->setTArgs({12.0f}) + ->setIArgs({1, 2, 3}); + + + ASSERT_EQ("dummy", tuple._opName); + ASSERT_EQ(2, tuple._inputs.size()); + ASSERT_EQ(0, tuple._outputs.size()); + ASSERT_EQ(1, tuple._tArgs.size()); + ASSERT_EQ(3, tuple._iArgs.size()); +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/PairwiseTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/PairwiseTests.cpp new file mode 100644 index 000000000..59c83bdd8 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/PairwiseTests.cpp @@ -0,0 +1,52 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by agibsonccc on 1/17/17. +// +#include "testinclude.h" +#include + +class EqualsTest : public testing::Test { +public: + const Nd4jLong firstShapeBuffer[8] = {2,1,2,1,1,0,1,102}; + float data[2] = {1.0f, 7.0f}; + const Nd4jLong secondShapeBuffer[8] = {2,2,1,6,1,0,6,99}; + float dataSecond[12] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; + int opNum = 4; + float extraArgs[1] = {1e-6f}; + int dimension[1] = {2147483647}; + int dimensionLength = 1; +}; + +#ifndef __CUDABLAS__ + +TEST_F(EqualsTest,Eps) { + auto val = sd::NDArrayFactory::create(0.0f); + functions::reduce3::Reduce3::execScalar(opNum, + data, + firstShapeBuffer, + extraArgs, + dataSecond, + secondShapeBuffer, + val.buffer(), + val.shapeInfo()); + ASSERT_TRUE(val.e(0) < 0.5); +} + +#endif diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ParityOpsTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ParityOpsTests.cpp new file mode 100644 index 000000000..2feba9d63 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ParityOpsTests.cpp @@ -0,0 +1,1700 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 12.10.2017. +// + +#include "testlayers.h" +#include +#include + + +using namespace sd; +using namespace sd::ops; + +class ParityOpsTests : public testing::Test { +public: + +}; + + +TEST_F(ParityOpsTests, TestZeroAs1) { + auto x = NDArrayFactory::create('c', {10, 10}); + x.assign(1.0); + + auto exp = NDArrayFactory::create('c', {10, 10}); + exp.assign(0.0f); + + sd::ops::zeros_as op; + + auto result = op.evaluate({&x}, {}, {}); + + auto z = result.at(0); + + ASSERT_TRUE(z->isSameShape(&x)); + ASSERT_TRUE(z->equalsTo(&exp)); + + +} + +TEST_F(ParityOpsTests, TestMaximum1) { + auto x = NDArrayFactory::create('c', {10, 10}); + x.assign(1.0); + + auto y = NDArrayFactory::create('c', {10, 10}); + y.assign(2.0); + + sd::ops::maximum op; + + auto result = op.evaluate({&x, &y}, {}, {}); + + auto z = result.at(0); + + ASSERT_TRUE(y.equalsTo(z)); + + +} + + +TEST_F(ParityOpsTests, TestMinimum1) { + auto x = NDArrayFactory::create('c', {10, 10}); + x.assign(1.0f); + + auto y = NDArrayFactory::create('c', {10, 10}); + y.assign(-2.0f); + + + sd::ops::minimum op; + + auto result = op.evaluate({&x, &y}, {}, {}); + + auto z = result.at(0); + + ASSERT_TRUE(y.equalsTo(z)); + + +} + +TEST_F(ParityOpsTests, TestTear1) { + auto input = NDArrayFactory::create('c', {10, 5}); + auto tads = input.allTensorsAlongDimension({1}); + for (int e = 0; e < tads.size(); e++) { + ASSERT_EQ(5, tads.at(e)->lengthOf()); + tads.at(e)->assign((float) e + 1); + } + + sd::ops::tear op; + + auto result = op.evaluate({&input}, {}, {1}); + + ASSERT_EQ(10, result.size()); + + for (int e = 0; e < result.size(); e++) + ASSERT_TRUE(tads.at(e)->equalsTo(result.at(e))); + + +} + +TEST_F(ParityOpsTests, TestUnstack1) { + auto input = NDArrayFactory::create('c', {10, 5}); + auto tads = input.allTensorsAlongDimension({1}); + for (int e = 0; e < tads.size(); e++) { + ASSERT_EQ(5, tads.at(e)->lengthOf()); + tads.at(e)->assign((float) e + 1); + } + + sd::ops::unstack op; + + auto result = op.evaluate({&input}, {}, {0}); + + ASSERT_EQ(10, result.size()); + + for (int e = 0; e < result.size(); e++) + ASSERT_TRUE(tads.at(e)->equalsTo(result.at(e))); + + +} + + + +TEST_F(ParityOpsTests, TestUnstack2) { + auto input = NDArrayFactory::create('c', {5,2,6}); + auto tads = input.allTensorsAlongDimension({0,1}); + for (int e = 0; e < tads.size(); e++) { + ASSERT_EQ(10, tads.at(e)->lengthOf()); + tads.at(e)->assign((float) e + 1); + } + + sd::ops::unstack op; + + auto result = op.evaluate({&input}, {}, {2}); + + ASSERT_EQ(6, result.size()); + + for (int e = 0; e < result.size(); e++) + ASSERT_TRUE(tads.at(e)->equalsTo(result.at(e))); + + +} + +TEST_F(ParityOpsTests, TestUnstack3) { + auto input = NDArrayFactory::create('c', {3,2,3}); + auto exp = NDArrayFactory::create('c', {3, 2}, {1.f, 4., 7., 10.f, 13.f, 16.f}); + input.linspace(1); + + sd::ops::unstack op; + + auto result = op.evaluate({&input}, {}, {2}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(ParityOpsTests, TestUnstack4) { + auto input = NDArrayFactory::create('c', {3,2,3}); + auto exp = NDArrayFactory::create('c', {3, 3}, { 1, 2, 3, 7, 8, 9, 13, 14, 15.}); + input.linspace(1); + + sd::ops::unstack op; + + auto result = op.evaluate({&input}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(ParityOpsTests, TestUnstack5) { + auto input = NDArrayFactory::create('c', {3,2,3}); + auto exp = NDArrayFactory::create('c', {2, 3}, { 1, 2, 3, 4, 5, 6}); + input.linspace(1); + + sd::ops::unstack op; + + auto result = op.evaluate({&input}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(ParityOpsTests, TestUnstack6) { + auto input = NDArrayFactory::create('c', {1, 1, 1}); + auto exp = NDArrayFactory::create('c', {1, 1}, {1}); + input.linspace(1); + + sd::ops::unstack op; + + auto result = op.evaluate({&input}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(ParityOpsTests, TestUnstack7) { + auto input = NDArrayFactory::create('c', {1, 1, 1}); + auto exp = NDArrayFactory::create('c', {1, 1}, {1}); + input.linspace(1); + + sd::ops::unstack op; + + auto result = op.evaluate({&input}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(ParityOpsTests, TestUnstack8) { + auto input = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('c', {1}, {1}); + input.linspace(1); + + sd::ops::unstack op; + + auto result = op.evaluate({&input}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(ParityOpsTests, TestUnstack9) { + auto input = NDArrayFactory::create('c', {1, 1}); + auto exp = NDArrayFactory::create('c', {1}, {1}); + input.linspace(1); + + sd::ops::unstack op; + + auto result = op.evaluate({&input}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, TestUnstack10) { + + auto input = NDArrayFactory::create('c', {3, 0, 2}); + auto exp = NDArrayFactory::create('c', {0,2}); + + sd::ops::unstack op; + + auto result = op.evaluate({&input}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.isSameShape(result.at(1))); + ASSERT_TRUE(exp.isSameShape(result.at(2))); + + +} + +//////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, TestUnstack11) { + + auto input = NDArrayFactory::create('c', {3, 0, 2}); + auto exp = NDArrayFactory::create('c', {3,0}); + + sd::ops::unstack op; + + auto result = op.evaluate({&input}, {}, {2}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(exp.isSameShape(result.at(0))); + ASSERT_TRUE(exp.isSameShape(result.at(1))); + + +} + +//////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, TestUnstack12) { + + auto input = NDArrayFactory::create('c', {3, 0, 2}); + + sd::ops::unstack op; + + auto result = op.evaluate({&input}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_TRUE(result.size() == 0); + + +} + +TEST_F(ParityOpsTests, TestUnstack13) { + + auto x = NDArrayFactory::create('c', {2, 3}); + + sd::ops::unstack op; + auto result = op.evaluate({&x}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + ASSERT_EQ(3, result.size()); + + for (int e = 0; e < 3; e++) + ASSERT_EQ(1, result.at(e)->rankOf()); + +} + + + +TEST_F(ParityOpsTests, ExpandDimsTest1) { + auto input = NDArrayFactory::create('c', {5, 5}); + input.linspace(1); + auto reshaped = input.reshape('c', {5, 1, 5}); + + sd::ops::expand_dims op; + auto result = op.evaluate({&input}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(reshaped.isSameShape(z)); + ASSERT_TRUE(reshaped.equalsTo(z)); + + +} + + +TEST_F(ParityOpsTests, ExpandDimsTest2) { + auto input = NDArrayFactory::create('c', {3, 4}); + input.linspace(1); + auto reshaped = input.reshape('c', {1, 3, 4}); + + sd::ops::expand_dims op; + auto result = op.evaluate({&input}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(reshaped.isSameShape(z)); + ASSERT_TRUE(reshaped.equalsTo(z)); + + +} + + +TEST_F(ParityOpsTests, ExpandDimsTest3) { + auto input = NDArrayFactory::create('c', {3, 4}); + input.linspace(1); + auto reshaped = input.reshape('c', {3, 1, 4}); + + sd::ops::expand_dims op; + auto result = op.evaluate({&input}, {}, {-2}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(reshaped.isSameShape(z)); + ASSERT_TRUE(reshaped.equalsTo(z)); + + +} + +TEST_F(ParityOpsTests, ExpandDimsTest4) { + auto input = NDArrayFactory::create('c', {3, 4}); + input.linspace(1); + auto reshaped = input.reshape('c', {1, 3, 4}); + + sd::ops::expand_dims op; + auto result = op.evaluate({&input}, {}, {-3}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(reshaped.isSameShape(z)); + ASSERT_TRUE(reshaped.equalsTo(z)); + + +} + + +TEST_F(ParityOpsTests, Test_Shape_1) { + auto x = NDArrayFactory::create('c', {3, 4, 5, 6}); + auto exp = NDArrayFactory::create('c', {4}, {3, 4, 5, 6}); + + sd::ops::shape_of op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(ParityOpsTests, Test_Equals_1) { + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {1, 5}, {1, 0, 3, 0, 5}); + auto exp = NDArrayFactory::create('c', {1, 5}, {1, 0, 1, 0, 1}); + + sd::ops::equals op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(ParityOpsTests, Test_NotEquals_1) { + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {1, 5}, {1, 0, 3, 0, 5}); + auto exp = NDArrayFactory::create('c', {1, 5}, {0, 1, 0, 1, 0}); + + sd::ops::not_equals op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(ParityOpsTests, Test_Less_1) { + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {1, 5}, {5, 4, 3, 2, 1}); + auto exp = NDArrayFactory::create('c', {1, 5}, {1, 1, 0, 0, 0}); + + sd::ops::less op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(ParityOpsTests, Test_LessEquals_1) { + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {1, 5}, {5, 4, 3, 2, 1}); + auto exp = NDArrayFactory::create('c', {1, 5}, {1, 1, 1, 0, 0}); + + sd::ops::less_equal op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(ParityOpsTests, Test_GreaterEquals_1) { + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {1, 5}, {5, 4, 3, 2, 1}); + auto exp = NDArrayFactory::create('c', {1, 5}, {0, 0, 1, 1, 1}); + + sd::ops::greater_equal op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(ParityOpsTests, Test_GreaterEquals_2) { + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {1, 5}, {5, 4, 3, 2, 1}); + auto exp = NDArrayFactory::create('c', {1, 5}, {0, 0, 1, 1, 1}); + + sd::ops::greater_equal op; + auto result = op.evaluate({&x, &y}, {}, {}, {}, {}, false); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(ParityOpsTests, Test_Greater_1) { + auto x = NDArrayFactory::create('c', {1, 5}, {1, 2, 3, 4, 5}); + auto y = NDArrayFactory::create('c', {1, 5}, {5, 4, 3, 2, 1}); + auto exp = NDArrayFactory::create('c', {1, 5}, {0, 0, 0, 1, 1}); + + sd::ops::greater op; + auto result = op.evaluate({&x, &y}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(ParityOpsTests, Test_Where_1) { + auto mask = NDArrayFactory::create('c', {3, 3}, {1, 1, 1, 0, 0, 0, 1, 1, 1}); + auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y = NDArrayFactory::create('c', {3, 3}, {9, 8, 7, 6, 5, 4, 3, 2, 1}); + auto exp = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 6, 5, 4, 7, 8, 9}); + + sd::ops::Where op; + auto result = op.evaluate({&mask, &x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + // z->printIndexedBuffer("result"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(ParityOpsTests, Test_Where_2) { + auto mask = NDArrayFactory::create('c', {1, 3}, {1, 0, 0}); + auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y = NDArrayFactory::create('c', {3, 3}, {9, 8, 7, 6, 5, 4, 3, 2, 1}); + auto exp = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 6, 5, 4, 3, 2, 1}); + + sd::ops::Where op; + auto result = op.evaluate({&mask, &x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(ParityOpsTests, Test_Where_3) { + auto mask = NDArrayFactory::create('c', {2, 2, 3}, {0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1}); + auto exp = NDArrayFactory::create('c', {5, 3}, {0, 0, 1, 0, 0, 2, 0, 1, 1, 1, 0, 0, 1, 1, 2}); + + sd::ops::Where op; + auto result = op.evaluate({&mask}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + // z->printShapeInfo("z"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(ParityOpsTests, Test_Select_1) { + auto mask = NDArrayFactory::create('c', {1, 3}, {1, 0, 0}); + auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto y = NDArrayFactory::create('c', {3, 3}, {9, 8, 7, 6, 5, 4, 3, 2, 1}); + auto exp = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 6, 5, 4, 3, 2, 1}); + + sd::ops::select op; + auto result = op.evaluate({&mask, &x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(ParityOpsTests, Test_Select_2) { + auto mask = NDArrayFactory::create('c', {2, 2}, {1, 0, 1, 0}); + auto x = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4 }); + auto y = NDArrayFactory::create('c', {2, 2}, {9, 8, 7, 6}); + auto exp = NDArrayFactory::create('c', {2, 2}, {1, 8, 3, 6}); + + sd::ops::select op; + auto result = op.evaluate({&mask, &x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(ParityOpsTests, Test_Select_3) { + bool value = false; + auto mask = NDArrayFactory::create('c', {1, 1}, {value}); + auto x = NDArrayFactory::create('c', {1, 1}, {1}); + auto y = NDArrayFactory::create('c', {1, 1}, {2}); + auto exp = NDArrayFactory::create('c', {1, 1}, {2}); + + sd::ops::select op; + auto result = op.evaluate({&mask, &x, &y}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +TEST_F(ParityOpsTests, Test_Bias_Add_1) { + auto x = NDArrayFactory::create('c', {10, 5}); + x.assign(0.0); + auto bias = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + sd::ops::biasadd op; + + auto result = op.evaluate({&x, &bias}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + auto tads = z->allTensorsAlongDimension({1}); + for (int e = 0; e < tads.size(); e++) { + ASSERT_TRUE(bias.equalsTo(tads.at(e))); + } + +} + +TEST_F(ParityOpsTests, Test_Scatter_Add_1) { + auto matrix = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + NDArray idc('c', {1}, std::vector({0}), sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 2}, {1, 1}); + auto exp = NDArrayFactory::create('c', {2, 2}, {2, 3, 3, 4}); + + sd::ops::scatter_add op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + +} + +TEST_F(ParityOpsTests, Test_Scatter_Add_2) { + + auto vec = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + NDArray idc('c', {1, 4}, {0., 1, 2, 3}, sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 4}, {1, 1, 1, 1}); + auto exp = NDArrayFactory::create('c', {1, 4}, {2, 3, 4, 5}); + + sd::ops::scatter_add op; + auto result = op.evaluate({&vec, &idc, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + +} + +TEST_F(ParityOpsTests, Test_Scatter_Add_3) { + auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + NDArray idc('c', {1}, std::vector({0}), sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 2, 2}, {1, 1, 1, 1}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}, {2, 3, 4, 5, 5, 6, 7, 8}); + + sd::ops::scatter_add op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + +} + +TEST_F(ParityOpsTests, Test_Scatter_Add_4) { + auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + NDArray idc('c', {1, 2}, std::vector{0, 0}, sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 2, 2, 2}, {1, 1, 1, 1, 1, 1, 1, 1}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}, {3, 4, 5, 6, 5, 6, 7, 8}); + + sd::ops::scatter_add op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true, true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + +} + +TEST_F(ParityOpsTests, Test_Scatter_Add_5) { + auto matrix = NDArrayFactory::create('c', {2, 2, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + NDArray idc('c', {2, 2}, {1., 1, 0, 0}, sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {2, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto exp = NDArrayFactory::create('c', {2, 2, 3}, {9., 11., 13.,15., 17., 19., 9., 11., 13.,15., 17., 19.}); + + sd::ops::scatter_add op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + // z->printBuffer(); + + ASSERT_TRUE(exp.equalsTo(z)); + +} + +TEST_F(ParityOpsTests, Test_Scatter_Add_6) { + auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 1, 1, 1, 1, 1, 1, 1}); + NDArray idc('c', {2, 2}, {1, 1, 0, 0}, sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {2, 2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}, {7, 9, 11, 13, 7, 9, 11, 13}); + + sd::ops::scatter_add op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true, true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + +} + +TEST_F(ParityOpsTests, Test_Scatter_Add_7) { + auto matrix = NDArrayFactory::create('c', {10, 3}, {1.f,2.f,3.f,4.f,5.f,6.f,7.f,8.f,9.f,10.f,11.f,12.f,13.f,14.f,15.f,16.f,17.f,18.f,19.f,20.f,21.f,22.f,23.f,24.f,25.f,26.f,27.f,28.f,29.f,30.f}); + NDArray idc('c', {}, std::vector{5}, sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {3}, {10.f, 20.f, 30.f}); + auto exp = NDArrayFactory::create('c', {10, 3}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f,11.f,12.f, 13.f,14.f,15.f, 26.f,37.f,48.f, 19.f,20.f,21.f, 22.f,23.f,24.f, 25.f,26.f,27.f, 28.f,29.f,30.f}); + + sd::ops::scatter_add op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, Test_Scatter_Add_8) { + + NDArray input('c', {8}, {1,1,1,1,1,1,1,1}, sd::DataType::FLOAT32); + NDArray indices('c', {4}, {1, 1, 1, 1}, sd::DataType::INT32); + NDArray updates('c', {4}, {1,2,3,4}, sd::DataType::FLOAT32); + NDArray expected('c', {8}, {1.f, 11.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f}, sd::DataType::FLOAT32); + + NDArray z('c', {8}, sd::DataType::FLOAT32); + + sd::ops::scatter_add op; + Nd4jStatus status = op.execute({&input, &indices, &updates}, {&z}, {}, {}, {true}); + // z.printBuffer(); + + ASSERT_EQ(ND4J_STATUS_OK, status); + ASSERT_TRUE(expected.isSameShapeStrict(z)); + ASSERT_TRUE(expected.equalsTo(z)); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, Test_Scatter_Add_9) { + auto matrix = NDArrayFactory::create('c', {2, 2, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + NDArray idc('c', {2, 2}, {1, 10, 0, 0}, sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {2, 2, 2, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + auto output = NDArrayFactory::create('c', {2, 2, 3}); + + sd::ops::scatter_add op; + + ASSERT_ANY_THROW(op.execute({&matrix, &idc, &updates}, {&output}, {}, {}, {true, true})); +} + +//////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatterMax_test1) { + auto matrix = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + NDArray idc('c', {1}, std::vector{0.}, sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 2}, {10, 1}); + auto exp = NDArrayFactory::create('c', {2, 2}, {10, 2, 3, 4}); + + sd::ops::scatter_max op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + +} + +TEST_F(ParityOpsTests, scatterMax_test2) { + auto vec = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + NDArray idc('c', {1, 4}, {0, 1, 2, 3}, sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 4}, {10, 1, 30, 1}); + auto exp = NDArrayFactory::create('c', {1, 4}, {10, 2, 30, 4}); + + sd::ops::scatter_max op; + auto result = op.evaluate({&vec, &idc, &updates}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + +} + +TEST_F(ParityOpsTests, scatterMax_test3) { + auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + NDArray idc('c', {1}, std::vector({0}), sd::DataType::INT64); + auto updates = NDArrayFactory::create('c', {1, 2, 2}, {10, 1, 30, 1}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}, {10, 2, 30, 4, 5, 6, 7, 8}); + + sd::ops::scatter_max op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(ParityOpsTests, scatterMax_test4) { + auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + NDArray idc('c', {1,2}, std::vector{0.,0}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {1, 2, 2, 2}, {1,10,1,10, 1,1,10,1.}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 10, 10, 10, 5, 6, 7, 8}); + + sd::ops::scatter_max op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {true}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + +} + +TEST_F(ParityOpsTests, scatterMax_test5) { + auto matrix = NDArrayFactory::create('c', {2, 2, 3}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); + NDArray idc('c', {2, 2}, {1, 1, 0, 0}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {2, 2, 2, 3}, {2,10,1,10, 2,10,1,10, 2,10,1,10, 10,2,10,1, 10,2,10,1, 10,2,10,1.}); + auto exp = NDArrayFactory::create('c', {2, 2, 3}, {10, 2, 10, 2, 10, 2, 2, 10, 2, 10, 2, 10}); + + sd::ops::scatter_max op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + +} + +TEST_F(ParityOpsTests, scatterMax_test6) { + auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 1, 1, 1, 1, 1, 1, 1}); + NDArray idc('c', {2, 2}, {1, 1, 0, 0}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {2, 2, 2, 2}, {0,2,0,2, 0,2,0,2, 2,0,2,0., 2,0,2,0}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}, {2, 1, 2, 1, 1, 2, 1, 2}); + + sd::ops::scatter_max op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + +} + + +TEST_F(ParityOpsTests, scatterMin_test1) { + auto matrix = NDArrayFactory::create('c', {2, 2}, {1, 2, 3, 4}); + NDArray idc('c', {1}, std::vector({0}), sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {1, 2}, {-1, 1}); + auto exp = NDArrayFactory::create('c', {2, 2}, {-1, 1, 3, 4}); + + sd::ops::scatter_min op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(ParityOpsTests, scatterMin_test2) { + auto vec = NDArrayFactory::create('c', {4}, {1, 2, 3, 4}); + NDArray idc('c', {1, 4}, {0, 1, 2, 3}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {1, 4}, {10, 1, 30, 1}); + auto exp = NDArrayFactory::create('c', {1, 4}, {1, 1, 3, 1}); + + sd::ops::scatter_min op; + auto result = op.evaluate({&vec, &idc, &updates}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + +} + +TEST_F(ParityOpsTests, scatterMin_test3) { + auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + NDArray idc('c', {1}, std::vector({0}), sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {1, 2, 2}, {10, 1, 30, 2}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 1, 3, 2, 5, 6, 7, 8}); + + sd::ops::scatter_min op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.equalsTo(z)); + +} + +TEST_F(ParityOpsTests, scatterMin_test4) { + auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + NDArray idc('c', {1,2}, std::vector{0.,0}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {1, 2, 2, 2}, {1,10,1,10, 1,1,10,1.}); + auto exp = NDArrayFactory::create('c', {2, 2, 2}, {1, 1, 1, 1, 5, 6, 7, 8}); + + sd::ops::scatter_min op; + auto result = op.evaluate({&matrix, &idc, &updates}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + // z->printBuffer(); + + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatterMin_test5) { + auto matrix = NDArrayFactory::create('c', {2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}); + NDArray idc('c', {1,2}, {10,10}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {1, 2, 2, 2}, {1,10,1,10, 1,1,10,1.}); + auto output = NDArrayFactory::create('c', {2, 2, 2}); + + sd::ops::scatter_min op; + + ASSERT_ANY_THROW(op.execute({&matrix, &idc, &updates}, {&output}, {}, {}, {true, true})); +} + +//////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatterND_test1) { + + NDArray indices('c', {2, 1}, {1., 0.}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {2, 4}, {10.f, 20.f, 30.f, 40.f, 50.f, 60.f, 70.f, 80.f}); + auto shape = NDArrayFactory::create('c', {2}, {3, 4}); + auto exp = NDArrayFactory::create('c', {3, 4}, {50.f, 60.f, 70.f, 80.f, 10.f, 20.f, 30.f, 40.f, 0.f, 0.f, 0.f, 0.f}); + + sd::ops::scatter_nd op; + auto result = op.evaluate({&indices, &updates, &shape}, {}, {false, true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + // z->printBuffer(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatterND_test2) { + + NDArray indices('c', {3, 1}, {4., 2., 0.}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {3, 4}); + auto shape = NDArrayFactory::create('c', {2}, {5, 4}); + auto exp = NDArrayFactory::create('c', {5, 4}, {9.f,10.f,11.f,12.f, 0.f, 0.f, 0.f, 0.f, 5.f, 6.f, 7.f, 8.f, 0.f, 0.f, 0.f, 0.f, 1.f, 2.f, 3.f, 4.f}); + updates.linspace(1.f); + + sd::ops::scatter_nd op; + auto result = op.evaluate({&indices, &updates, &shape}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatterND_test3) { + + NDArray indices('c', {2, 3, 1}, {0., 2., 7., 3., 6., 9.}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {2,3, 3,4}); + auto shape = NDArrayFactory::create('c', {3}, {10, 3, 4}); + auto exp = NDArrayFactory::create('c', {10, 3, 4}, {1.f, 2.f, 3.f, 4., 5.f, 6.f, 7.f, 8., 9.f, 10.f, 11.f, 12., 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., + 13.f, 14.f, 15.f, 16.,17.f, 18.f, 19.f, 20.,21.f, 22.f, 23.f, 24.,37.f, 38.f, 39.f, 40.,41.f, 42.f, 43.f, 44.,45.f, 46.f, 47.f, 48., + 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., + 49.f, 50.f, 51.f, 52.,53.f, 54.f, 55.f, 56.,57.f, 58.f, 59.f, 60.,25.f, 26.f, 27.f, 28.,29.f, 30.f, 31.f, 32.,33.f, 34.f, 35.f, 36., + 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0., 0.f, 0.f, 0.f, 0.,61.f, 62.f, 63.f, 64.,65.f, 66.f, 67.f, 68.,69.f, 70.f, 71.f, 72.,}); + updates.linspace(1.f); + + sd::ops::scatter_nd op; + auto result = op.evaluate({&indices, &updates, &shape}, {}, {false, true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatterND_test4) { + + NDArray indices('c', {4, 1}, {4., 3., 1., 7.}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {4}, {9.f, 10.f, 11.f, 12.f}); + auto shape = NDArrayFactory::create('c', {1}, {8}); + auto exp = NDArrayFactory::create('c', {8}, {0.f, 11.f, 0.f, 10.f, 9.f, 0.f, 0.f, 12.f}); + + sd::ops::scatter_nd op; + auto result = op.evaluate({&indices, &updates, &shape}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatterND_test5) { + + NDArray indices('c', {4, 1}, {1, 1, 1, 1}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {4}, {1.f, 2.f, 3.f, 4.f}); + auto shape = NDArrayFactory::create('c', {1}, {8}); + auto exp = NDArrayFactory::create('c', {8}, {0.f, 10.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); + + sd::ops::scatter_nd op; + auto result = op.evaluate({&indices, &updates, &shape}, {}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + // z->printBuffer(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +//////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatterND_test6) { + + NDArray indices('c', {3, 2}, {0,1,1,0,3,2}, sd::DataType::INT32); + NDArray updates('c', {3, 2, 3}, sd::DataType::FLOAT32); + NDArray shape('c', {4}, {5,4,2,3}, sd::DataType::INT32); + + NDArray exp('c', {5,4,2,3}, {0., 0., 0.,0., 0., 0.,1., 2., 3.,4., 5., 6.,0., 0., 0.,0., 0., 0., 0., 0., 0.,0., 0., 0., + 7., 8., 9., 10., 11., 12., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 13., 14., 15., 16., 17., 18., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}, sd::DataType::FLOAT32); + updates.linspace(1); + + sd::ops::scatter_nd op; + auto result = op.evaluate({&indices, &updates, &shape}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + // z->printBuffer(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatterND_test7) { + + NDArray indices('c', {4,3,2}, {0,1,1,0,3,2,1,0,0,1,1,0,3,2,1,0,0,1,1,0,3,2,1,0}, sd::DataType::INT32); + NDArray updates('c', {4,3,2,3}, sd::DataType::FLOAT32); + NDArray shape('c', {4}, {5,4,2,3}, sd::DataType::INT32); + + NDArray exp('c', {5,4,2,3}, {0., 0., 0., 0., 0., 0., 75., 78., 81., 84., 87., 90., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 222., 228., 234., 240., 246., 252., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 111., 114., 117., 120., 123., 126., 0., 0., 0., 0., 0., 0., + 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.}, sd::DataType::FLOAT32); + updates.linspace(1); + + sd::ops::scatter_nd op; + auto result = op.evaluate({&indices, &updates, &shape}, {}, {}, {true, true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + // z->printBuffer(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatterND_test8) { + + NDArray indices('c', {3, 2}, {0,0, 1,1, 2,2}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {3}, {1.f, 2.f, 3.f}); + auto shape = NDArrayFactory::create('c', {2}, {6,4}); + auto exp = NDArrayFactory::create('c', {6,4}, {1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); + + sd::ops::scatter_nd op; + auto result = op.evaluate({&indices, &updates, &shape}, {}, {true}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + // z->printBuffer(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatterND_test9) { + + NDArray indices('c', {2, 3, 1}, {0., 20., 7., 30., 6., 90.}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {2,3, 3,4}); + auto shape = NDArrayFactory::create('c', {3}, {10, 3, 4}); + auto output = NDArrayFactory::create('c', {10, 3, 4}); + + sd::ops::scatter_nd op; + + ASSERT_ANY_THROW(auto result = op.execute({&indices, &updates, &shape}, {&output}, {}, {}, {false, true})); +} + + +//////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatterND_add_test1) { + + auto input = NDArrayFactory::create('c', {8}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + NDArray indices('c', {4, 1}, {4., 3., 1., 7.}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {4}, {9.f, 10.f, 11.f, 12.f}); + auto exp = NDArrayFactory::create('c', {8}, {1.f, 13.f, 3.f, 14.f, 14.f, 6.f, 7.f, 20.f}); + + sd::ops::scatter_nd_add op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatterND_add_test2) { + + auto input = NDArrayFactory::create('c', {6, 4}); + NDArray indices('c', {3, 3, 2}, {0.f,0.f, 1.f,1.f, 2.f,2.f, 3.f,3.f, 4.f,0.f, 5.f,1.f, 0.f,2.f, 1.f,3.f, 2.f,0.f}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {3,3}); + auto exp = NDArrayFactory::create('c', {6,4}, {1.f,0.f,7.f,0.f, 0.f,2.f,0.f,8.f, 9.f,0.f,3.f,0.f, 0.f,0.f,0.f,4.f, 5.f,0.f,0.f,0.f, 0.f,6.f,0.f,0.f}); + + input = 0.f; + updates.linspace(1.f); + + sd::ops::scatter_nd_add op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + // z->printIndexedBuffer(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatterND_add_test3) { + + auto input = NDArrayFactory::create('c', {6, 4}); + NDArray indices('c', {2, 3, 1}, {5.f, 1.f, 2.f, 3.f, 4.f, 0.f}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {2,3,4}); + auto exp = NDArrayFactory::create('c', {6,4}, {21.f, 22.f, 23.f, 24.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f, 1.f, 2.f, 3.f, 4.f}); + + input = 0.f; + updates.linspace(1.f); + + sd::ops::scatter_nd_add op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatterND_add_test4) { + + auto input = NDArrayFactory::create('c', {6, 4, 5}); + NDArray indices('c', {3, 3, 2}, {0.f,0.f, 1.f,1.f, 2.f,2.f, 3.f,3.f, 4.f,0.f, 5.f,1.f, 0.f,2.f, 1.f,3.f, 2.f,0.f}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {3,3,5}); + auto exp = NDArrayFactory::create('c', {6,4,5}, {1.f, 2.f, 3.f, 4.f, 5.f, 0.f, 0.f, 0.f, 0.f, 0.f,31.f, 32.f, 33.f, 34.f, 35.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 6.f, 7.f, 8.f, 9.f, 10.f, 0.f, 0.f, 0.f, 0.f, 0.f,36.f, 37.f, 38.f, 39.f, 40.f, + 41.f, 42.f, 43.f, 44.f, 45.f, 0.f, 0.f, 0.f, 0.f, 0.f,11.f, 12.f, 13.f, 14.f, 15.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,16.f, 17.f, 18.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, 25.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f,26.f, 27.f, 28.f, 29.f, 30.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); + input = 0.f; + updates.linspace(1.f); + + sd::ops::scatter_nd_add op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatterND_add_test5) { + + auto input = NDArrayFactory::create('c', {6,5,4,3,2}); + NDArray indices('c', {2,2,3}, {0.f,0.f,0.f, 1.f,1.f,1.f, 2.f,2.f,2.f, 3.f,3.f,3.f}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {2,2,3,2}); + auto exp = NDArrayFactory::create('c', {6,5,4,3,2}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 7.f, 8.f, 9.f, 10.f,11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,13.f, 14.f,15.f, 16.f,17.f, 18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,19.f, 20.f,21.f, 22.f,23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); + input = 0.f; + updates.linspace(1.f); + + sd::ops::scatter_nd_add op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatterND_add_test6) { + + auto input = NDArrayFactory::create('c', {6, 4}); + NDArray indices('c', {2, 3, 1}, {50.f, 1.f, 2.f, 3.f, 40.f, 0.f}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {2,3,4}); + auto output = NDArrayFactory::create('c', {6,4}); + + sd::ops::scatter_nd_add op; + + ASSERT_ANY_THROW(op.execute({&input, &indices, &updates}, {&output}, {}, {}, {false, true})); +} + +//////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatterND_sub_test1) { + + auto input = NDArrayFactory::create('c', {8}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + NDArray indices('c', {4, 1}, {4.f, 3.f, 1.f, 7.f}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {4}, {9.f, 10.f, 11.f, 12.f}); + auto exp = NDArrayFactory::create('c', {8}, {1.f, -9.f, 3.f, -6.f, -4.f, 6.f, 7.f, -4.f}); + + sd::ops::scatter_nd_sub op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatterND_sub_test2) { + + auto input = NDArrayFactory::create('c', {6, 4}); + NDArray indices('c', {3, 3, 2}, {0.f,0.f, 1.f,1.f, 2.f,2.f, 3.f,3.f, 4.f,0.f, 5.f,1.f, 0.f,2.f, 1.f,3.f, 2.f,0.f}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {3,3}); + auto exp = NDArrayFactory::create('c', {6,4}, {-1.f,0.f,-7.f,0.f, 0.f,-2.f,0.f,-8.f, -9.f,0.f,-3.f,0.f, 0.f,0.f,0.f,-4.f, -5.f,0.f,0.f,0.f, 0.f,-6.f,0.f,0.f}); + + input = 0.f; + updates.linspace(1.f); + + sd::ops::scatter_nd_sub op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + //exp.printIndexedBuffer("e"); + //z->printIndexedBuffer("z"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatterND_sub_test3) { + + auto input = NDArrayFactory::create('c', {6, 4}); + NDArray indices('c', {2, 3, 1}, {5.f, 1.f, 2.f, 3.f,4.f, 0.f}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {2,3,4}); + auto exp = NDArrayFactory::create('c', {6,4}, {-21.f,-22.f,-23.f,-24., -5.f, -6.f, -7.f, -8., -9.f,-10.f,-11.f,-12., -13.f,-14.f,-15.f,-16., -17.f,-18.f,-19.f,-20., -1.f, -2.f, -3.f, -4.f}); + + input = 0.f; + updates.linspace(1.f); + + sd::ops::scatter_nd_sub op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatterND_sub_test4) { + + auto input = NDArrayFactory::create('c', {6, 4, 5}); + NDArray indices('c', {3, 3, 2}, {0.f,0.f, 1.f,1.f, 2.f,2.f, 3.f,3.f, 4.f,0.f, 5.f,1.f, 0.f,2.f, 1.f,3.f, 2.f,0.f}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {3,3,5}); + auto exp = NDArrayFactory::create('c', {6,4,5}, {-1.f, -2.f, -3.f, -4.f, -5.f, 0.f, 0.f, 0.f, 0.f, 0.f,-31.f, -32.f, -33.f, -34.f, -35.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, -6.f, -7.f, -8.f, -9.f, -10.f, 0.f, 0.f, 0.f, 0.f, 0.f,-36.f, -37.f, -38.f, -39.f, -40.f, + -41.f, -42.f, -43.f, -44.f, -45.f, 0.f, 0.f, 0.f, 0.f, 0.f,-11.f, -12.f, -13.f, -14.f, -15.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,-16.f, -17.f, -18.f, -19.f, -20.f, + -21.f, -22.f, -23.f, -24.f, -25.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f,-26.f, -27.f, -28.f, -29.f, -30.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); + input = 0.f; + updates.linspace(1.f); + + sd::ops::scatter_nd_sub op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatterND_sub_test5) { + + auto input = NDArrayFactory::create('c', {6,5,4,3,2}); + NDArray indices('c', {2,2,3}, {0.f,0.f,0.f, 1.f,1.f,1.f, 2.f,2.f,2.f, 3.f,3.f,3.f}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {2,2,3,2}); + auto exp = NDArrayFactory::create('c', {6,5,4,3,2}, { -1.f, -2.f, -3.f, -4.f, -5.f, -6.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, -7.f, -8.f, -9.f, -10.f,-11.f, -12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,-13.f, -14.f,-15.f, -16.f,-17.f, -18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,-19.f, -20.f,-21.f, -22.f,-23.f,-24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, +0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f}); + input = 0.f; + updates.linspace(1.f); + + sd::ops::scatter_nd_sub op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +//////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatterND_update_test1) { + + auto input = NDArrayFactory::create('c', {8}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f}); + NDArray indices('c', {4, 1}, {4.f, 3.f, 1.f, 7.f}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {4}, {9.f, 10.f, 11.f, 12.f}); + auto exp = NDArrayFactory::create('c', {8}, {1.f, 11.f, 3.f, 10.f, 9.f, 6.f, 7.f, 12.f}); + + sd::ops::scatter_nd_update op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatterND_update_test2) { + + auto input = NDArrayFactory::create('c', {6, 4}); + NDArray indices('c', {3, 3, 2}, {0.f,0.f, 1.f,1.f, 2.f,2.f, 3.f,3.f, 4.f,0.f, 5.f,1.f, 0.f,2.f, 1.f,3.f, 2.f,0.f}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {3,3}); + auto exp = NDArrayFactory::create('c', {6,4}, {1.f,-1.f,7.f,-1.f, -1.f,2.f,-1.f,8.f, 9.f,-1.f,3.f,-1.f, -1.f,-1.f,-1.f,4.f, 5.f,-1.f,-1.f,-1.f, -1.f,6.f,-1.f,-1.f}); + + input = -1.f; + updates.linspace(1.f); + + sd::ops::scatter_nd_update op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + // z->printIndexedBuffer(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatterND_update_test3) { + + auto input = NDArrayFactory::create('c', {6, 4}); + NDArray indices('c', {2, 3, 1}, {5.f, 1.f, 2.f, 3.f, 4.f, 0.f}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {2,3,4}); + auto exp = NDArrayFactory::create('c', {6,4}, {21.f, 22.f, 23.f, 24.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f, 1.f, 2.f, 3.f, 4.f,}); + + input = -1.f; + updates.linspace(1.f); + + sd::ops::scatter_nd_update op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + // z->printBuffer(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatterND_update_test4) { + + auto input = NDArrayFactory::create('c', {6, 4, 5}); + NDArray indices('c', {3, 3, 2}, {0.f,0.f, 1.f,1.f, 2.f,2.f, 3.f,3.f, 4.f,0.f, 5.f,1.f, 0.f,2.f, 1.f,3.f, 2.f,0.f}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {3,3,5}); + auto exp = NDArrayFactory::create('c', {6,4,5}, {1.f, 2.f, 3.f, 4.f, 5.f, -1.f, -1.f, -1.f, -1.f, -1.f,31.f, 32.f, 33.f, 34.f, 35.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, 6.f, 7.f, 8.f, 9.f, 10.f, -1.f, -1.f, -1.f, -1.f, -1.f,36.f, 37.f, 38.f, 39.f, 40.f, + 41.f, 42.f, 43.f, 44.f, 45.f, -1.f, -1.f, -1.f, -1.f, -1.f,11.f, 12.f, 13.f, 14.f, 15.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f,16.f, 17.f, 18.f, 19.f, 20.f, + 21.f, 22.f, 23.f, 24.f, 25.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f,26.f, 27.f, 28.f, 29.f, 30.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f}); + input = -1.f; + updates.linspace(1.f); + + sd::ops::scatter_nd_update op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatterND_update_test5) { + + auto input = NDArrayFactory::create('c', {6,5,4,3,2}); + NDArray indices('c', {2,2,3}, {0.f,0.f,0.f, 1.f,1.f,1.f, 2.f,2.f,2.f, 3.f,3.f,3.f}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {2,2,3,2}); + auto exp = NDArrayFactory::create('c', {6,5,4,3,2}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, +-1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, +-1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, +-1.f, -1.f, -1.f, -1.f, -1.f, -1.f, 7.f, 8.f, 9.f, 10.f,11.f, 12.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, +-1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, +-1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, +-1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f,13.f, 14.f,15.f, 16.f,17.f, 18.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, +-1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, +-1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, +-1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f,19.f, 20.f,21.f, 22.f,23.f, 24.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, +-1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, +-1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, +-1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, +-1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, +-1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f}); + input = -1.f; + updates.linspace(1.f); + + sd::ops::scatter_nd_update op; + auto result = op.evaluate({&input, &indices, &updates}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +//////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatterND_update_test6) { + + auto input = NDArrayFactory::create('c', {6, 4}); + NDArray indices('c', {3, 3, 2}, {0.f,0.f, 10.f,1.f, 20.f,2.f, 30.f,3.f, 40.f,0.f, 50.f,1.f, 0.f,2.f, 1.f,3.f, 2.f,0.f}, sd::DataType::INT32); + auto updates = NDArrayFactory::create('c', {3,3}); + auto output = NDArrayFactory::create('c', {6,4}); + + sd::ops::scatter_nd_update op; + + ASSERT_ANY_THROW(op.execute({&input, &indices, &updates}, {&output}, {}, {}, {true, true})); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatter_update_1) { + + NDArray x('c', {2,2}, {1,2,3,4}, sd::DataType::INT32); + NDArray updates('c', {2,2}, {10,20,30,40}, sd::DataType::INT32); + + NDArray exp('c', {2,2}, {30,40,10,20}, sd::DataType::INT32); + + sd::ops::scatter_update op; + auto results = op.evaluate({&x, &updates}, {}, {6, 1,1, 2,1,0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + // x.printBuffer(); + + ASSERT_TRUE(exp.isSameShape(x)); + ASSERT_TRUE(exp.equalsTo(x)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatter_update_2) { + + NDArray x('c', {2,2}, {1,2,3,4}, sd::DataType::INT32); + NDArray updates('c', {2,2}, {10,20,30,40}, sd::DataType::INT32); + + NDArray exp('c', {2,2}, {20,10,40,30}, sd::DataType::INT32); + + sd::ops::scatter_update op; + auto results = op.evaluate({&x, &updates}, {}, {6, 1,0, 2,1,0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(exp.isSameShape(x)); + ASSERT_TRUE(exp.equalsTo(x)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatter_update_3) { + + NDArray x('c', {2,2,2}, {1,2,3,4,5,6,7,8}, sd::DataType::INT32); + NDArray updates('c', {2,2,2}, {10,20,30,40,50,60,70,80}, sd::DataType::INT32); + + NDArray exp('c', {2,2,2}, {50,60,70,80,10,20,30,40}, sd::DataType::INT32); + + sd::ops::scatter_update op; + auto results = op.evaluate({&x, &updates}, {}, {6, 2,1,2, 2,1,0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(exp.isSameShape(x)); + ASSERT_TRUE(exp.equalsTo(x)); + + +} + +////////////////////////////////////////////////////////////////////// +TEST_F(ParityOpsTests, scatter_update_4) { + + NDArray x('c', {2,2,2}, {1,2,3,4,5,6,7,8}, sd::DataType::INT32); + NDArray updates('c', {2,2,2}, {10,20,30,40,50,60,70,80}, sd::DataType::INT32); + + NDArray exp('c', {2,2,2}, {20,2,3,10,60,6,7,50}, sd::DataType::INT32); + + sd::ops::scatter_update op; + auto results = op.evaluate({&x, &updates}, {}, {6, 1,0, 2,3,0}); + + ASSERT_EQ(ND4J_STATUS_OK, results.status()); + + ASSERT_TRUE(exp.isSameShape(x)); + ASSERT_TRUE(exp.equalsTo(x)); + + +} diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/PerformanceTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/PerformanceTests.cpp new file mode 100644 index 000000000..48e32d701 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/PerformanceTests.cpp @@ -0,0 +1,146 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +using namespace sd; +using namespace sd::graph; + +class PerformanceTests : public testing::Test { +public: + int numIterations = 100; + + PerformanceTests() { + samediff::ThreadPool::getInstance(); + } +}; + + +#ifdef RELEASE_BUILD + +TEST_F(PerformanceTests, test_matmul_c_f_1) { + int iterations = 500; + std::vector valuesC, valuesF; + for (int e = 0; e < iterations; e++) { + auto xc = NDArrayFactory::create('c', {512, 2048}); + auto yc = NDArrayFactory::create('c', {2048, 512}); + auto zc = NDArrayFactory::create('c', {512, 512}); + + auto xf = NDArrayFactory::create('f', {512, 2048}); + auto yf = NDArrayFactory::create('f', {2048, 512}); + auto zf = NDArrayFactory::create('f', {512, 512}); + + auto warm = xc.like(); + warm.linspace(1.0); + + //zc.linspace(1.0); + //zf.linspace(1.0); + + sd::ops::matmul op; + + auto timeStartF = std::chrono::system_clock::now(); + + op.execute({&xf, &yf}, {&zf}); + + auto timeEndF = std::chrono::system_clock::now(); + auto outerTimeF = std::chrono::duration_cast(timeEndF - timeStartF).count(); + + + auto timeStartC = std::chrono::system_clock::now(); + + op.execute({&xc, &yc}, {&zc}); + + auto timeEndC = std::chrono::system_clock::now(); + auto outerTimeC = std::chrono::duration_cast(timeEndC - timeStartC).count(); + + valuesF.emplace_back(outerTimeF); + valuesC.emplace_back(outerTimeC); + } + + std::sort(valuesC.begin(), valuesC.end()); + std::sort(valuesF.begin(), valuesF.end()); + + + nd4j_printf("Median time C: [%lld]; Median time F: [%lld];", valuesC[valuesC.size() / 2], valuesF[valuesF.size() / 2]); +} + +TEST_F(PerformanceTests, test_maxpooling2d_1) { + std::vector valuesX; + // auto x = NDArrayFactory::create('c', {32, 3, 224, 224}); + // auto z = NDArrayFactory::create('c', {32, 3, 224, 224}); + auto x = NDArrayFactory::create('c', {8, 3, 64, 64}); + auto z = NDArrayFactory::create('c', {8, 3, 64, 64}); + x.linspace(1.0f); + Nd4jLong k = 5; + + + Nd4jLong iArgs[] {k,k, 1,1, 0,0, 1,1, 1}; + Context ctx(1); + ctx.setInputArray(0, &x); + ctx.setOutputArray(0, &z); + ctx.setIArguments(iArgs, 9); + + sd::ops::maxpool2d op; + + for (int i = 0; i < numIterations; i++) { + auto timeStart = std::chrono::system_clock::now(); + + op.execute(&ctx); + + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); + valuesX.emplace_back(outerTime); + + if ((i + 1) % 1000 == 0) + nd4j_printf("Iteration %i finished...\n", i + 1); + } + + std::sort(valuesX.begin(), valuesX.end()); + nd4j_printf("Execution time: %lld; Min: %lld; Max: %lld;\n", valuesX[valuesX.size() / 2], valuesX[0], valuesX[valuesX.size() - 1]); +} + +#endif \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/PlaygroundTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/PlaygroundTests.cpp new file mode 100644 index 000000000..f476fb05a --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/PlaygroundTests.cpp @@ -0,0 +1,1686 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +// +// Created by raver119 on 20.11.17. +// + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +using namespace sd; +using namespace sd::graph; + +class PlaygroundTests : public testing::Test { +public: + int numIterations = 3; + int poolSize = 10; + + PlaygroundTests() { + } +}; + +TEST_F(PlaygroundTests, test_avx) { + nd4j_printf("Optimal level: %i; Binary level: %i;\n", ::optimalLevel(), ::binaryLevel()); +} + +TEST_F(PlaygroundTests, buildver) { + nd4j_printf("%s\n", buildInfo()); +} + +TEST_F(PlaygroundTests, test_biasAdd_1) { + auto x = NDArrayFactory::create('c', {512, 3072}); + auto y = NDArrayFactory::create('c', {3072}); + + std::vector values; + + sd::ops::biasadd op; + + for (int e = 0; e < 100; e++) { + auto timeStart = std::chrono::system_clock::now(); + + op.execute({&x, &y}, {&x}); + + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); + values.emplace_back(outerTime); + } + + std::sort(values.begin(), values.end()); + + nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); +} + + +TEST_F(PlaygroundTests, test_bert_full_1) { +#ifdef _RELEASE + + // this test will run ONLY if this model exists + if (sd::graph::getFileSize("/home/raver119/Downloads/BertFull/model.fb") < 0) + return; + + auto graph = GraphExecutioner::importFromFlatBuffers("/home/raver119/Downloads/BertFull/model.fb"); + + auto t = NDArrayFactory::fromNpyFile("/home/raver119/Downloads/BertFull/in0_IteratorGetNext.npy"); + auto u = NDArrayFactory::fromNpyFile("/home/raver119/Downloads/BertFull/in1_IteratorGetNext_1.npy"); + auto v = NDArrayFactory::fromNpyFile("/home/raver119/Downloads/BertFull/in2_IteratorGetNext_4.npy"); + auto z = NDArrayFactory::fromNpyFile("/home/raver119/Downloads/BertFull/out_loss-Softmax.npy"); + + //graph->printOut(); + + graph->tagInplaceNodes(); + + graph->getVariableSpace()->putVariable(658,0, t); + graph->getVariableSpace()->putVariable(659,0, u); + graph->getVariableSpace()->putVariable(660,0, v); + +/* + // validating graph now + auto status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(1620)); + + auto array = graph->getVariableSpace()->getVariable(1620)->getNDArray(); + ASSERT_EQ(z, *array); + +*/ + + sd::Environment::getInstance().setProfiling(true); + auto profile = GraphProfilingHelper::profile(graph, 1); + + profile->printOut(); + + sd::Environment::getInstance().setProfiling(false); + delete profile; + +/* + std::vector values; + + for (int e = 0; e < 1; e++) { + auto timeStart = std::chrono::system_clock::now(); + + GraphExecutioner::execute(graph); + + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); + values.emplace_back(outerTime); + } + + std::sort(values.begin(), values.end()); + + nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); +*/ + delete graph; +#endif +} + + +TEST_F(PlaygroundTests, test_bert_1) { +#ifdef _RELEASE + // this test will run ONLY if this model exists + if (sd::graph::getFileSize("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_model.fb") < 0) + return; + + auto graph = GraphExecutioner::importFromFlatBuffers("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_model.fb"); + + auto t = NDArrayFactory::fromNpyFile("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_input_IteratorGetNext.numpy"); + auto u = NDArrayFactory::fromNpyFile("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_input_IteratorGetNext_1.numpy"); + auto v = NDArrayFactory::fromNpyFile("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_input_IteratorGetNext_4.numpy"); + auto z = NDArrayFactory::fromNpyFile("/home/raver119/Downloads/Bert_minimal_model/bert_minimal_model_output.numpy"); + + //graph->printOut(); + + graph->tagInplaceNodes(); + + graph->getVariableSpace()->putVariable(85,0, t); + graph->getVariableSpace()->putVariable(86,0, u); + graph->getVariableSpace()->putVariable(87,0, v); + +/* + // validating graph now + auto status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(198)); + + auto array = graph->getVariableSpace()->getVariable(198)->getNDArray(); + ASSERT_EQ(z, *array); + +*/ + sd::Environment::getInstance().setProfiling(true); + auto profile = GraphProfilingHelper::profile(graph, 1); + + profile->printOut(); + + sd::Environment::getInstance().setProfiling(false); + delete profile; + +/* + std::vector values; + + for (int e = 0; e < 1; e++) { + auto timeStart = std::chrono::system_clock::now(); + + GraphExecutioner::execute(graph); + + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); + values.emplace_back(outerTime); + } + + std::sort(values.begin(), values.end()); + + nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); +*/ + delete graph; +#endif +} + +TEST_F(PlaygroundTests, test_bert_2) { +#ifdef _RELEASE + // this test will run ONLY if this model exists + if (sd::graph::getFileSize("/home/raver119/Downloads/Bert_minimal_model/bert_like_ops.fb") < 0) + return; + + auto graph = GraphExecutioner::importFromFlatBuffers("/home/raver119/Downloads/Bert_minimal_model/bert_like_ops.fb"); + + //graph->printOut(); + + graph->tagInplaceNodes(); + + +/* + // validating graph now + auto status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + ASSERT_TRUE(graph->getVariableSpace()->hasVariable(198)); + + auto array = graph->getVariableSpace()->getVariable(198)->getNDArray(); + ASSERT_EQ(z, *array); +*/ + + sd::Environment::getInstance().setProfiling(true); + auto profile = GraphProfilingHelper::profile(graph, 1); + + profile->printOut(); + + sd::Environment::getInstance().setProfiling(false); + delete profile; + +/* + std::vector values; + + for (int e = 0; e < 1; e++) { + auto timeStart = std::chrono::system_clock::now(); + + GraphExecutioner::execute(graph); + + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); + values.emplace_back(outerTime); + } + + std::sort(values.begin(), values.end()); + + nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); +*/ + delete graph; +#endif +} + + +TEST_F(PlaygroundTests, test_one_off_ops_1) { + auto x = NDArrayFactory::create('c', {4, 128, 768}); + auto y = NDArrayFactory::create('c', {4, 128, 1}); + auto z = x.ulike(); + + sd::ops::squaredsubtract op; + op.execute({&x, &y}, {&z}); +} + +#if defined(INDEX_REDUCTIONS_BENCH_TESTS) +//temporarly, testing against the original one +void original_argmax(const NDArray& input, std::vector& axis, NDArray& output) { + sd::ops::helpers::adjustAxis(input.rankOf(), axis); + input.applyIndexReduce(sd::indexreduce::IndexMax, output, axis); +} + +template +void fill_random(sd::NDArray& arr) { + Nd4jLong coords[MAX_RANK] = {}; + std::random_device rd; + std::mt19937 gen(rd()); + //for floats + std::uniform_real_distribution dis((T)-10.0, (T)22.9); + T* x = arr.bufferAsT(); + Nd4jLong* shapeInfo = arr.getShapeInfo(); + Nd4jLong* strides = arr.stridesOf(); + Nd4jLong rank = shapeInfo[0]; + Nd4jLong* bases = &(shapeInfo[1]); + size_t t = 1; + for (size_t i = 0; i < rank ; i++) { + t *= bases[i]; + } + size_t offset = 0; + if (arr.ordering() == 'c') { + + for (size_t i = 0; i < t; i++) { + x[offset] = dis(gen) ; + offset = sd::inc_coords(bases, strides, coords, offset, rank); + } + + } + else { + + for (size_t i = 0; i < t; i++) { + x[offset] = dis(gen) ; + offset = sd::inc_coords(bases, strides, coords, offset, rank); + } + + } +} + +void testLegacy(bool random) { +#if 0 + int bases[] = { 3, 2, 4, 5, 7 }; + constexpr int Loop = 1; +#else + int bases[] = { 8, 32, 64, 32, 64 }; + constexpr int Loop = 10; +#endif + constexpr int N = 5; + + auto x = NDArrayFactory::create('c', { bases[0], bases[1], bases[2], bases[3], bases[4] }); + if (!random) { + x.linspace(1); + } + else{ + fill_random(x); + } + +#define COMBINATIONS 1 +#if COMBINATIONS +//https://www.rosettacode.org/wiki/Combinations#C.2B.2B +for (int k = N; k >= 1; k--) { + + std::string bitmask(k, 1); // K leading 1's + bitmask.resize(N, 0); // N-K trailing 0's + + do { + + + std::vector dimension; + + std::vector output_bases; + + for (int i = 0; i < N; ++i) // [0..N-1] integers + { + if (bitmask[i]) dimension.push_back(i); + else { + output_bases.push_back(bases[i]); + } + } +#else +std::vector dimension = { 0,1,2,3 }; +int k = 4; +#endif +auto dim = NDArrayFactory::create(dimension); + +#if 1 +nd4j_printf("C(N:%d K:%d) \n", N, k); +dim.printIndexedBuffer("Dimension"); +for (int xind : dimension) { + nd4j_printf(" %d ,", bases[xind]); +} +nd4j_printf("%s", "\n"); +#endif + + + +std::vector values; +sd::ResultSet result; +for (int e = 0; e < Loop; e++) { + auto timeStart = std::chrono::system_clock::now(); + NDArray exp = output_bases.size() > 0 ? NDArrayFactory::create('c', output_bases) : NDArrayFactory::create(0); + original_argmax(x, dimension, exp); + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); + values.emplace_back(outerTime); +} + +std::sort(values.begin(), values.end()); + +nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); +#if COMBINATIONS + + } while (std::prev_permutation(bitmask.begin(), bitmask.end())); + +} +#endif +} + +#define DEBUG 1 + +void testNewReduction(bool random, bool checkCorrectness = false , char order ='c') { + std::vector arr_dimensions; +#if defined(DEBUG) + int bases[] = { 3, 2, 3, 3, 5 ,4,7,4,7,7 }; + constexpr int Loop = 1; + constexpr int N = 10; +#else + int bases[] = { 8, 32, 64, 32, 64 }; + constexpr int Loop = 10; + constexpr int N = 5; + +#endif + + for (int i = 0; i < N; i++) { + arr_dimensions.push_back(bases[i]); + } + auto x = NDArrayFactory::create(order,arr_dimensions); + if (!random) { + x.linspace(1); + } + else { + fill_random(x); + } + +#define COMBINATIONS 1 +#if COMBINATIONS + //https://www.rosettacode.org/wiki/Combinations#C.2B.2B + for (int k = N; k >= 1; k--) { + + std::string bitmask(k, 1); // K leading 1's + bitmask.resize(N, 0); // N-K trailing 0's + + do { + + + std::vector dimension; + + std::vector output_bases; + + for (int i = 0; i < N; ++i) // [0..N-1] integers + { + if (bitmask[i]) dimension.push_back(i); + else { + output_bases.push_back(bases[i]); + } + } +#else + std::vector dimension = { 0,1,2,3 }; + int k = 4; +#endif + auto dim = NDArrayFactory::create(dimension); + +#if 1 + nd4j_printf("C(N:%d K:%d) \n", N, k); + dim.printIndexedBuffer("Dimension"); + for (int xind : dimension) { + nd4j_printf(" %d ,", bases[xind]); + } + nd4j_printf("%s", "\n"); +#endif + + + sd::ops::argmax op; + std::vector values; + sd::ResultSet result; + for (int e = 0; e < Loop; e++) { + auto timeStart = std::chrono::system_clock::now(); + result = op.evaluate({ &x, &dim }, {}, {}); + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); + values.emplace_back(outerTime); + } + auto z = result.at(0); + + if (checkCorrectness) { + //check for the correctness + NDArray exp = output_bases.size() > 0 ? NDArrayFactory::create('c', output_bases) : NDArrayFactory::create(0); + original_argmax(x, dimension, exp); + + +#if 0// defined(DEBUG) + x.printIndexedBuffer("X"); + exp.printIndexedBuffer("Expected"); + z->printIndexedBuffer("Z"); +#endif + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + } + std::sort(values.begin(), values.end()); + + nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); +#if COMBINATIONS + + } while (std::prev_permutation(bitmask.begin(), bitmask.end())); + + } +#endif +} + +constexpr bool test_corr = true; +#if !defined(DEBUG) +TEST_F(PlaygroundTests, ArgMaxPerfLinspace) { + testNewReduction(false, test_corr); +} +#endif + +TEST_F(PlaygroundTests, ArgMaxPerfRandom) { + testNewReduction(true, test_corr); +} + +TEST_F(PlaygroundTests, ArgMaxPerfRandomOrderF) { + testNewReduction(true, test_corr, 'f'); +} + +#if !defined(DEBUG) +TEST_F(PlaygroundTests, ArgMaxPerfLegacyLinspace) { + testLegacy(false); +} + +TEST_F(PlaygroundTests, ArgMaxPerfLegacyRandom) { + testLegacy(true); +} + +#endif + +#endif + +/* + +TEST_F(PlaygroundTests, test_broadcast_1) { + int pool = 1000; + std::vector aX(pool); + std::vector aY(pool); + std::vector aZ(pool); + + for (int e = 0; e < pool; e++) { + aX[e] = NDArrayFactory::create_('c', {512, 3072}); + aY[e] = NDArrayFactory::create_('c', {3072}); + aZ[e] = NDArrayFactory::create_('c', {512, 3072}); + + aX[e]->assign(119 * (e+1)); + aY[e]->assign(119 * (e+3)); + } + + std::vector values; + Context ctx(1); + + sd::ops::biasadd op; + + for (int e = 0; e < 1000; e++) { + auto x = aX[e < pool ? e : e % pool]; + auto y = aY[e < pool ? e : e % pool]; + auto z = aZ[e < pool ? e : e % pool]; + + auto timeStart = std::chrono::system_clock::now(); + + //op.execute({x, y}, {z}); + sd::ops::helpers::addBias(ctx, *x, *y, *z, false); + + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); + values.emplace_back(outerTime); + } + + std::sort(values.begin(), values.end()); + + nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); + + for (int e = 0; e < pool; e++) { + delete aX[e]; + delete aY[e]; + delete aZ[e]; + } +} + + +/* +TEST_F(PlaygroundTests, test_broadcast_1) { + int pool = 500; + std::vector aX(pool); + std::vector aY(pool); + std::vector aZ(pool); + + for (int e = 0; e < pool; e++) { + aX[e] = NDArrayFactory::create_('c', {512, 3072}); + aY[e] = NDArrayFactory::create_('c', {768}); + aZ[e] = NDArrayFactory::create_('c', {512, 3072}); + + aX[e]->assign( (e+1) / 119); + aY[e]->assign( (e+3) / 119); + } + + + + std::vector values; + + for (int e = 0; e < 1000; e++) { + auto x = aX[e < pool ? e : e % pool]; + auto y = aY[e < pool ? e : e % pool]; + auto z = aZ[e < pool ? e : e % pool]; + + auto timeStart = std::chrono::system_clock::now(); + + //x->applyTrueBroadcast(BroadcastOpsTuple::Multiply(), *y, *z); + x->applyTransform(transform::Tanh, *z, nullptr); + + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); + values.emplace_back(outerTime); + } + + std::sort(values.begin(), values.end()); + + nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); + + for (int e = 0; e < pool; e++) { + delete aX[e]; + delete aY[e]; + delete aZ[e]; + } +} + +*/ +/* + +TEST_F(PlaygroundTests, test_s_0) { + std::vector> shapes = {{32, 224, 224, 3}, {32, 56, 56, 64}, {32, 7, 7, 512}}; + std::vector threads = {1, 2, 4, 8, 16}; + + for (auto shape: shapes) { + for (auto t: threads) { + sd::Environment::getInstance().setMaxMasterThreads(t); + + auto x = NDArrayFactory::create('c', shape); + auto y = NDArrayFactory::create('c', {shape[3]}); + auto z = x.ulike(); + + std::vector values; + Context ctx(1); + ctx.setInputArray(0, &x); + ctx.setInputArray(1, &y); + ctx.setOutputArray(0, &z); + + sd::ops::biasadd op; + + + for (int e = 0; e < 10000; e++) { + auto timeStart = std::chrono::system_clock::now(); + + op.execute(&ctx); + sd::ops::helpers::addBias(ctx, x, y, z, false); + + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); + values.emplace_back(outerTime); + } + + std::sort(values.begin(), values.end()); + + nd4j_printf("Shape: [%lld, %lld, %lld, %lld]; Threads: [%i]; Time: %lld us;\n", shape[0], shape[1], shape[2], shape[3], t, values[values.size() / 2]); + } + } +} + +TEST_F(PlaygroundTests, test_s_1) { + std::vector> shapes = {{32, 3, 224, 224}, {32, 64, 56, 56}, {32, 512, 7, 7}}; + std::vector threads = {1, 2, 4, 8, 16}; + + for (auto shape: shapes) { + for (auto t: threads) { + sd::Environment::getInstance().setMaxMasterThreads(t); + + auto x = NDArrayFactory::create('c', shape); + auto y = NDArrayFactory::create('c', {shape[1]}); + auto z = x.ulike(); + + std::vector values; + Context ctx(1); + ctx.setInputArray(0, &x); + ctx.setInputArray(1, &y); + ctx.setOutputArray(0, &z); + + sd::ops::biasadd op; + + + for (int e = 0; e < 10000; e++) { + auto timeStart = std::chrono::system_clock::now(); + + //op.execute({&x, &y}, {&z}, {true}); + sd::ops::helpers::addBias(ctx, x, y, z, true); + + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); + values.emplace_back(outerTime); + } + + std::sort(values.begin(), values.end()); + + nd4j_printf("Shape: [%lld, %lld, %lld, %lld]; Threads: [%i]; Time: %lld us;\n", shape[0], shape[1], shape[2], shape[3], t, values[values.size() / 2]); + } + } +} +*/ + +/* +TEST_F(PlaygroundTests, test_s_0) { + auto x = NDArrayFactory::create('c', {32, 112, 112, 16}); + auto y = NDArrayFactory::create('c', {16}); + auto z = x.ulike(); + + std::vector values; + Context ctx(1); + ctx.setInputArray(0, &x); + ctx.setInputArray(1, &y); + ctx.setOutputArray(0, &z); + + sd::ops::biasadd op; + + + for (int e = 0; e < 10000; e++) { + auto timeStart = std::chrono::system_clock::now(); + + op.execute(&ctx); + + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); + values.emplace_back(outerTime); + } + + std::sort(values.begin(), values.end()); + + nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); +} +*/ +/* +TEST_F(PlaygroundTests, test_s_1) { + auto x0 = NDArrayFactory::create('c', {32, 7, 7, 176}); + auto x1 = x0.ulike(); + auto x2 = x0.ulike(); + auto x3 = x0.ulike(); + auto x4 = x0.ulike(); + auto x5 = x0.ulike(); + + auto y = NDArrayFactory::create(3); + auto z = NDArrayFactory::create('c', {32, 7, 7, 1056}); + + Context ctx(1); + ctx.setInputArray(0, &x0); + ctx.setInputArray(1, &x1); + ctx.setInputArray(2, &x2); + ctx.setInputArray(3, &x3); + ctx.setInputArray(4, &x4); + ctx.setInputArray(5, &x5); + + ctx.setInputArray(6, &y); + ctx.setOutputArray(0, &z); + ctx.setBArguments({true}); + + std::vector values; + + sd::ops::concat op; + op.execute(&ctx); + + for (int e = 0; e < 1000; e++) { + auto timeStart = std::chrono::system_clock::now(); + + op.execute(&ctx); + + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); + values.emplace_back(outerTime); + } + + + std::sort(values.begin(), values.end()); + + nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); +} +*/ + +/* +TEST_F(PlaygroundTests, test_s_1) { + auto t = ::runLightBenchmarkSuit(true); + delete[] t; +} + +TEST_F(PlaygroundTests, test_s_2) { + std::atomic s; + s = 0; + auto func = PRAGMA_THREADS_FOR { + s++; + }; + + samediff::Threads::parallel_for(func, 0, 8192, 1, 4); + std::vector values; + + for (int e = 0; e < 100000; e++) { + s = 0; + + auto timeStart = std::chrono::system_clock::now(); + //samediff::Threads::parallel_for(func, 0, 8192, 1, 4); + PRAGMA_OMP_PARALLEL_THREADS(4) { + s++; + } + + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); + values.emplace_back(outerTime); + }; + std::sort(values.begin(), values.end()); + + nd4j_printf("Time: %lld;\n", values[values.size() / 2]); +} + */ +/* +TEST_F(PlaygroundTests, test_s_4) { + std::atomic f; + std::atomic s; + std::vector valuesX, valuesY; + int iterations = 1000; + s = 0; + auto func = PRAGMA_THREADS_FOR { + s++; + }; + + samediff::Threads::parallel_for(func, 0, 8192, 1, 4); + + //////// + + auto x = NDArrayFactory::create('c', {32, 3, 256, 256}); + auto z = NDArrayFactory::create('c', {32, 3, 256, 256}); + x.linspace(1.0); + + auto xs0 = x.sizeAt(0); + auto xs1 = x.sizeAt(1); + auto xs2 = x.sizeAt(2); + auto xs3 = x.sizeAt(3); + + auto buffer = x.bufferAsT(); + auto zbuffer = z.bufferAsT(); + + for (int e = 0; e < iterations; e++) { + auto timeStart = std::chrono::system_clock::now(); + PRAGMA_OMP_PARALLEL_FOR_COLLAPSE(2) + for (int i = 0; i < xs0; i++) { + for (int j = 0; j < xs1; j++) { + auto thread_id = omp_get_thread_num(); + for (int k = 0; k < xs2; k++) { + for (int l = 0; l < xs3; l++) { + zbuffer[thread_id] += buffer[i * j + (k*l)] * 2.5f; + } + } + } + } + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); + valuesX.emplace_back(outerTime); + } + + + for (int e = 0; e < iterations; e++) { + auto timeStart = std::chrono::system_clock::now(); + auto f2d = PRAGMA_THREADS_FOR_2D { + for (auto i = start_x; i < stop_x; i++) { + for (auto j = start_y; j < stop_y; j++) { + + for (auto k = 0; k < xs2; k++) { + for (auto l = 0; l < xs3; l++) { + zbuffer[thread_id] += buffer[i * j + (k * l)] * 2.5f; + } + } + } + } + }; + samediff::Threads::parallel_for(f2d, 0, xs0, 1, 0, xs1, 1); + + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); + valuesY.emplace_back(outerTime); + } + + if (valuesX.size() > 0) { + std::sort(valuesX.begin(), valuesX.end()); + nd4j_printf("OpenMP time: %lld; Min: %lld; Max: %lld;\n", valuesX[valuesX.size() / 2], valuesX[0], valuesX[valuesX.size() - 1]); + } + + if (valuesY.size() > 0) { + std::sort(valuesY.begin(), valuesY.end()); + nd4j_printf("Threads time: %lld; Min: %lld; Max: %lld;\n", valuesY[valuesY.size() / 2], valuesY[0], valuesY[valuesY.size() - 1]); + } + + nd4j_printf("Sum: %f\n", z.sumNumber().e(0)); +} + + +TEST_F(PlaygroundTests, test_s_5) { + auto x = NDArrayFactory::create('c', {32, 1, 28, 28}); + + std::vector values; + auto iterations = 100; + + auto startX = 0; + auto stopX = x.sizeAt(0); + auto incX = 1; + auto startY = 0; + auto stopY = x.sizeAt(1); + auto incY = 1; + auto numThreads = 4; + + // number of elements per loop + auto delta_x = (stopX - startX); + auto delta_y = (stopY - startY); + + // number of iterations per loop + auto itersX = delta_x / incX; + auto itersY = delta_y / incY; + + for (int e = 0; e < iterations; e++) { + auto timeStart = std::chrono::system_clock::now(); + + // picking best fit here + auto splitLoop = samediff::ThreadsHelper::pickLoop2d(numThreads, itersX, itersY); + auto span = samediff::Span2::build(splitLoop, 0, numThreads, startX, stopX, incX, startY, stopY, incY); + + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); + values.emplace_back(outerTime); + } + + std::sort(values.begin(), values.end()); + + nd4j_printf("Calculations time: [Median: %lld; Min: %lld; Max: %lld;]\n", values[values.size() / 2], values[0], values[values.size()-1]); +} + + +TEST_F(PlaygroundTests, test_s_6) { + auto x = NDArrayFactory::create('c', {1024 * 1024 * 64}); + auto buffer = x.bufferAsT(); + auto len = x.lengthOf(); + std::vector values; + auto iterations = 1000; + + for (int i = 0; i < iterations; i++) { + auto timeStart = std::chrono::system_clock::now(); + + // picking best fit here + for (int e = 0; e < len; e++) { + buffer[e] = (buffer[e] + 1.72f) * 3.17f - 0.0012f; + } + + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); + values.emplace_back(outerTime); + } + + std::sort(values.begin(), values.end()); + + nd4j_printf("Calculations time: [Median: %lld; Min: %lld; Max: %lld;]\n", values[values.size() / 2], values[0], values[values.size()-1]); +} + + +TEST_F(PlaygroundTests, test_s_3) { + std::atomic s; + s = 0; + auto func = PRAGMA_THREADS_FOR { + s++; + }; + + for (int e = 0; e < 10000; e++) { + + samediff::Threads::parallel_for(func, 0, 8192, 1, 4); + } +} + */ + +/* +TEST_F(PlaygroundTests, test_relubp_1) { + auto x = NDArrayFactory::create('c', {128, 64, 224, 224}); + auto y = x.ulike(); + auto z = x.ulike(); + RandomGenerator rng(119, 120); + RandomLauncher::fillUniform(LaunchContext::defaultContext(), rng, &x, -1.0, 1.0); + RandomLauncher::fillUniform(LaunchContext::defaultContext(), rng, &y, -1.0, 1.0); + + int iterations = 10; + + auto timeStart = std::chrono::system_clock::now(); + for (int e = 0; e < iterations; e++) + ops::helpers::reluDerivative(LaunchContext::defaultContext(), &x, &y, &z); + auto timeEnd = std::chrono::system_clock::now(); + + auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); + auto time = (Nd4jLong) outerTime / iterations; + auto bw = (1000000L * (float) (x.lengthOf() * x.sizeOfT()) / time) / 1024 / 1024 / 1024; + + nd4j_printf("Time: %lld; BW: %f GB/s\n", time, bw); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(PlaygroundTests, my) { + + int bS=8, iD=32,iH=32,iW=32, iC=128, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=2,dH=2,dW=2; + int oD,oH,oW; + + sd::ops::ConvolutionUtils::calcOutSizeDeconv3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, 0); + + printf("!!%i, %i, %i\n", oD,oH,oW); + + NDArray col('c', {bS, iC, kD, kH, kW, iD, iH, iW}, sd::DataType::DOUBLE); + NDArray vol('c', {bS, iC, oD, oH, oW}, sd::DataType::DOUBLE); + + col = 3.77; + vol = -10.33; + + auto variableSpace = new VariableSpace(); + auto block = new Context(1, variableSpace, false); // not-in-place + + auto timeStart = std::chrono::system_clock::now(); + sd::ops::ConvolutionUtils::col2vol(*block, col, vol, sD, sH, sW, pD, pH, pW, dD, dH, dW); + auto timeEnd = std::chrono::system_clock::now(); + auto time = std::chrono::duration_cast (timeEnd - timeStart).count(); + + printf("time: %i \n", time); + + delete block; + delete variableSpace; +} + +TEST_F(PlaygroundTests, my) { + + int bS=32, iD=32,iH=64,iW=64, iC=128, kD=2,kH=2,kW=2, sD=1,sH=1,sW=1, pD=0,pH=0,pW=0, dD=2,dH=2,dW=2; + int oD,oH,oW; + + // sd::ops::ConvolutionUtils::calcOutSizeDeconv3D(oD, oH, oW, kD, kH, kW, sD, sH, sW, pD, pH, pW, dD, dH, dW, iD, iH, iW, 0); + sd::ops::ConvolutionUtils::calcOutSizeDeconv2D(oH, oW, kH, kW, sH, sW, pH, pW,dH, dW, iH, iW, 0); + + printf("!!%i, %i, %i\n", oD,oH,oW); + + // NDArray col('c', {bS, iC, kD, kH, kW, iD, iH, iW}, sd::DataType::DOUBLE); + // NDArray vol('c', {bS, iC, oD, oH, oW}, sd::DataType::DOUBLE); + NDArray col('c', {bS, iC, kH, kW, iH, iW}, sd::DataType::DOUBLE); + NDArray im('c', {bS, iC, oH, oW}, sd::DataType::DOUBLE); + + col = 3.77; + // vol = -10.33; + im = -10.33; + + auto variableSpace = new VariableSpace(); + auto block = new Context(1, variableSpace, false); // not-in-place + + auto timeStart = std::chrono::system_clock::now(); + // sd::ops::ConvolutionUtils::col2vol(*block, col, vol, sD, sH, sW, pD, pH, pW, dD, dH, dW); + sd::ops::helpers::col2im(*col.getContext(), col, im, sH, sW, pH, pW, iH, iW, dH, dW); + auto timeEnd = std::chrono::system_clock::now(); + auto time = std::chrono::duration_cast (timeEnd - timeStart).count(); + + printf("time: %i \n", time); + + delete block; + delete variableSpace; +} + +/////////////////////////////////////////////////////////////////// +TEST_F(PlaygroundTests, lstmLayerCellBp_1) { + + const int bS = 2; + const int nIn = 4; + const int nOut = 3; + // const int nIn = 8; + // const int nOut = 6; + + const float cellClip = 1.1; // clipping value + const Nd4jLong gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const float gateAlpha = 0; // alpha value for activation for gates, not required for sigmoid + const float gateBeta = 0; // beta value for activation for gates, not required for sigmoid + const Nd4jLong cellAct = 0; // tanh activation for cell state + const float cellAlpha = 0; // alpha value for cell state activation, not required for tanh + const float cellBeta = 0; // beta value for cell state activation, not required for tanh + const Nd4jLong outAct = 0; // tanh activation for output + const float outAlpha = 0; // alpha value for output activation, not required for tanh + const float outBeta = 0; // beta value for output activation, not required for tanh + + NDArray x ('c', {bS, nIn}, sd::DataType::DOUBLE); + NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdc('c', {bS, nOut}, sd::DataType::DOUBLE); + + // NDArray x ('c', {nIn}, sd::DataType::DOUBLE); + // NDArray hI('c', {nOut}, sd::DataType::DOUBLE); + // NDArray cI('c', {nOut}, sd::DataType::DOUBLE); + // NDArray dLdh('c', {nOut}, sd::DataType::DOUBLE); + // NDArray dLdc('c', {nOut}, sd::DataType::DOUBLE); + + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE); + NDArray b ('c', {4*nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); + + x.linspace(-4,1); + hI.linspace(-2.5,0.5); + cI.linspace(-3,0.5); + Wx.linspace(0,0.1); + Wr.linspace(3,-0.1); + Wp.linspace(0.2,0.2); + b.linspace(1,-0.15); + + // x.assign(1.); + // hI.assign(2.); + // cI.assign(3.); + // Wx.assign(0.5); + // Wr.assign(0.5); + // Wp.assign(0.75); + // b.assign(0.7); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {gateAct, cellAct, outAct}; + + // std::vector bArgs = {false, false}; + // const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &hI, &cI}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &hI, &cI, &dLdh}, tArgs, iArgs, bArgs); + + std::vector bArgs = {true, true}; + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayerCell opFF; + sd::ops::lstmLayerCellBp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, true, true, true}); +} + +TEST_F(PlaygroundTests, my) { + + const int N = 40; + + NDArray x('c', {256,256,128,128}, sd::DataType::FLOAT32); + NDArray z1('c', {256,2,128}, sd::DataType::DOUBLE); + NDArray z = z1({0,0,0,1,0,0}); + z.printShapeInfo(); + + auto timeStart = std::chrono::system_clock::now(); + for (int i = 0; i < N; ++i) { + // x.reduceAlongDimension(sd::reduce::Mean, z, {1,3}); + x.applyBroadcast(sd::broadcast::Ops::Add, {1,3}, z, x); + } + auto timeEnd = std::chrono::system_clock::now(); + auto time = std::chrono::duration_cast ((timeEnd - timeStart) / N).count(); + + printf("old %ld\n", time); +} + + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_bp_1) { + + const int sL = 3; + const int bS = 2; + const int nIn = 2; + const int nOut = 3; + + const int dataFormat = 0; // [sL,bS,nIn] + const int directionMode = 0; // forward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // dLdh per each time step + const auto retLastH = true; // output at last time step + const auto retLastC = true; // cells state at last time step + + const double cellClip = 0.5; // clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE); + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE); + NDArray b('c', {4*nOut}, sd::DataType::DOUBLE); + NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {sL, bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); + + x.linspace(-2,0.1); + hI.linspace(-1.5,0.1); + cI.linspace(0.7,-0.1); + Wx.linspace(1,-0.1); + Wr.linspace(-1,0.1); + Wp.linspace(0.2,0.2); + b.linspace(1,-0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_bp_2) { + + const int sL = 3; + const int bS = 2; + const int nIn = 2; + const int nOut = 3; + + const int dataFormat = 1; // [bS,sL,nIn] + const int directionMode = 0; // forward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // return whole h {h_0, h_1, ... , h_sL-1}, [sL,bS,nOut] + const auto retLastH = false; // output at last time step + const auto retLastC = true; // cells state at last time step + + const double cellClip = 0.5; // clipping + + NDArray x('c', {bS, sL, nIn}, sd::DataType::DOUBLE); + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE); + NDArray b('c', {4*nOut}, sd::DataType::DOUBLE); + NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {bS, sL, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); + + x.linspace(-2,0.1); + hI.linspace(-1.5,0.1); + cI.linspace(0.7,-0.1); + Wx.linspace(1,-0.1); + Wr.linspace(-1,0.1); + Wp.linspace(0.2,0.2); + b.linspace(1,-0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, std::vector(), {0., 1.}, GradCheck::LossFunc::MEAN); + + ASSERT_TRUE(isGradCorrect); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_bp_3) { + + const int sL = 4; + const int bS = 3; + const int nIn = 3; + const int nOut = 2; + + const int dataFormat = 2; // [bS, nIn, sL] + const int directionMode = 0; // forward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // dLdh per each time step + const auto retLastH = true; // output at last time step + const auto retLastC = true; // cells state at last time step + + const double cellClip = 0.5; // clipping + + NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE); + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE); + NDArray b('c', {4*nOut}, sd::DataType::DOUBLE); + NDArray seqLen('c', {bS}, {2,0,4}, sd::DataType::DOUBLE); + NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE); + NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); + + x.linspace(-2,0.1); + hI.linspace(-1.5,0.1); + cI.linspace(0.7,-0.1); + Wx.linspace(1,-0.1); + Wr.linspace(-1,0.1); + Wp.linspace(0.2,0.2); + b.linspace(1,-0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}); + + ASSERT_TRUE(isGradCorrect); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_bp_4) { + + const int sL = 3; + const int bS = 2; + const int nIn = 2; + const int nOut = 3; + + const int dataFormat = 1; // [bS,sL,nIn] + const int directionMode = 1; // backward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = false; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // dLdh per each time step + const auto retLastH = true; // output at last time step + const auto retLastC = true; // cells state at last time step + + const double cellClip = 0.5; // clipping + + NDArray x('c', {bS, sL, nIn}, sd::DataType::DOUBLE); + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE); + NDArray b('c', {4*nOut}, sd::DataType::DOUBLE); + NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {bS, sL, nOut}, sd::DataType::DOUBLE); + NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); + + x.linspace(-2,0.1); + hI.linspace(-1.5,0.1); + cI.linspace(0.7,-0.1); + Wx.linspace(1,-0.1); + Wr.linspace(-1,0.1); + Wp.linspace(0.2,0.2); + b.linspace(1,-0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); + + ASSERT_TRUE(isGradCorrect); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_bp_5) { + + const int sL = 3; + const int bS = 2; + const int nIn = 2; + const int nOut = 2; + + const int dataFormat = 2; // [bS, nIn, sL] + const int directionMode = 1; // backward + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // dLdh per each time step + const auto retLastH = true; // output at last time step + const auto retLastC = true; // cells state at last time step + + const double cellClip = 0.5; // clipping + + NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE); + NDArray Wx('c', {nIn, 4*nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {nOut, 4*nOut}, sd::DataType::DOUBLE); + NDArray b('c', {4*nOut}, sd::DataType::DOUBLE); + NDArray seqLen('c', {bS}, {0,2}, sd::DataType::DOUBLE); + NDArray hI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {3*nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE); + NDArray dLdhL('c', {bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {bS, nOut}, sd::DataType::DOUBLE); + + x.linspace(-2,0.1); + hI.linspace(-1.5,0.1); + cI.linspace(0.7,-0.1); + Wx.linspace(1,-0.1); + Wr.linspace(-1,0.1); + Wp.linspace(0.2,0.2); + b.linspace(1,-0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}); + + ASSERT_TRUE(isGradCorrect); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_bp_6) { + + const int sL = 3; + const int bS = 2; + const int nIn = 2; + const int nOut = 2; + + const int dataFormat = 2; // [bS, nIn, sL] + const int directionMode = 2; // bidirectional sum + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // dLdh per each time step + const auto retLastH = true; // output at last time step + const auto retLastC = true; // cells state at last time step + + const double cellClip = 0.5; // clipping + + NDArray x('c', {bS, nIn, sL}, sd::DataType::DOUBLE); + NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {2, nOut, 4*nOut}, sd::DataType::DOUBLE); + NDArray b('c', {2, 4*nOut}, sd::DataType::DOUBLE); + NDArray seqLen('c', {bS}, {0,2}, sd::DataType::DOUBLE); + NDArray hI('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {bS, nOut, sL}, sd::DataType::DOUBLE); + NDArray dLdhL('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {2, bS, nOut}, sd::DataType::DOUBLE); + + x.linspace(-2,0.1); + hI.linspace(-1.5,0.1); + cI.linspace(0.7,-0.1); + Wx.linspace(1,-0.1); + Wr.linspace(-1,0.1); + Wp.linspace(0.2,0.2); + b.linspace(1,-0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}); + + ASSERT_TRUE(isGradCorrect); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_bp_7) { + + const int sL = 3; + const int bS = 2; + const int nIn = 2; + const int nOut = 2; + + const int dataFormat = 1; // [bS,sL,nIn] + const int directionMode = 3; // bidirectional concat + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // dLdh per each time step + const auto retLastH = true; // output at last time step + const auto retLastC = true; // cells state at last time step + + const double cellClip = 0.5; // clipping + + NDArray x('c', {bS,sL,nIn}, sd::DataType::DOUBLE); + NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {2, nOut, 4*nOut}, sd::DataType::DOUBLE); + NDArray b('c', {2, 4*nOut}, sd::DataType::DOUBLE); + NDArray seqLen('c', {bS}, {0,2}, sd::DataType::DOUBLE); + NDArray hI('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {bS,sL,2*nOut}, sd::DataType::DOUBLE); + NDArray dLdhL('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {2, bS, nOut}, sd::DataType::DOUBLE); + + x.linspace(-2,0.1); + hI.linspace(-1.5,0.1); + cI.linspace(0.7,-0.1); + Wx.linspace(1,-0.1); + Wr.linspace(-1,0.1); + Wp.linspace(0.2,0.2); + b.linspace(1,-0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}); + + ASSERT_TRUE(isGradCorrect); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests13, lstmLayer_bp_8) { + + const int sL = 3; + const int bS = 2; + const int nIn = 2; + const int nOut = 2; + + const int dataFormat = 3; // [sL, bS, nIn] + const int directionMode = 4; // bidirectional extra output dim + const int gateAct = 2; // sigmoid activation for input (i), forget (f) and output (o) gates + const int cellAct = 0; // tanh activation for cell state + const int outAct = 0; // tanh activation for output + + const bool hasBiases = true; // biases array is provided + const bool hasSeqLen = true; // seqLen array is not provided + const auto hasInitH = true; // initial output is provided + const auto hasInitC = true; // initial cell state is provided + const auto hasPH = true; // peephole connections are absent + const auto retFullSeq = true; // dLdh per each time step + const auto retLastH = true; // output at last time step + const auto retLastC = true; // cells state at last time step + + const double cellClip = 0.5; // clipping + + NDArray x('c', {sL, bS, nIn}, sd::DataType::DOUBLE); + NDArray Wx('c', {2, nIn, 4*nOut}, sd::DataType::DOUBLE); + NDArray Wr('c', {2, nOut, 4*nOut}, sd::DataType::DOUBLE); + NDArray b('c', {2, 4*nOut}, sd::DataType::DOUBLE); + NDArray seqLen('c', {bS}, {0,2}, sd::DataType::DOUBLE); + NDArray hI('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray cI('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray Wp('c', {2, 3*nOut}, sd::DataType::DOUBLE); + NDArray dLdh('c', {sL, 2, bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdhL('c', {2, bS, nOut}, sd::DataType::DOUBLE); + NDArray dLdcL('c', {2, bS, nOut}, sd::DataType::DOUBLE); + + x.linspace(-2,0.1); + hI.linspace(-1.5,0.1); + cI.linspace(0.7,-0.1); + Wx.linspace(1,-0.1); + Wr.linspace(-1,0.1); + Wp.linspace(0.2,0.2); + b.linspace(1,-0.15); + + std::vector tArgs = {cellClip}; + std::vector iArgs = {dataFormat, directionMode, gateAct, cellAct, outAct}; + std::vector bArgs = {hasBiases, hasSeqLen, hasInitH, hasInitC, hasPH, retFullSeq, retLastH, retLastC}; + + const OpArgsHolder argsHolderFF({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp}, tArgs, iArgs, bArgs); + const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdhL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL, &dLdcL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdh}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdhL}, tArgs, iArgs, bArgs); + // const OpArgsHolder argsHolderBP({&x, &Wx, &Wr, &b, &seqLen, &hI, &cI, &Wp, &dLdcL}, tArgs, iArgs, bArgs); + + sd::ops::lstmLayer opFF; + sd::ops::lstmLayer_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP, {true, true, true, true, false, true, true, true}); + + ASSERT_TRUE(isGradCorrect); +} + +////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests15, gru_bp_1) { + + const int sL = 3; + const int bS = 2; + const int nIn = 5; + const int nOut = 4; + + + NDArray x('c', {sL, bS, nIn}, {0.5, 1. , 1.5, 2. , 2.5, 3. , 3.5, 4. , 4.5, 5. , 5.5, 6. , 6.5, 7. , 7.5, 8. , 8.5, 9. , 9.5, 10. , 10.5, 11. , 11.5, 12. , 12.5, 13. , 13.5, 14. , 14.5, 15.}, sd::DataType::DOUBLE); + NDArray hI('c', {bS, nOut}, {-3,-2,-1,0,1,2,3,4}, sd::DataType::DOUBLE); + NDArray Wx('c', {nIn, 3*nOut}, sd::DataType::DOUBLE); + NDArray Wh('c', {nOut, 3*nOut}, sd::DataType::DOUBLE); + NDArray b('c', {3*nOut}, sd::DataType::DOUBLE); + + NDArray dLdh('c', {sL, bS, nOut}, sd::DataType::DOUBLE); + + Wx.linspace(1,-0.1); + Wh.linspace(0.2,0.2); + b.linspace(1,-0.15); + + const OpArgsHolder argsHolderFF({&x, &hI, &Wx, &Wh, &b}, {}, {}); + const OpArgsHolder argsHolderBP({&x, &hI, &Wx, &Wh, &b, &dLdh}, {}, {}); + + sd::ops::gru opFF; + sd::ops::gru_bp opBP; + + const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP); +} + +*/ diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/PrimitivesTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/PrimitivesTests.cpp new file mode 100644 index 000000000..e8fa8d314 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/PrimitivesTests.cpp @@ -0,0 +1,94 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver110@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace sd; +using namespace sd::graph; + +class PrimitivesTests : public testing::Test { + public: + + PrimitivesTests() { + } +}; + +TEST_F(PrimitivesTests, test_mod_1) { + int ix = 7; + int iy = 3; + + + auto v = simdOps::Mod::op(ix, iy); + + ASSERT_EQ(7 % 3, v); +} + +TEST_F(PrimitivesTests, test_mod_2) { + float ix = 7.f; + float iy = 3.f; + + + auto e = sd::math::nd4j_fmod(ix, iy); + auto v = simdOps::Mod::op(ix, iy); + + ASSERT_NEAR(e, v, 1e-5f); +} + +TEST_F(PrimitivesTests, test_mod_3) { + float ix = 7.f; + float iy = 0.f; + + + auto e = sd::math::nd4j_fmod(ix, iy); + auto v = simdOps::Mod::op(ix, iy); + + // absence of SIGFPE will be a good enough +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ProtoBufTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ProtoBufTests.cpp new file mode 100644 index 000000000..c2c147a9d --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ProtoBufTests.cpp @@ -0,0 +1,112 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + + +#include "testlayers.h" +#include + +/* + +using namespace sd::graph; + +class ProtoBufTests : public testing::Test { + +}; + +TEST_F(ProtoBufTests, TestBinaryLoad1) { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + auto graph = GraphExecutioner::importFromTensorFlow("../../../tests/resources/tensorflow_inception_graph.pb"); + + ASSERT_FALSE(graph == nullptr); +} + +TEST_F(ProtoBufTests, TestTextLoad1) { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + auto graph = GraphExecutioner::importFromTensorFlow("../../../tests/resources/max_graph.pb.txt"); + + ASSERT_FALSE(graph == nullptr); +} + + +TEST_F(ProtoBufTests, TestTextLoad2) { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + auto graph = GraphExecutioner::importFromTensorFlow("../../../tests/resources/max_add_2.pb.txt"); + + ASSERT_FALSE(graph == nullptr); + + ASSERT_EQ(2, graph->getVariableSpace()->externalEntries()); + + auto var0 = graph->getVariableSpace()->getVariable(new std::string("zeros")); + auto var1 = graph->getVariableSpace()->getVariable(new std::string("ones")); + + + // first we're veryfying variable states + ASSERT_TRUE(var0 != nullptr); + ASSERT_TRUE(var1 != nullptr); + + ASSERT_TRUE(var0->getNDArray() != nullptr); + ASSERT_TRUE(var1->getNDArray() != nullptr); + + ASSERT_EQ(12, var0->getNDArray()->lengthOf()); + ASSERT_EQ(12, var1->getNDArray()->lengthOf()); + + ASSERT_NEAR(0.0f, var0->getNDArray()->reduceNumber>(), 1e-5); + ASSERT_NEAR(12.0f, var1->getNDArray()->reduceNumber>(), 1e-5); + ASSERT_NEAR(1.0f, var1->getNDArray()->reduceNumber>(), 1e-5); + + + // now we're veryfying op graph + ASSERT_EQ(1, graph->totalNodes()); + + GraphExecutioner::execute(graph); + + ASSERT_NEAR(12.0f, var0->getNDArray()->reduceNumber>(), 1e-5); + ASSERT_NEAR(1.0f, var0->getNDArray()->reduceNumber>(), 1e-5); +} + + +TEST_F(ProtoBufTests, TestTextLoad3) { + GOOGLE_PROTOBUF_VERIFY_VERSION; + + auto graph = GraphExecutioner::importFromTensorFlow("../../../tests/resources/max_multiply.pb.txt"); + + ASSERT_FALSE(graph == nullptr); + + ASSERT_EQ(2, graph->getVariableSpace()->externalEntries()); + + auto var0 = graph->getVariableSpace()->getVariable(new std::string("Placeholder")); + auto var1 = graph->getVariableSpace()->getVariable(new std::string("Placeholder_1")); + + ASSERT_TRUE(var0 != nullptr); + ASSERT_TRUE(var1 != nullptr); + + // we expect both variables to be set to null here + ASSERT_TRUE(var0->getNDArray() == nullptr); + ASSERT_TRUE(var1->getNDArray() == nullptr); + + // now we're veryfying op graph + ASSERT_EQ(1, graph->totalNodes()); +} +*/ \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/QuantizationTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/QuantizationTests.cpp new file mode 100644 index 000000000..a44a2053a --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/QuantizationTests.cpp @@ -0,0 +1,72 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@protonmail.com +// + + +#include "testlayers.h" +#include +#include + + +using namespace sd; + +class QuantizationTests : public testing::Test { + +}; + +TEST_F(QuantizationTests, Basic_Test_1) { +#ifndef __CUDABLAS__ + auto s = TypeCast::estimateQuantizedSize(10); + ASSERT_EQ(18, s); +#endif +} + +TEST_F(QuantizationTests, Basic_Test_2) { +#ifndef __CUDABLAS__ + auto s = TypeCast::estimateQuantizedSize(1); + ASSERT_EQ(9, s); +#endif +} + +TEST_F(QuantizationTests, Compression_Test_1) { + + #ifndef __CUDABLAS__ + + auto x = NDArrayFactory::create('c', {10}); + auto z = NDArrayFactory::create('c', {10}); + x.linspace(1.0f); + + auto q = new char[TypeCast::estimateQuantizedSize(x.lengthOf())]; + + TypeCast::convertToQuantized(nullptr, x.buffer(), x.lengthOf(), q); + TypeCast::convertFromQuantized(nullptr, q, x.lengthOf(), z.buffer()); + + ASSERT_TRUE(x.equalsTo(z, 0.1)); + + auto fq = reinterpret_cast(q); + + ASSERT_NEAR(1.0f, fq[0], 1e-5); + ASSERT_NEAR(10.0f, fq[1], 1e-5); + + delete[] q; + + #endif +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/RNGTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/RNGTests.cpp new file mode 100644 index 000000000..8da3542f6 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/RNGTests.cpp @@ -0,0 +1,1299 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include + +using namespace sd; + +class RNGTests : public testing::Test { +private: + + +public: + long _seed = 119L; + + sd::graph::RandomGenerator _rngA; + sd::graph::RandomGenerator _rngB; + + NDArray* nexp0 = NDArrayFactory::create_('c', {10, 10}); + NDArray* nexp1 = NDArrayFactory::create_('c', {10, 10}); + NDArray* nexp2 = NDArrayFactory::create_('c', {10, 10}); + + RNGTests() { + + _rngA.setStates(_seed * 0xDEADBEEF * 13, _seed * 0xDEADBEEF * 7); + _rngB.setStates(_seed * 0xDEADBEEF * 13, _seed * 0xDEADBEEF * 7); + nexp0->assign(-1.0f); + nexp1->assign(-2.0f); + nexp2->assign(-3.0f); + } + + ~RNGTests() { + + delete nexp0; + delete nexp1; + delete nexp2; + } +}; + +TEST_F(RNGTests, TestSeeds_1) { + RandomGenerator generator(123L, 456L); + + ASSERT_EQ(123, generator.rootState()); + ASSERT_EQ(456, generator.nodeState()); + + Nd4jPointer ptr = malloc(sizeof(RandomGenerator)); + memcpy(ptr, &generator, sizeof(RandomGenerator)); + + auto cast = reinterpret_cast(ptr); + ASSERT_EQ(123, cast->rootState()); + ASSERT_EQ(456, cast->nodeState()); + + free(ptr); +} + +TEST_F(RNGTests, TestSeeds_2) { + RandomGenerator generator(12, 13); + + generator.setStates(123L, 456L); + + ASSERT_EQ(123, generator.rootState()); + ASSERT_EQ(456, generator.nodeState()); +} + +TEST_F(RNGTests, TestGenerator_SGA_1) { + RandomGenerator generator(12, 13); + auto array= NDArrayFactory::create('c',{10000000}); + generator.setStates(123L, 456L); + for (auto idx = 0; idx < array.lengthOf(); idx++) { + float x = generator.relativeT(idx, -sd::DataTypeUtils::template max() / 10, + sd::DataTypeUtils::template max() / 10); + array.r(idx) = x; + } + auto minimum = array.reduceNumber(reduce::AMin); + ASSERT_EQ(123, generator.rootState()); + ASSERT_EQ(456, generator.nodeState()); +} + + +TEST_F(RNGTests, Test_Dropout_1) { + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); + + x0.linspace(1); + x1.linspace(1); + + float prob[] = {0.5f}; + + + RandomLauncher::applyDropOut(LaunchContext::defaultContext(), _rngA, &x0, 0.5); + RandomLauncher::applyDropOut(LaunchContext::defaultContext(), _rngB, &x1, 0.5); + ASSERT_TRUE(x0.equalsTo(&x1)); + + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); +} + +TEST_F(RNGTests, Test_DropoutInverted_1) { + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); + + x0.linspace(1); + x1.linspace(1); + + float prob[] = {0.5f}; + + + RandomLauncher::applyInvertedDropOut(LaunchContext::defaultContext(), _rngA, &x0, 0.5); + RandomLauncher::applyInvertedDropOut(LaunchContext::defaultContext(), _rngB, &x1, 0.5); + ASSERT_TRUE(x0.equalsTo(&x1)); + + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); +} + + +TEST_F(RNGTests, Test_Launcher_1) { + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); + + RandomLauncher::applyDropOut(LaunchContext::defaultContext(), _rngA, &x0, 0.5f); + RandomLauncher::applyDropOut(LaunchContext::defaultContext(), _rngB, &x1, 0.5f); + + ASSERT_TRUE(x0.equalsTo(&x1)); + + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); +} + + +TEST_F(RNGTests, Test_Launcher_2) { + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); + + RandomLauncher::applyInvertedDropOut(LaunchContext::defaultContext(), _rngA, &x0, 0.5f); + RandomLauncher::applyInvertedDropOut(LaunchContext::defaultContext(), _rngB, &x1, 0.5f); + + ASSERT_TRUE(x0.equalsTo(&x1)); + + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); +} + + +TEST_F(RNGTests, Test_Launcher_3) { + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); + + RandomLauncher::applyAlphaDropOut(LaunchContext::defaultContext(), _rngA, &x0, 0.5f, 0.2f, 0.1f, 0.3f); + RandomLauncher::applyAlphaDropOut(LaunchContext::defaultContext(), _rngB, &x1, 0.5f, 0.2f, 0.1f, 0.3f); + + ASSERT_TRUE(x0.equalsTo(&x1)); + + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); +} + +TEST_F(RNGTests, Test_Uniform_1) { + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); + + RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); + RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); + + ASSERT_TRUE(x0.equalsTo(&x1)); + + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); + + for (int e = 0; e < x0.lengthOf(); e++) { + float v = x0.e(e); + ASSERT_TRUE(v >= 1.0f && v <= 2.0f); + } +} + +TEST_F(RNGTests, Test_Uniform_10) { + auto x = NDArrayFactory::create('c', {10000, 10000}); + auto z = NDArrayFactory::create(0.0f); + + RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngA, &x, 0.0f, 1.0f); + + sd::ops::reduce_max op; + auto status = op.execute({&x}, {&z}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_LT(z.t(0), 1.0f); +} + +TEST_F(RNGTests, Test_Uniform_10_double) { + auto x = NDArrayFactory::create('c', {10000, 10000}); + auto z = NDArrayFactory::create(0.0f); + + RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngA, &x, 0.0f, 1.0f); + + sd::ops::reduce_max op; + auto status = op.execute({&x}, {&z}); + ASSERT_EQ(Status::OK(), status); + + ASSERT_LT(z.t(0), 1.0); +} + +TEST_F(RNGTests, Test_Uniform_11) { + uint32_t max = 0; + for (int e = 0; e < 100000000; e++) { + auto v = _rngA.xoroshiro32(e) >> 8; + if (v > max) + max = v; + } +} + +TEST_F(RNGTests, Test_Uniform_12) { + float max = -std::numeric_limits::infinity(); + float min = std::numeric_limits::infinity(); + for (int e = 0; e < 100000000; e++) { + auto v = _rngA.relativeT(e); + if (v > max) + max = v; + + if (v < min) + min = v; + } + + ASSERT_LT(max, 1.0f); + ASSERT_GE(min, 0.0); +} + +TEST_F(RNGTests, Test_Uniform_13) { + double max = -std::numeric_limits::infinity(); + double min = std::numeric_limits::infinity(); + for (int e = 0; e < 100000000; e++) { + auto v = _rngA.relativeT(e); + if (v > max) + max = v; + + if (v < min) + min = v; + } + + ASSERT_LT(max, 1.0); + ASSERT_GE(min, 0.0); +} + +TEST_F(RNGTests, Test_Uniform_3) { + auto x0 = NDArrayFactory::create('c', {1000000}); + + RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); + + for (int e = 0; e < x0.lengthOf(); e++) { + auto v = x0.t(e); + ASSERT_TRUE(v >= 1.0 && v <= 2.0); + } +} + +TEST_F(RNGTests, Test_Bernoulli_1) { + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); + + RandomLauncher::fillBernoulli(LaunchContext::defaultContext(), _rngA, &x0, 1.0f); + RandomLauncher::fillBernoulli(LaunchContext::defaultContext(), _rngB, &x1, 1.0f); + + ASSERT_TRUE(x0.equalsTo(&x1)); + + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); +} + +TEST_F(RNGTests, Test_Gaussian_1) { + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); + + RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); + RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); + + ASSERT_TRUE(x0.equalsTo(&x1)); + + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); +} + +TEST_F(RNGTests, Test_Gaussian_21) { + auto x0 = NDArrayFactory::create('c', {1000, 1000}); + auto x1 = NDArrayFactory::create('c', {1000, 1000}); + + RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngA, &x0, 0.0f, 1.0f); + RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngB, &x1, 0.0f, 1.0f); + + ASSERT_TRUE(x0.equalsTo(&x1)); + + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); + sd::ops::moments op; + auto result = op.evaluate({&x0}, {}, {}); + ASSERT_TRUE(result.status() == Status::OK()); + auto mean = result.at(0); + auto variance = result.at(1); + + ASSERT_NEAR(sd::math::nd4j_abs(mean->e(0)), 0.f, 0.2f); + ASSERT_NEAR(variance->e(0), 1.0f, 0.2f); + + +} + +#ifdef DEBUG_BUILD +TEST_F(RNGTests, Test_Gaussian_22) { + auto x0 = NDArrayFactory::create('c', {1000, 800}); + auto x1 = NDArrayFactory::create('c', {1000, 800}); + + RandomLauncher::fillGaussian(sd::LaunchContext::defaultContext(), _rngA, &x0, 0.0f, 1.0f); + RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngB, &x1, 0.0f, 1.0f); + + ASSERT_TRUE(x0.equalsTo(&x1)); + + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); + sd::ops::moments op; + auto result = op.evaluate({&x0}, {}, {}); + + ASSERT_TRUE(result.status() == Status::OK()); + auto mean0 = result.at(0); + auto variance0 = result.at(1); + + ASSERT_NEAR(sd::math::nd4j_abs(mean0->e(0)), 0.f, 1.0e-3f); + ASSERT_NEAR(variance0->e(0), 1.0f, 1.e-3f); + +} + +TEST_F(RNGTests, Test_Gaussian_3) { + auto x0 = NDArrayFactory::create('c', {800000}); + + RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngA, &x0, 0.0, 1.0); + + auto mean = x0.meanNumber(); + auto stdev = x0.varianceNumber(sd::variance::SummaryStatsStandardDeviation, false); + auto meanExp = NDArrayFactory::create(0.); + auto devExp = NDArrayFactory::create(1.); + ASSERT_TRUE(meanExp.equalsTo(mean, 1.e-3)); + ASSERT_TRUE(devExp.equalsTo(stdev, 1.e-3)); +} + +TEST_F(RNGTests, Test_LogNormal_1) { + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); + + RandomLauncher::fillLogNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); + RandomLauncher::fillLogNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); + + ASSERT_TRUE(x0.equalsTo(&x1)); + + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); +} + +TEST_F(RNGTests, Test_Truncated_1) { + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); + + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); + + ASSERT_TRUE(x0.equalsTo(&x1)); + + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); + + /* Check up distribution */ + auto mean = x1.reduceNumber(reduce::Mean); + auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); + + auto deviation = x1.varianceNumber(variance::SummaryStatsStandardDeviation, false); +} + +TEST_F(RNGTests, Test_Truncated_2) { + auto x0 = NDArrayFactory::create('c', {1000, 1000}); + auto x1 = NDArrayFactory::create('c', {1000, 1000}); + + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); + + ASSERT_TRUE(x0.equalsTo(&x1)); + /* Check up distribution */ + auto mean = x1.reduceNumber(reduce::Mean); + + auto deviation = x1.varianceNumber(variance::SummaryStatsStandardDeviation, false); + ASSERT_NEAR(mean.e(0), 1.f, 0.5); + ASSERT_NEAR(deviation.e(0), 2.f, 0.5); +} + +TEST_F(RNGTests, Test_Truncated_21) { + auto x0 = NDArrayFactory::create('c', {100, 100}); + auto x1 = NDArrayFactory::create('c', {100, 100}); + + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); + + ASSERT_TRUE(x0.equalsTo(&x1)); + + auto mean0 = x0.reduceNumber(reduce::Mean); + + auto deviation0 = x0.varianceNumber(variance::SummaryStatsStandardDeviation, false); + + /* Check up distribution */ + auto mean = x1.reduceNumber(reduce::Mean); + + auto deviation = x1.varianceNumber(variance::SummaryStatsStandardDeviation, false); + + ASSERT_NEAR(mean.e(0), 1.f, 0.002); + ASSERT_NEAR(deviation.e(0), 2.f, 0.5); + sd::ops::moments op; + auto result = op.evaluate({&x0}, {}, {}, {}, {}, false); + + sd::ops::reduce_min minOp; + sd::ops::reduce_max maxOp; + + auto minRes = minOp.evaluate({&x1}, {}, {}, {}); + auto maxRes = maxOp.evaluate({&x0}, {}, {}, {}); +} + +TEST_F(RNGTests, Test_Truncated_22) { + auto x0 = NDArrayFactory::create('c', {100, 100}); + auto x1 = NDArrayFactory::create('c', {100, 100}); + + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 2.0f, 4.0f); + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 2.0f, 4.0f); + + ASSERT_TRUE(x0.equalsTo(&x1)); + + auto mean0 = x0.reduceNumber(reduce::Mean); + + auto deviation0 = x0.varianceNumber(variance::SummaryStatsStandardDeviation, false); + + /* Check up distribution */ + auto mean = x1.reduceNumber(reduce::Mean); + + auto deviation = x1.varianceNumber(variance::SummaryStatsStandardDeviation, false); + ASSERT_NEAR(mean.e(0), 2.f, 0.01); + ASSERT_NEAR(deviation.e(0), 4.f, 0.52); + sd::ops::moments op; + auto result = op.evaluate({&x0}, {}, {}, {}, {}, false); + + sd::ops::reduce_min minOp; + sd::ops::reduce_max maxOp; + + auto minRes = minOp.evaluate({&x1}, {}, {}, {}); + auto maxRes = maxOp.evaluate({&x0}, {}, {}, {}); +} + +TEST_F(RNGTests, Test_Truncated_23) { + auto x0 = NDArrayFactory::create('c', {1000, 1000}); + auto x1 = NDArrayFactory::create('c', {1000, 1000}); + + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 0.0f, 1.0f); + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 0.0f, 1.0f); + + ASSERT_TRUE(x0.equalsTo(&x1)); + + auto mean0 = x0.reduceNumber(reduce::Mean); + + auto deviation0 = x0.varianceNumber(variance::SummaryStatsStandardDeviation, false); + + /* Check up distribution */ + auto mean = x1.reduceNumber(reduce::Mean); + + auto deviation = x1.varianceNumber(variance::SummaryStatsStandardDeviation, false); + ASSERT_NEAR(mean.e(0), 0.f, 0.01); + ASSERT_NEAR(deviation.e(0), 1.f, 0.5); + sd::ops::moments op; + auto result = op.evaluate({&x0}); + sd::ops::reduce_min minOp; + sd::ops::reduce_max maxOp; + + auto minRes = minOp.evaluate({&x1}, {}, {}, {}); + auto maxRes = maxOp.evaluate({&x0}, {}, {}, {}); + +} + +TEST_F(RNGTests, Test_Truncated_3) { + auto x0 = NDArrayFactory::create('c', {2000, 2000}); + auto x1 = NDArrayFactory::create('c', {2000, 2000}); + + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngA, &x0, 1.0f, 2.0f); + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); + + ASSERT_TRUE(x0.equalsTo(&x1)); + + // Check up distribution + auto mean = x1.reduceNumber(reduce::Mean); + + auto deviation = x1.varianceNumber(variance::SummaryStatsStandardDeviation, false); + ASSERT_NEAR(mean.e(0), 1.f, 0.001); + ASSERT_NEAR(deviation.e(0), 2.f, 0.3); +} +#endif + +TEST_F(RNGTests, Test_Binomial_1) { + auto x0 = NDArrayFactory::create('c', {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); + + RandomLauncher::fillBinomial(LaunchContext::defaultContext(), _rngA, &x0, 3, 2.0f); + RandomLauncher::fillBinomial(LaunchContext::defaultContext(), _rngB, &x1, 3, 2.0f); + + ASSERT_TRUE(x0.equalsTo(&x1)); + + ASSERT_FALSE(x0.equalsTo(nexp0)); + ASSERT_FALSE(x0.equalsTo(nexp1)); + ASSERT_FALSE(x0.equalsTo(nexp2)); +} + + +TEST_F(RNGTests, Test_Uniform_2) { + auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); + + RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); + + auto op = new sd::ops::LegacyRandomOp(0); + auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(x1.isSameShape(z)); + ASSERT_TRUE(x1.equalsTo(z)); + + delete op; + +} + +TEST_F(RNGTests, Test_Uniform_SGA_3) { + //auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); + + RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngB, &x1, -sd::DataTypeUtils::template max(), sd::DataTypeUtils::template max()); + auto minimumU = x1.reduceNumber(reduce::AMin); +} + +TEST_F(RNGTests, Test_Gaussian_2) { + auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); + + RandomLauncher::fillGaussian(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); + + auto op = new sd::ops::LegacyRandomOp(random::GaussianDistribution); + auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(x1.isSameShape(z)); + ASSERT_TRUE(x1.equalsTo(z)); + + delete op; + +} + +TEST_F(RNGTests, Test_LogNorm_2) { + auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); + + RandomLauncher::fillLogNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); + + auto op = new sd::ops::LegacyRandomOp(random::LogNormalDistribution); + auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(x1.isSameShape(z)); + ASSERT_TRUE(x1.equalsTo(z)); + + delete op; + +} + +TEST_F(RNGTests, Test_TruncatedNorm_2) { + auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); + + RandomLauncher::fillTruncatedNormal(LaunchContext::defaultContext(), _rngB, &x1, 1.0f, 2.0f); + + auto op = new sd::ops::LegacyRandomOp(random::TruncatedNormalDistribution); + auto result = op->execute(_rngA, {&input}, {1.0f, 2.0f}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(x1.isSameShape(z)); + ASSERT_TRUE(x1.equalsTo(z)); + delete op; + +} + + +TEST_F(RNGTests, Test_Binomial_2) { + auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); + + RandomLauncher::fillBinomial(LaunchContext::defaultContext(), _rngB, &x1, 3, 0.5f); + + auto op = new sd::ops::LegacyRandomOp(random::BinomialDistributionEx); + auto result = op->execute(_rngA, {&input}, {0.5f}, {3}); + + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(x1.isSameShape(z)); + ASSERT_TRUE(x1.equalsTo(z)); + + delete op; + +} + + +TEST_F(RNGTests, Test_Bernoulli_2) { + auto input = NDArrayFactory::create('c', {1, 2}, {10, 10}); + auto x1 = NDArrayFactory::create('c', {10, 10}); + + RandomLauncher::fillBernoulli(LaunchContext::defaultContext(), _rngB, &x1, 0.5f); + + auto op = new sd::ops::LegacyRandomOp(random::BernoulliDistribution); + auto result = op->execute(_rngA, {&input}, {0.5f}, {}); + + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(x1.isSameShape(z)); + ASSERT_TRUE(x1.equalsTo(z)); + + delete op; + +} + +TEST_F(RNGTests, Test_GaussianDistribution_1) { + auto x = NDArrayFactory::create('c', {2}, {10, 10}); + auto exp0 = NDArrayFactory::create('c', {10, 10}); + + + sd::ops::random_normal op; + auto result = op.evaluate({&x}, {0.0, 1.0f}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); + + + ASSERT_FALSE(nexp0->equalsTo(z)); + ASSERT_FALSE(nexp1->equalsTo(z)); + ASSERT_FALSE(nexp2->equalsTo(z)); + + +} + +TEST_F(RNGTests, Test_BernoulliDistribution_1) { + auto x = NDArrayFactory::create('c', {2}, {10, 10}); + auto exp0 = NDArrayFactory::create('c', {10, 10}); + + + sd::ops::random_bernoulli op; + auto result = op.evaluate({&x}, {0.5f}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_FALSE(exp0.equalsTo(z)); + + ASSERT_FALSE(nexp0->equalsTo(z)); + ASSERT_FALSE(nexp1->equalsTo(z)); + ASSERT_FALSE(nexp2->equalsTo(z)); + + +} + + +TEST_F(RNGTests, Test_ExponentialDistribution_1) { + auto x = NDArrayFactory::create('c', {2}, {10, 10}); + auto exp0 = NDArrayFactory::create('c', {10, 10}); + + + sd::ops::random_exponential op; + auto result = op.evaluate({&x}, {0.25f}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); + + auto mean = z->reduceNumber(reduce::Mean); + auto variance = z->varianceNumber(variance::SummaryStatsVariance, false); + + ASSERT_FALSE(nexp0->equalsTo(z)); + ASSERT_FALSE(nexp1->equalsTo(z)); + ASSERT_FALSE(nexp2->equalsTo(z)); + +} + +TEST_F(RNGTests, Test_ExponentialDistribution_1_SGA) { + auto x = NDArrayFactory::create('c', {2}, {10, 10}); + auto exp0 = NDArrayFactory::create('c', {10, 10}); + + + sd::ops::random_exponential op; + auto result = op.evaluate({&x}, {1.f}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); + + auto mean = z->reduceNumber(reduce::Mean); + auto variance = z->varianceNumber(variance::SummaryStatsVariance, false); + + ASSERT_FALSE(nexp0->equalsTo(z)); + ASSERT_FALSE(nexp1->equalsTo(z)); + ASSERT_FALSE(nexp2->equalsTo(z)); + + +} + +TEST_F(RNGTests, Test_ExponentialDistribution_2_SGA) { + auto x = NDArrayFactory::create('c', {2}, {10, 10}); + auto exp0 = NDArrayFactory::create('c', {10, 10}); + RandomGenerator oc(2716049175077475646L, -6182841917129177862L); + + sd::ops::random_exponential op; + RandomLauncher::fillExponential(x.getContext(), oc, &exp0, 2.f); + auto result = op.evaluate({&x}, {1.f}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); + + auto mean = z->reduceNumber(reduce::Mean); + auto variance = z->varianceNumber(variance::SummaryStatsVariance, false); + + ASSERT_FALSE(nexp0->equalsTo(z)); + ASSERT_FALSE(nexp1->equalsTo(z)); + ASSERT_FALSE(nexp2->equalsTo(z)); + mean = exp0.reduceNumber(reduce::Mean); + variance = exp0.varianceNumber(variance::SummaryStatsVariance, false); + +} + +TEST_F(RNGTests, Test_ExponentialDistribution_2) { + auto x = NDArrayFactory::create('c', {2}, {10, 10}); + auto y = NDArrayFactory::create('c', {10, 10}); + auto exp0 = NDArrayFactory::create('c', {10, 10}); + + y.assign(1.0); + + + sd::ops::random_exponential op; + auto result = op.evaluate({&x, &y}, {0.25f}, {0}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); + + + ASSERT_FALSE(nexp0->equalsTo(z)); + ASSERT_FALSE(nexp1->equalsTo(z)); + ASSERT_FALSE(nexp2->equalsTo(z)); + + +} + +TEST_F(RNGTests, Test_PoissonDistribution_1) { + auto x = NDArrayFactory::create('c', {1}, {10}); + auto la = NDArrayFactory::create('c', {2, 3}); + auto exp0 = NDArrayFactory::create('c', {10, 2, 3}); + + la.linspace(1.0); + + + sd::ops::random_poisson op; + auto result = op.evaluate({&x, &la}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); +} + +TEST_F(RNGTests, Test_GammaDistribution_1) { + auto x = NDArrayFactory::create('c', {1}, {10}); + auto al = NDArrayFactory::create('c', {2, 3}); + auto exp0 = NDArrayFactory::create('c', {10, 2, 3}); + + al.linspace(1.0); + + + sd::ops::random_gamma op; + auto result = op.evaluate({&x, &al}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); +} + +TEST_F(RNGTests, Test_GammaDistribution_2) { + auto x = NDArrayFactory::create('c', {1}, {10}); + auto al = NDArrayFactory::create('c', {2, 3}); + auto be = NDArrayFactory::create('c', {2, 3}); + auto exp0 = NDArrayFactory::create('c', {10, 2, 3}); + + al.linspace(1.0); + be.assign(1.0); + + sd::ops::random_gamma op; + auto result = op.evaluate({&x, &al, &be}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); +} + +TEST_F(RNGTests, Test_GammaDistribution_3) { + auto x = NDArrayFactory::create('c', {1}, {10}); + auto al = NDArrayFactory::create('c', {3, 1}); + auto be = NDArrayFactory::create('c', {1, 2}); + auto exp0 = NDArrayFactory::create('c', {10, 3, 2}); + + al.linspace(1.0); + be.assign(2.0); + + sd::ops::random_gamma op; + auto result = op.evaluate({&x, &al, &be}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); + + +} + +TEST_F(RNGTests, Test_GammaDistribution_4) { + auto x = NDArrayFactory::create('c', {2}, {1000, 1000}); + auto al = NDArrayFactory::create(2.f); + auto be = NDArrayFactory::create(2.f); + auto exp0 = NDArrayFactory::create('c', {1000, 1000}); + +// al.linspace(1.0); +// be.assign(2.0); + + sd::ops::random_gamma op; + auto result = op.evaluate({&x, &al, &be}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); +// z->printIndexedBuffer("Gamma distribution"); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); + sd::ops::reduce_mean testOps1; + sd::ops::reduce_variance testOps2; + auto testRes1 = testOps1.evaluate({z}); + auto testRes2 = testOps2.evaluate({z}); +// testRes1[0]->printBuffer("Mean (expected 1.0)"); +// testRes2[0]->printBuffer("Variance (expected 0.5)"); + ASSERT_NEAR(testRes1[0]->t(0), 1.0f, 0.01); + ASSERT_NEAR(testRes2[0]->t(0), 0.5f, 0.02); +} + +TEST_F(RNGTests, Test_GammaDistribution_5) { + auto x = NDArrayFactory::create('c', {2}, {100, 100}); + auto al = NDArrayFactory::create(0.2f); + auto be = NDArrayFactory::create(2.f); + auto exp0 = NDArrayFactory::create('c', {100, 100}); + +// al.linspace(1.0); +// be.assign(2.0); + + sd::ops::random_gamma op; + auto result = op.evaluate({&x, &al, &be}, {}, {}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); +// z->printIndexedBuffer("Gamma distribution"); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); +// z->printIndexedBuffer("Gamma distributed"); + sd::ops::reduce_mean testOps1; + sd::ops::reduce_variance testOps2; + auto testRes1 = testOps1.evaluate({z}); + auto testRes2 = testOps2.evaluate({z}); +// testRes1[0]->printBuffer("Mean (expected 0.1)"); +// testRes2[0]->printBuffer("Variance (expected 0.05)"); + ASSERT_NEAR(testRes1[0]->t(0), 0.1f, 0.02); + ASSERT_NEAR(testRes2[0]->t(0), 0.05f, 0.02); +} + +TEST_F(RNGTests, Test_UniformDistribution_04) { + auto x = NDArrayFactory::create('c', {1}, {10}); + auto al = NDArrayFactory::create(1); + auto be = NDArrayFactory::create(20); + auto exp0 = NDArrayFactory::create('c', {10}); + + + sd::ops::randomuniform op; + auto result = op.evaluate({&x, &al, &be}, {}, {DataType::INT32}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); + +} + +TEST_F(RNGTests, Test_UniformDistribution_05) { + auto x = NDArrayFactory::create('c', {2}, {10000, 10000}); + auto al = NDArrayFactory::create(0.f); + auto be = NDArrayFactory::create(1.f); + auto exp0 = NDArrayFactory::create('c', {10000, 10000}); + + + sd::ops::randomuniform op; + auto result = op.evaluate({&x, &al, &be}, {}, {},{}, {DataType::FLOAT32}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + ASSERT_TRUE(exp0.isSameShape(z)); + ASSERT_FALSE(exp0.equalsTo(z)); + + sd::ops::reduce_max checkOp; + auto checkResult = checkOp.evaluate({z}); +} + +namespace sd { + namespace tests { + static void fillList(Nd4jLong seed, int numberOfArrays, std::vector &shape, std::vector &list, sd::graph::RandomGenerator *rng) { + rng->setSeed((int) seed); + + for (int i = 0; i < numberOfArrays; i++) { + auto arrayI = NDArrayFactory::create(shape); + auto arrayR = NDArrayFactory::create_('c', shape); + auto min = NDArrayFactory::create(0.0); + auto max = NDArrayFactory::create(1.0); + sd::ops::randomuniform op; + op.execute(*rng, {&arrayI, &min, &max}, {arrayR}, {}, {DataType::DOUBLE}, {}, {}, false); + + list.emplace_back(arrayR); + } + }; + } +} + +TEST_F(RNGTests, Test_Reproducibility_1) { + Nd4jLong seed = 123; + + std::vector shape = {32, 3, 28, 28}; + sd::graph::RandomGenerator rng; + + std::vector expList; + sd::tests::fillList(seed, 10, shape, expList, &rng); + + for (int e = 0; e < 2; e++) { + std::vector trialList; + sd::tests::fillList(seed, 10, shape, trialList, &rng); + + for (int a = 0; a < expList.size(); a++) { + auto arrayE = expList[a]; + auto arrayT = trialList[a]; + + bool t = arrayE->equalsTo(arrayT); + if (!t) { + ASSERT_TRUE(false); + } + + delete arrayT; + } + } + + for (auto v: expList) + delete v; +} + +#ifndef DEBUG_BUILD +TEST_F(RNGTests, Test_Reproducibility_2) { + Nd4jLong seed = 123; + + std::vector shape = {32, 3, 64, 64}; + sd::graph::RandomGenerator rng; + + std::vector expList; + sd::tests::fillList(seed, 10, shape, expList, &rng); + + for (int e = 0; e < 2; e++) { + std::vector trialList; + sd::tests::fillList(seed, 10, shape, trialList, &rng); + + for (int a = 0; a < expList.size(); a++) { + auto arrayE = expList[a]; + auto arrayT = trialList[a]; + + bool t = arrayE->equalsTo(arrayT); + if (!t) { + + for (Nd4jLong f = 0; f < arrayE->lengthOf(); f++) { + double x = arrayE->e(f); + double y = arrayT->e(f); + + if (sd::math::nd4j_re(x, y) > 0.1) { + throw std::runtime_error("boom"); + } + } + ASSERT_TRUE(false); + } + + delete arrayT; + } + } + + for (auto v: expList) + delete v; +} + +TEST_F(RNGTests, Test_Uniform_4) { + auto x1 = NDArrayFactory::create('c', {1000000}); + + RandomLauncher::fillUniform(LaunchContext::defaultContext(), _rngB, &x1, 1.0, 2.0); + + /* Check up distribution */ + auto mean = x1.reduceNumber(reduce::Mean); + // mean.printIndexedBuffer("Mean should be 1.5"); + auto sumA = x1 - mean; //.reduceNumber(reduce::Sum); + + auto deviation = x1.varianceNumber(variance::SummaryStatsVariance, false); + + ASSERT_NEAR(mean.e(0), 1.5, 1e-3); + ASSERT_NEAR(1/12., deviation.e(0), 1e-3); +} +#endif + +TEST_F(RNGTests, test_choice_1) { + const auto x = NDArrayFactory::linspace(0, 10, 11); + const auto prob = NDArrayFactory::valueOf({11}, 1.0/11, 'c'); + auto z = NDArrayFactory::create('c', {1000}); + + RandomGenerator rng(119, 256); + NativeOpExecutioner::execRandom(sd::LaunchContext ::defaultContext(), random::Choice, &rng, x->buffer(), x->shapeInfo(), x->specialBuffer(), x->specialShapeInfo(), prob->buffer(), prob->shapeInfo(), prob->specialBuffer(), prob->specialShapeInfo(), z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo(), nullptr); + + delete x; + delete prob; +} + +TEST_F(RNGTests, test_uniform_119) { + auto x = NDArrayFactory::create('c', {2}, {1, 5}); + auto z = NDArrayFactory::create('c', {1, 5}); + + + sd::ops::randomuniform op; + auto status = op.execute({&x}, {&z}, {1.0, 2.0}, {}, {}); + ASSERT_EQ(Status::OK(), status); +} + +TEST_F(RNGTests, test_multinomial_1) { + + NDArray probs('f', { 3, 3 }, { 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3 }, sd::DataType::FLOAT32); + NDArray expected('f', { 3, 3 }, { 0., 1, 2, 2, 0, 0, 1, 2, 1 }, sd::DataType::INT64); + NDArray output('f', { 3, 3 }, sd::DataType::INT64); + NDArray samples('f', { 1 }, std::vector({3}), sd::DataType::INT32); + + sd::ops::random_multinomial op; + RandomGenerator rng(1234, 1234); + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64}, {}, {}, false) ); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + NDArray probsZ('c', { 1, 3 }, { 0.3, 0.3, 0.3 }, sd::DataType::FLOAT32); + NDArray expectedZ('c', { 3, 3 }, { 0., 0, 0, 0, 0, 0, 0, 0, 0 }, sd::DataType::INT64); + + auto result = op.evaluate({ &probsZ, &samples }, { }, { 1, INT64 }); + auto outputZ = result.at(0); + + ASSERT_EQ(Status::OK(), result.status()); + ASSERT_TRUE(expectedZ.isSameShape(outputZ)); + ASSERT_TRUE(expectedZ.equalsTo(outputZ)); +} + +TEST_F(RNGTests, test_multinomial_2) { + + NDArray samples('c', { 1 }, std::vector{ 20 }, sd::DataType::INT32); + NDArray probs('c', { 3, 5 }, { 0.2, 0.3, 0.5, 0.3, 0.5, 0.2, 0.5, 0.2, 0.3, 0.35, 0.25, 0.3, 0.25, 0.25, 0.5 }, sd::DataType::FLOAT32); + NDArray expected('c', { 3, 20 }, { 0, 2, 0, 2, 0, 4, 2, 0, 1, 2, 0, 2, 3, 0, 0, 2, 4, 4, 1, 0, 2, 3, 2, 3, 0, 1, 3, 1, 1, 1, 2, 4, 3, 3, 1, 4, 4, 2, 0, 0, 3, 3, 3, 0, 0, 2, 2, 3, 3, 0, 0, 2, 3, 4, 2, 2, 3, 2, 1, 2 }, sd::DataType::INT64); + NDArray output('c', { 3, 20 }, sd::DataType::INT64); + + sd::ops::random_multinomial op; + RandomGenerator rng(1234, 1234); + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, {}, false)); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + NDArray probs2('c', { 5, 3 }, { 0.2, 0.3, 0.5, 0.3, 0.5, 0.2, 0.5, 0.2, 0.3, 0.35, 0.25, 0.3, 0.25, 0.25, 0.5 }, sd::DataType::FLOAT32); + NDArray expected2('c', { 20, 3 }, { 0, 2, 3, 2, 3, 3, 0, 2, 3, 2, 3, 0, 0, 0, 0, 4, 1, 2, 2, 3, 2, 3, 1, 3, 1, 1, 3, 2, 1, 0, 0, 2, 0, 2, 4, 2, 3, 3, 3, 0, 3, 4, 0, 1, 2, 2, 0, 2, 4, 4, 0, 4, 2, 2, 1, 0, 1, 0, 0, 2 }, sd::DataType::INT64); + NDArray output2('c', { 20, 3 }, sd::DataType::INT64); + + rng.setStates(1234, 1234); + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs2, &samples }, { &output2 }, {}, { 1, INT64 }, {}, {}, false)); + ASSERT_TRUE(expected2.isSameShape(output2)); + ASSERT_TRUE(expected2.equalsTo(output2)); +} + +TEST_F(RNGTests, test_multinomial_3) { + + NDArray probs('c', { 4, 3 }, { 0.3, 0.3, 0.4, 0.3, 0.4, 0.3, 0.3, 0.3, 0.4, 0.4, 0.3, 0.3 }, sd::DataType::FLOAT32); + NDArray expected('c', { 4, 5 }, sd::DataType::INT64); + NDArray output('c', { 4, 5 }, sd::DataType::INT64); + NDArray samples('c', { 1 }, std::vector{ 5 }, sd::DataType::INT32); + RandomGenerator rng(1234, 1234); + + sd::ops::random_multinomial op; + + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 0, INT64 }, {}, {}, false)); + + rng.setStates(1234, 1234); + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, {}, false)); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +TEST_F(RNGTests, test_multinomial_4) { + + NDArray probs('c', { 3, 4 }, { 0.3, 0.3, 0.4, 0.3, 0.4, 0.3, 0.3, 0.3, 0.4, 0.4, 0.3, 0.3 }, sd::DataType::FLOAT32); + NDArray expected('c', { 5, 4 }, sd::DataType::INT64); + NDArray output('c', { 5, 4 }, sd::DataType::INT64); + NDArray samples('c', { 1 }, std::vector{ 5 }, sd::DataType::INT32); + + RandomGenerator rng(1234, 1234); + sd::ops::random_multinomial op; + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 1, INT64 }, {}, {}, false)); + + rng.setStates(1234, 1234); + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1, INT64 }, {}, {}, false)); + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); +} + +TEST_F(RNGTests, test_multinomial_5) { + // multinomial as binomial if 2 classes used + int batchValue = 1; + int ClassValue = 2; + int Samples = 100000; + + NDArray samples('c', { 1 }, std::vector{ 1.*Samples }, sd::DataType::INT32); + + NDArray probs('c', { ClassValue, batchValue }, { 1.0, 1.0 }, sd::DataType::FLOAT32); + + sd::ops::random_multinomial op; + + NDArray output('c', { Samples, batchValue }, sd::DataType::INT64); + RandomGenerator rng(1234, 1234); + + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1 }, {}, {}, false)); + auto deviation = output.varianceNumber(variance::SummaryStatsStandardDeviation, false); + auto mean = output.meanNumber(); + + // theoretical values for binomial + ASSERT_NEAR(0.5, deviation.e(0), 4e-3); // 1000000 3e-3); + ASSERT_NEAR(0.5, mean.e(0), 4e-3); // 1000000 3e-3); + + for (int i = 0; i < output.lengthOf(); i++) { + auto value = output.e(i); + ASSERT_TRUE(value >= 0 && value < ClassValue); + } + + auto resultR = op.evaluate({ &probs, &samples }, { }, { 1 }); + auto outputR = resultR.at(0); + ASSERT_EQ(Status::OK(), resultR.status()); + + deviation = outputR->varianceNumber(variance::SummaryStatsStandardDeviation, false); + mean = outputR->meanNumber(); + + ASSERT_NEAR(0.5, deviation.e(0), 45e-3); // 1000000 35e-3); + ASSERT_NEAR(0.5, mean.e(0), 45e-3); // 1000000 35e-3); + + for (int i = 0; i < outputR->lengthOf(); i++) { + auto value = outputR->e(i); + ASSERT_TRUE(value >= 0 && value < ClassValue); + } + +} + + +TEST_F(RNGTests, test_multinomial_6) { + + int batchValue = 1; + int ClassValue = 5; + int Samples = 100000; + + NDArray samples('c', { 1 }, std::vector{ 1. * Samples }, sd::DataType::INT32); + + sd::ops::random_multinomial op; + NDArray probExpect('c', { ClassValue }, { 0.058, 0.096, 0.1576, 0.2598, 0.4287 }, sd::DataType::DOUBLE); + + // without seed + NDArray probsR('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, sd::DataType::FLOAT32); + + auto resultR = op.evaluate({ &probsR, &samples }, { }, { 0 }); + auto outputR = resultR.at(0); + ASSERT_EQ(Status::OK(), resultR.status()); + + NDArray countsR('c', { ClassValue }, { 0., 0, 0, 0, 0 }, sd::DataType::DOUBLE); + + for (int i = 0; i < outputR->lengthOf(); i++) { + auto value = outputR->e(i); + ASSERT_TRUE(value >= 0 && value < ClassValue); + double* z = countsR.bufferAsT(); + z[value] += 1; + } + + for (int i = 0; i < countsR.lengthOf(); i++) { + auto c = countsR.e(i); + auto p = probExpect.e(i); + + ASSERT_NEAR((c / Samples), p, 45e-3); // 1000000 35e-3); + } + + auto deviation = outputR->varianceNumber(variance::SummaryStatsStandardDeviation, false); + auto mean = outputR->meanNumber(); + + ASSERT_NEAR(1.2175, deviation.e(0), 45e-3); // 1000000 35e-3); + ASSERT_NEAR(2.906, mean.e(0), 45e-3); // 1000000 35e-3); + + + RandomGenerator rng(1234, 1234); + NDArray probs('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, sd::DataType::FLOAT32); + NDArray output('c', { batchValue, Samples }, sd::DataType::INT64); + + ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, {}, false)); + + NDArray counts('c', { ClassValue }, { 0., 0, 0, 0, 0 }, sd::DataType::DOUBLE); + + for (int i = 0; i < output.lengthOf(); i++) { + auto value = output.e(i); + ASSERT_TRUE(value >= 0 && value < ClassValue); + double* z = counts.bufferAsT(); + z[value] += 1; + } + + for (int i = 0; i < counts.lengthOf(); i++) { + auto c = counts.e(i); + auto p = probExpect.e(i); + + ASSERT_NEAR((c / Samples), p, 4e-3); // 1000000 3e-3); + } + + deviation = output.varianceNumber(variance::SummaryStatsStandardDeviation, false); + mean = output.meanNumber(); + ASSERT_NEAR(1.2175, deviation.e(0), 5e-3); // 1000000 3e-3); + ASSERT_NEAR(2.906, mean.e(0), 5e-3); // 1000000 3e-3); +} diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ResultSetTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ResultSetTests.cpp new file mode 100644 index 000000000..3e9753dc0 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ResultSetTests.cpp @@ -0,0 +1,51 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver on 4/18/2019. +// + +#include "testlayers.h" +#include +#include +#include + +using namespace sd; +using namespace sd::graph; + +class ResultSetTests : public testing::Test { +public: + +}; + +TEST_F(ResultSetTests, basic_test_1) { + auto x = NDArrayFactory::create('c', {3, 5}); + + auto tensors = x.allTensorsAlongDimension({1}); + ASSERT_EQ(3, tensors.size()); + + ResultSet set = tensors; + ASSERT_EQ(3, tensors.size()); + ASSERT_EQ(3, set.size()); + + for (int e = 0; e < set.size(); e++) + ASSERT_EQ(5, set.at(e)->lengthOf()); + + for (int e = 0; e < tensors.size(); e++) + ASSERT_EQ(5, tensors.at(e)->lengthOf()); +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/SanityTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/SanityTests.cpp new file mode 100644 index 000000000..fffb49854 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/SanityTests.cpp @@ -0,0 +1,64 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 13/11/17. +// + +#include "testlayers.h" +#include +#include +#include + +using namespace sd; +using namespace sd::graph; + +class SanityTests : public testing::Test { +public: + +}; + + +TEST_F(SanityTests, VariableSpace_1) { + VariableSpace variableSpace; + variableSpace.putVariable(1, new Variable()); + variableSpace.putVariable(1, 1, new Variable()); + + std::pair pair(1, 2); + variableSpace.putVariable(pair, new Variable()); +} + +TEST_F(SanityTests, VariableSpace_2) { + VariableSpace variableSpace; + variableSpace.putVariable(1, new Variable(NDArrayFactory::create_('c', {3, 3}))); + variableSpace.putVariable(1, 1, new Variable(NDArrayFactory::create_('c', {3, 3}))); + + std::pair pair(1, 2); + variableSpace.putVariable(pair, new Variable(NDArrayFactory::create_('c', {3, 3}))); +} + + +TEST_F(SanityTests, Graph_1) { + Graph graph; + + graph.getVariableSpace()->putVariable(1, new Variable(NDArrayFactory::create_('c', {3, 3}))); + graph.getVariableSpace()->putVariable(1, 1, new Variable(NDArrayFactory::create_('c', {3, 3}))); + + std::pair pair(1, 2); + graph.getVariableSpace()->putVariable(pair, new Variable(NDArrayFactory::create_('c', {3, 3}))); +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ScalarTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ScalarTests.cpp new file mode 100644 index 000000000..e75d2acba --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ScalarTests.cpp @@ -0,0 +1,238 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include + +using namespace sd; +using namespace sd::graph; + +class ScalarTests : public testing::Test { +public: + +}; + +TEST_F(ScalarTests, Test_Create_1) { + auto x = NDArrayFactory::create(2.0f); + + ASSERT_EQ(0, x.rankOf()); + ASSERT_EQ(1, x.lengthOf()); + ASSERT_TRUE(x.isScalar()); + ASSERT_FALSE(x.isVector()); + ASSERT_FALSE(x.isRowVector()); + ASSERT_FALSE(x.isColumnVector()); + ASSERT_FALSE(x.isMatrix()); +} + +TEST_F(ScalarTests, Test_Add_1) { + auto x = NDArrayFactory::create(2.0f); + auto exp = NDArrayFactory::create(5.0f); + + x += 3.0f; + + ASSERT_NEAR(5.0f, x.e(0), 1e-5f); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); +} + +TEST_F(ScalarTests, Test_Add_2) { + auto x = NDArrayFactory::create(2.0f); + auto y = NDArrayFactory::create(3.0f); + auto exp = NDArrayFactory::create(5.0f); + + x += y; + + ASSERT_NEAR(5.0f, x.e(0), 1e-5f); + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); +} + +TEST_F(ScalarTests, Test_Add_3) { + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto y = NDArrayFactory::create(3.0f); + auto exp = NDArrayFactory::create('c', {3}, {4, 5, 6}); + + x += y; + + ASSERT_TRUE(exp.isSameShape(&x)); + + ASSERT_TRUE(exp.equalsTo(&x)); +} + +TEST_F(ScalarTests, Test_EQ_1) { + auto x = NDArrayFactory::create(2.0f); + auto y = NDArrayFactory::create(3.0f); + + ASSERT_TRUE(y.isSameShape(&x)); + ASSERT_FALSE(y.equalsTo(&x)); +} + +TEST_F(ScalarTests, Test_Concat_1) { + auto t = NDArrayFactory::create(1.0f); + auto u = NDArrayFactory::create(2.0f); + auto v = NDArrayFactory::create(3.0f); + auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); + + sd::ops::concat op; + auto result = op.evaluate({&t, &u, &v}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(ScalarTests, Test_Concat_2) { + auto t = NDArrayFactory::create(1.0f); + auto u = NDArrayFactory::create('c', {3}, {2, 3, 4}); + auto v = NDArrayFactory::create(5.0f); + auto exp = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + + sd::ops::concat op; + auto result = op.evaluate({&t, &u, &v}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + // z->printIndexedBuffer(); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(ScalarTests, Test_Concat_3) { + auto t = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto u = NDArrayFactory::create(4.0f); + auto v = NDArrayFactory::create(5.0f); + auto exp = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + + sd::ops::concat op; + auto result = op.evaluate({&t, &u, &v}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + //z->printShapeInfo("z"); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(ScalarTests, Test_ExpandDims_1) { + auto x = NDArrayFactory::create(2.0f); + auto exp = NDArrayFactory::create('c', {1}, {2.0f}); + + sd::ops::expand_dims op; + auto result = op.evaluate({&x}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(ScalarTests, Test_Squeeze_1) { + auto x = NDArrayFactory::create(2.0f); + auto exp = NDArrayFactory::create(2.0f); + + sd::ops::squeeze op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(ScalarTests, Test_Permute_1) { + auto x = NDArrayFactory::create(3.0f); + auto exp = NDArrayFactory::create(3.0f); + + sd::ops::permute op; + auto result = op.evaluate({&x}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(ScalarTests, Test_Concat_Scalar_1) { + auto t = NDArrayFactory::create('c', {1, 1}, {1.0f}); + auto u = NDArrayFactory::create('c', {1, 1}, {2.0f}); + auto v = NDArrayFactory::create('c', {1, 1}, {3.0f}); + auto w = NDArrayFactory::create('c', {1, 1}, {4.0f}); + auto exp = NDArrayFactory::create('c', {4, 1}, {1, 2, 3, 4}); + + sd::ops::concat op; + auto result = op.evaluate({&t, &u, &v, &w}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + + +TEST_F(ScalarTests, Test_Concat_Scalar_2) { + auto t = NDArrayFactory::create('c', {1, 1}, {1.0f}); + auto u = NDArrayFactory::create('c', {1, 1}, {2.0f}); + auto v = NDArrayFactory::create('c', {1, 1}, {3.0f}); + auto w = NDArrayFactory::create('c', {1, 1}, {4.0f}); + auto exp = NDArrayFactory::create('c', {1, 4}, {1, 2, 3, 4}); + + sd::ops::concat op; + auto result = op.evaluate({&t, &u, &v, &w}, {}, {1}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ScopeTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ScopeTests.cpp new file mode 100644 index 000000000..64a892387 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ScopeTests.cpp @@ -0,0 +1,167 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 15.10.2017. +// + +#include "testlayers.h" +#include +#include +#include + +using namespace sd; +using namespace sd::graph; + +class ScopeTests : public testing::Test { +public: + +}; + +TEST_F(ScopeTests, BasicTests_1) { + Graph graph; + + auto x = NDArrayFactory::create_('c', {2, 2}); + x->assign(0.0f); + + auto variableSpace = graph.getVariableSpace(); + variableSpace->putVariable(-1, x); + + sd::ops::Scope opScope; + + auto scopeBody = new Node(OpType_LOGIC, 10, 1); + scopeBody->setName("scopeBody"); + scopeBody->setCustomOp(&opScope); + + + graph.addNode(scopeBody); + + ASSERT_EQ(1, graph.totalNodes()); + + auto scopedB0 = new Node(OpType_SCALAR, 0, 6, {-1}, {}, {}, 1.0f); + scopedB0->markInplace(true); + scopedB0->setScopeInfo(1, "scopeBody"); + + graph.addNode(scopedB0); + + ASSERT_EQ(1, graph.totalNodes()); + +} +/* +TEST_F(ScopeTests, RealTests_1) { + Graph graph; + + auto x = NDArrayFactory::create_('c', {2, 2}); + x->assign(0.0f); + + auto y = NDArrayFactory::create_('c', {2, 2}); + y->assign(0.0); + +// auto scalar = NDArrayFactory::create_('c', {1, 1}); + auto scalar = NDArrayFactory::create_(10.f); + //scalar->p(0, 10); + + auto variableSpace = graph.getVariableSpace(); + variableSpace->putVariable(-1, x); + variableSpace->putVariable(-2, y); + variableSpace->putVariable(-3, scalar); + + // just few ops coming before while + auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 1, {-1}); + auto nodeB = new Node(OpType_SCALAR, scalar::Add, 2, {1}, {}, {}, 1.0); + + // + auto scopeCondition = new Node(OpType_LOGIC, logic::Scope, 3); + scopeCondition->setName("scopeCondition"); + sd::ops::Scope opScope; + scopeCondition->setCustomOp(&opScope); + + // this is scope of the body, it'll be executed multiple times + auto scopeBody = new Node(OpType_LOGIC, logic::Scope, 10); + scopeBody->setName("scopeBody"); + scopeBody->setCustomOp(&opScope); + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//// filling out condition scope +//////////////////////////////////////////////////////////////////////////////////////////////////// + // this is Sum accumulation, which feed + auto scopedA0 = new Node(OpType_REDUCE_SAME, reduce::Sum, 4, {12}); + scopedA0->setScopeInfo(3, "scopeCondition"); + + // this op compares LT A0 result with variable `scalar` which is 10; + sd::ops::lt_scalar op; + auto scopedA1 = new Node(&op, 5, {4, -3}); + scopedA1->setScopeInfo(3, "scopeCondition"); + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +//// filling out body scope +//////////////////////////////////////////////////////////////////////////////////////////////////// + auto scopedB0 = new Node(OpType_SCALAR, scalar::Add, 6, {12}, {}, {}, 1.0f); + scopedB0->markInplace(false); + scopedB0->setScopeInfo(10, "scopeBody"); + + auto nodeReturn = new Node(OpType_LOGIC, logic::Return, 7, {6}, {12}); + sd::ops::Return opReturn; + nodeReturn->setCustomOp(&opReturn); + nodeReturn->setScopeInfo(10, "scopeBody"); + + // WHILE operations takes 2 scopes - :0 is condition scope, and :1 is loop body scope + auto nodeWhile = new Node(OpType_LOGIC, logic::While, 12, {-2, 3, 10}); + sd::ops::While opWhile; + nodeWhile->setCustomOp(&opWhile); + + // adding root nodes first, nothing unusual expected here + graph.addNode(nodeA); + graph.addNode(nodeB); + + // now we're registering our scopes + graph.addNode(scopeCondition); + graph.addNode(scopeBody); + + // at this moment graph should have 4 (four) nodes registered + ASSERT_EQ(4, graph.totalNodes()); + + // adding node that's attached to some scope. so it should be pushed to specific scope + graph.addNode(scopedA0); + + // we should still have 4 ops in graph, because node added above - goes directly into the scope + // thus falls out of the graph direct execution - it can be executed only via Scope + ASSERT_EQ(4, graph.totalNodes()); + + graph.addNode(scopedA1); + graph.addNode(scopedB0); + graph.addNode(nodeReturn); + + // should be still 4. no options here. + ASSERT_EQ(4, graph.totalNodes()); + + // WHILE is valid node, so we expect nodes counter to go up + graph.addNode(nodeWhile); + ASSERT_EQ(5, graph.totalNodes()); + + // now, let's try to execute graph + Nd4jStatus status = GraphExecutioner::execute(&graph); + ASSERT_EQ(ND4J_STATUS_OK, status); + + auto w = variableSpace->getVariable(12, 0)->getNDArray(); + + w->printShapeInfo("w shape"); + ASSERT_NEAR(12.f, w->sumNumber().e(0), 1e-5f); +} +*/ \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ServerRelatedTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ServerRelatedTests.cpp new file mode 100644 index 000000000..7b5f7d3c3 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ServerRelatedTests.cpp @@ -0,0 +1,190 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include + +using namespace sd; +using namespace sd::graph; + +class ServerRelatedTests : public testing::Test { +public: + ServerRelatedTests() { + Environment::getInstance().setDebug(true); + Environment::getInstance().setVerbose(true); + } + + ~ServerRelatedTests() { + Environment::getInstance().setDebug(false); + Environment::getInstance().setVerbose(false); + } +}; +/* +TEST_F(ServerRelatedTests, Basic_Output_Test_1) { + flatbuffers::FlatBufferBuilder builder(4096); + + auto array1 = NDArrayFactory::create_('c', {10, 10}); + auto array2 = NDArrayFactory::create_('c', {10, 10}); + auto array3 = NDArrayFactory::create_('c', {10, 10}); + + array1->assign(1.0f); + array2->assign(2.0f); + array3->assign(3.0f); + + Variable var1(array1, "first", 1); + Variable var2(array2, "second", 2); + Variable var3(array3, "second indexed", 2, 1); + + ExecutionResult result({&var1, &var2, &var3}); + + ASSERT_EQ(*array1, *result.at(0)->getNDArray()); + ASSERT_EQ(*array2, *result.at(1)->getNDArray()); + ASSERT_EQ(*array3, *result.at(2)->getNDArray()); + + ASSERT_EQ(*array1, *result.byId("first")->getNDArray()); + ASSERT_EQ(*array2, *result.byId("second")->getNDArray()); + ASSERT_EQ(*array3, *result.byId("second indexed")->getNDArray()); + + auto flatResult = result.asFlatResult(builder); + builder.Finish(flatResult); + auto ptr = builder.GetBufferPointer(); + auto received = GetFlatResult(ptr); + + ExecutionResult restored(received); + ASSERT_EQ(3, restored.size()); + + ASSERT_EQ(*array1, *restored.at(0)->getNDArray()); + ASSERT_EQ(*array2, *restored.at(1)->getNDArray()); + ASSERT_EQ(*array3, *restored.at(2)->getNDArray()); + + ASSERT_EQ(*array1, *restored.byId("first")->getNDArray()); + ASSERT_EQ(*array2, *restored.byId("second")->getNDArray()); + ASSERT_EQ(*array3, *restored.byId("second indexed")->getNDArray()); +} +*/ +#if GRAPH_FILES_OK +TEST_F(ServerRelatedTests, Basic_Execution_Test_1) { + flatbuffers::FlatBufferBuilder builder(4096); + auto oGraph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb"); + oGraph->printOut(); + + auto exp = NDArrayFactory::create('c', {3}, {3.f, 3.f, 3.f}); + + GraphHolder::getInstance().registerGraph(11901L, oGraph); + + auto cGraph = GraphHolder::getInstance().cloneGraph(11901L); + + ASSERT_TRUE(oGraph != cGraph); + + auto flatResult = GraphExecutioner::execute(cGraph, builder, nullptr); + + builder.Finish(flatResult); + auto ptr = builder.GetBufferPointer(); + auto received = GetFlatResult(ptr); + + ExecutionResult restored(received); + ASSERT_EQ(1, restored.size()); + + ASSERT_EQ(exp, *restored.at(0)->getNDArray()); + + delete cGraph; + + GraphHolder::getInstance().dropGraphAny(11901L); +} + +TEST_F(ServerRelatedTests, Basic_Execution_Test_2) { + flatbuffers::FlatBufferBuilder builder(4096); + flatbuffers::FlatBufferBuilder otherBuilder(4096); + auto oGraph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb"); + oGraph->printOut(); + + auto input0 = NDArrayFactory::create('c', {3, 3}, {2.f,2.f,2.f, 2.f,2.f,2.f, 2.f,2.f,2.f}); + auto exp = NDArrayFactory::create('c', {3}, {6.f, 6.f, 6.f}); + + GraphHolder::getInstance().registerGraph(11902L, oGraph); + + auto cGraph = GraphHolder::getInstance().cloneGraph(11902L); + + ASSERT_TRUE(oGraph != cGraph); + + // mastering InferenceRequest + InferenceRequest ir(11902L); + ir.appendVariable(1, 0, &input0); + + auto af = ir.asFlatInferenceRequest(otherBuilder); + otherBuilder.Finish(af); + auto fptr = otherBuilder.GetBufferPointer(); + auto fir = GetFlatInferenceRequest(fptr); + + auto flatResult = GraphExecutioner::execute(cGraph, builder, fir); + + builder.Finish(flatResult); + auto ptr = builder.GetBufferPointer(); + auto received = GetFlatResult(ptr); + + ExecutionResult restored(received); + ASSERT_EQ(1, restored.size()); + + ASSERT_EQ(exp, *restored.at(0)->getNDArray()); + + delete cGraph; + + GraphHolder::getInstance().dropGraphAny(11902L); +} + +TEST_F(ServerRelatedTests, BasicExecutionTests_3) { + flatbuffers::FlatBufferBuilder builder(4096); + flatbuffers::FlatBufferBuilder otherBuilder(4096); + auto oGraph = GraphExecutioner::importFromFlatBuffers("./resources/reduce_dim_false.fb"); + oGraph->printOut(); + + auto input0 = NDArrayFactory::create('c', {3, 3}, {2.f,2.f,2.f, 2.f,2.f,2.f, 2.f,2.f,2.f}); + auto exp = NDArrayFactory::create('c', {3}, {6.f, 6.f, 6.f}); + + GraphHolder::getInstance().registerGraph(11903L, oGraph); + + // mastering InferenceRequest + InferenceRequest ir(11903L); + ir.appendVariable(1, 0, &input0); + + auto af = ir.asFlatInferenceRequest(otherBuilder); + otherBuilder.Finish(af); + auto fptr = otherBuilder.GetBufferPointer(); + auto fir = GetFlatInferenceRequest(fptr); + + + auto flatResult = GraphHolder::getInstance().execute(fir->id(), builder, fir); + + builder.Finish(flatResult); + auto ptr = builder.GetBufferPointer(); + auto received = GetFlatResult(ptr); + + ExecutionResult restored(received); + ASSERT_EQ(1, restored.size()); + + ASSERT_EQ(exp, *restored.at(0)->getNDArray()); + + GraphHolder::getInstance().dropGraphAny(11903L); +} +#endif diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ShapeTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ShapeTests.cpp new file mode 100644 index 000000000..d09c99616 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ShapeTests.cpp @@ -0,0 +1,336 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include "testlayers.h" +#include + +using namespace sd; +using namespace sd::graph; + +class ShapeTests : public testing::Test { +public: + +}; + + +TEST_F(ShapeTests, Test_Basics_1) { + Nd4jLong shape[] = {2, 5, 3, 3, 1, 0, 1, 99}; + + ASSERT_EQ(2, shape::rank(shape)); + ASSERT_EQ(1, shape::elementWiseStride(shape)); + ASSERT_EQ(5, shape::sizeAt(shape, 0)); + ASSERT_EQ(3, shape::sizeAt(shape, 1)); + ASSERT_EQ('c', shape::order(shape)); +} + + +TEST_F(ShapeTests, Test_Basics_2) { + Nd4jLong shape[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, -1, 102}; + + ASSERT_EQ(4, shape::rank(shape)); + ASSERT_EQ(-1, shape::elementWiseStride(shape)); + ASSERT_EQ(2, shape::sizeAt(shape, 0)); + ASSERT_EQ(3, shape::sizeAt(shape, 1)); + ASSERT_EQ(4, shape::sizeAt(shape, 2)); + ASSERT_EQ(5, shape::sizeAt(shape, 3)); + ASSERT_EQ('f', shape::order(shape)); +} + + +TEST_F(ShapeTests, Test_tadLength_1) { + Nd4jLong shape[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, -1, 102}; + int axis[] = {2, 3}; + + ASSERT_EQ(20, shape::tadLength(shape, axis, 2)); +} + + +TEST_F(ShapeTests, Test_ShapeEquality_1) { + Nd4jLong shape[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, -1, 102}; + Nd4jLong shape_GOOD[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, 1, 99}; + Nd4jLong shape_BAD[] = {4, 3, 3, 4, 5, 60, 20, 5, 1, 0, -1, 102}; + + + ASSERT_TRUE(shape::equalsSoft(shape, shape_GOOD)); + ASSERT_FALSE(shape::equalsSoft(shape, shape_BAD)); +} + +TEST_F(ShapeTests, Test_ShapeEquality_2) { + Nd4jLong shape[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, -1, 102}; + Nd4jLong shape_GOOD[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, -1, 102}; + Nd4jLong shape_BAD[] = {4, 2, 3, 4, 5, 60, 20, 5, 1, 0, -1, 99}; + + + ASSERT_TRUE(shape::equalsStrict(shape, shape_GOOD)); + ASSERT_FALSE(shape::equalsStrict(shape, shape_BAD)); +} + +TEST_F(ShapeTests, Test_Ind2SubC_1) { + Nd4jLong shape[] = {3, 5}; + Nd4jLong c0[2]; + shape::index2coords(0, 2, shape, c0); + + ASSERT_EQ(0, c0[0]); + ASSERT_EQ(0, c0[1]); + + Nd4jLong c1[2]; + shape::index2coords(1, 2, shape, c1); + + ASSERT_EQ(0, c1[0]); + ASSERT_EQ(1, c1[1]); + + Nd4jLong c6[2]; + shape::index2coords(5, 2, shape, c6); + + ASSERT_EQ(1, c6[0]); + ASSERT_EQ(0, c6[1]); +} + + +TEST_F(ShapeTests, Test_ShapeDetector_1) { + Nd4jLong shape[] = {2, 5, 3, 3, 1, 0, 1, 99}; + + ASSERT_TRUE(shape::isMatrix(shape)); +} + +TEST_F(ShapeTests, Test_ShapeDetector_2) { + Nd4jLong shape[] = {3, 2, 5, 3, 15, 3, 1, 0, 1, 99}; + + ASSERT_FALSE(shape::isMatrix(shape)); +} + +TEST_F(ShapeTests, Test_ShapeDetector_3) { + Nd4jLong shape[] = {2, 1, 3, 3, 1, 0, 1, 99}; + + ASSERT_FALSE(shape::isColumnVector(shape)); + ASSERT_TRUE(shape::isVector(shape)); + ASSERT_TRUE(shape::isRowVector(shape)); + ASSERT_FALSE(shape::isMatrix(shape)); +} + + +TEST_F(ShapeTests, Test_ShapeDetector_4) { + Nd4jLong shape[] = {2, 3, 1, 1, 1, 0, 1, 99}; + + ASSERT_TRUE(shape::isColumnVector(shape)); + ASSERT_TRUE(shape::isVector(shape)); + ASSERT_FALSE(shape::isRowVector(shape)); + ASSERT_FALSE(shape::isMatrix(shape)); +} + +TEST_F(ShapeTests, Test_ShapeDetector_5) { + Nd4jLong shape[] = {2, 1, 1, 1, 1, 0, 1, 99}; + + ASSERT_TRUE(shape::isScalar(shape)); + ASSERT_FALSE(shape::isMatrix(shape)); + + // edge case here. Technicaly it's still a vector with length of 1 + ASSERT_TRUE(shape::isVector(shape)); +} + +TEST_F(ShapeTests, Test_ShapeDetector_6) { + Nd4jLong shape[] = {2, 1, 1, 1, 1, 0, 1, 99}; + + ASSERT_EQ(8, shape::shapeInfoLength(shape)); + ASSERT_EQ(64, shape::shapeInfoByteLength(shape)); +} + +TEST_F(ShapeTests, Test_ShapeDetector_7) { + Nd4jLong shape[] = {3, 1, 1, 1, 1, 1, 1, 0, 1, 99}; + + ASSERT_EQ(10, shape::shapeInfoLength(shape)); + ASSERT_EQ(80, shape::shapeInfoByteLength(shape)); +} + +TEST_F(ShapeTests, Test_Transpose_1) { + Nd4jLong shape[] = {3, 2, 5, 3, 15, 3, 1, 0, 1, 99}; + Nd4jLong exp[] = {3, 3, 5, 2, 1, 3, 15, 0, 1, 102}; + + shape::transposeInplace(shape); + + ASSERT_TRUE(shape::equalsStrict(exp, shape)); +} + +TEST_F(ShapeTests, Test_Transpose_2) { + Nd4jLong shape[] = {2, 5, 3, 3, 1, 0, 1, 99}; + Nd4jLong exp[] = {2, 3, 5, 1, 3, 0, 1, 102}; + + shape::transposeInplace(shape); + + ASSERT_TRUE(shape::equalsStrict(exp, shape)); +} + +TEST_F(ShapeTests, Test_Transpose_3) { + Nd4jLong shape[] = {2, 1, 3, 3, 1, 0, 1, 99}; + Nd4jLong exp[] = {2, 3, 1, 1, 3, 0, 1, 102}; + + shape::transposeInplace(shape); + + ASSERT_TRUE(shape::equalsStrict(exp, shape)); +} + + +TEST_F(ShapeTests, Test_Transpose_4) { + Nd4jLong shape[] = {4, 2, 3, 4, 5, 5, 4, 3, 2, 0, 1, 99}; + Nd4jLong exp[] = {4, 5, 4, 3, 2, 2, 3, 4, 5, 0, 1, 102}; + + shape::transposeInplace(shape); + + ASSERT_TRUE(shape::equalsStrict(exp, shape)); +} + +TEST_F(ShapeTests, Test_Edge_1) { + auto x = NDArrayFactory::create('f', {1, 4, 1, 4}); + x.linspace(1); + + x.reshapei('c', {4, 4}); + + //x.printShapeInfo("reshape0"); + //x.printIndexedBuffer("x i"); + //x.printBuffer("x r"); + + x.reshapei({4, 1, 1, 4}); + + //x.printShapeInfo("reshape1"); +} + +TEST_F(ShapeTests, Test_Edge_2) { + auto x = NDArrayFactory::create('c', {1, 4, 1, 3}); + + x.reshapei('c', {3, 4}); + + //x.printShapeInfo("reshape0"); + + x.reshapei({3, 1, 1, 4}); + + //x.printShapeInfo("reshape1"); +} + + +TEST_F(ShapeTests, Test_Remove_Index_1) { + int array[] = {1, 2, 3}; + int idx[] = {0}; + int result[2]; + shape::removeIndex(array, idx, 3, 1, result); + + ASSERT_EQ(2, result[0]); + ASSERT_EQ(3, result[1]); +} + +TEST_F(ShapeTests, Test_Remove_Index_2) { + int array[] = {1, 2, 3}; + int idx[] = {1}; + int result[2]; + shape::removeIndex(array, idx, 3, 1, result); + + ASSERT_EQ(1, result[0]); + ASSERT_EQ(3, result[1]); +} + +TEST_F(ShapeTests, Test_Remove_Index_3) { + int array[] = {1, 2, 3}; + int idx[] = {2}; + int result[2]; + shape::removeIndex(array, idx, 3, 1, result); + + ASSERT_EQ(1, result[0]); + ASSERT_EQ(2, result[1]); +} + +TEST_F(ShapeTests, Test_Remove_Index_4) { + int array[] = {1, 2, 3}; + int idx[] = {0, 2}; + int result[1]; + shape::removeIndex(array, idx, 3, 2, result); + + ASSERT_EQ(2, result[0]); +} + +TEST_F(ShapeTests, Test_Remove_Index_5) { + int array[] = {1, 2, 3}; + int idx[] = {1, 0}; + int result[1]; + shape::removeIndex(array, idx, 3, 2, result); + + ASSERT_EQ(3, result[0]); +} + +TEST_F(ShapeTests, Test_Remove_Index_6) { + int array[] = {1, 2, 3}; + int idx[] = {1, 2}; + int result[1]; + shape::removeIndex(array, idx, 3, 2, result); + + ASSERT_EQ(1, result[0]); +} + +TEST_F(ShapeTests, Tests_Transpose_119_1) { + auto x = NDArrayFactory::create('c', {3, 2}); + auto y = NDArrayFactory::create('c', {2}, {1.0f, 0.0f}); + auto z = NDArrayFactory::create('c', {2, 3}); + + x.linspace(1.f); + + auto e = x.permute({1, 0}); + e.streamline('c'); + + sd::ops::transpose op; + auto result = op.execute({&x, &y}, {&z}, {}, {}, {}); + + ASSERT_EQ(Status::OK(), result); + ASSERT_TRUE(e.isSameShape(z)); + ASSERT_TRUE(e.equalsTo(z)); +} + +TEST_F(ShapeTests, Tests_Transpose_119_2) { + auto x = NDArrayFactory::create('c', {3, 5}); + x.linspace(1.f); + + auto exp = x.transpose(); + + sd::ops::transpose op; + auto result = op.evaluate({&x}); + ASSERT_EQ(Status::OK(), result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + +} + +TEST_F(ShapeTests, Tests_Transpose_119_3) { + auto x = NDArrayFactory::create('c', {3, 5}); + x.linspace(1.f); + + auto z = NDArrayFactory::create('c', {5, 3}); + + auto exp = x.transpose(); + + sd::ops::transpose op; + auto result = op.execute({&x}, {&z}, {}, {}, {}); + ASSERT_EQ(Status::OK(), result); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ShapeTests2.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ShapeTests2.cpp new file mode 100644 index 000000000..415c24898 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ShapeTests2.cpp @@ -0,0 +1,820 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by agibsonccc on 1/6/17. +// +#include +#include "testinclude.h" +#include +#include + +class OnesTest : public testing::Test { +public: + Nd4jLong shapeBuffer[12] = {4,4,3,1,1,3,1,1,1,0,1,99}; + int dimension[3] = {0,2,3}; + Nd4jLong tadAssertionShape[10] = {3,1,1,4,1,1,3,0,3,99}; + int dimensionLength = 3; +}; + +class LabelTest : public testing::Test { +public: + float labels[450] = {1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0}; + Nd4jLong shapeInfo[8] = {2,150,3,1,150,16384,1,102}; + int dimension[1] = {1}; + int dimensionLength = 1; + Nd4jLong tadShapeInfoAssert[8] = {2,1,3,1,150,16384,150,102}; +}; +class ThreeDTest : public testing::Test { +public: + Nd4jLong shape[3] = {3,4,5}; + Nd4jLong *shapeBuffer; + ThreeDTest() { + shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 3, shape); + } + ~ThreeDTest() { + delete[] shapeBuffer; + } +}; + +class VectorTest : public testing::Test { + +}; + +class NumTadTests : public testing::Test { +public: + Nd4jLong shape[3] = {3,4,5}; + int dimension = 0; +}; + +class ShapeTest : public testing::Test { +public: + Nd4jLong vectorShape[2] = {1,2}; +}; + +class MatrixTest : public testing::Test { +public: + int rows = 3; + int cols = 4; + int rank = 2; + int dims[2] = {0,1}; + Nd4jLong expectedShapes[2][2] = { + {1,3}, + {1,4} + }; + Nd4jLong expectedStrides[2][2] = { + {1,4}, + {1,1} + }; +}; + +class TADStall : public testing::Test { +public: + Nd4jLong shape[4] = {3,3,4,5}; + int dimensions[3] = {1,2,3}; +}; + +class TensorOneDimTest : public testing::Test { +public: + int rows = 3; + int cols = 4; + int dim2 = 5; + int rank = 3; + int dims[3] = {0,1,2}; + Nd4jLong expectedShapes[3][2] = { + {1,3}, + {1,4}, + {1,5} + }; + Nd4jLong expectedStrides[3][2] = { + {1,20}, + {1,5}, + {1,1} + }; +}; + +class TensorTwoDimTest : public testing::Test { +public: + //From a 3d array: + int rows = 3; + int cols = 4; + int dim2 = 5; + int dimensionLength = 2; + int dims[3][2] = { + {0,1},{0,2},{1,2} + }; + + Nd4jLong shape[3] {rows,cols,dim2}; + + //Along dimension 0,1: expect matrix with shape [rows,cols] + //Along dimension 0,2: expect matrix with shape [rows,dim2] + //Along dimension 1,2: expect matrix with shape [cols,dim2] + Nd4jLong expectedShapes[3][2] = { + {rows,cols}, + {rows,dim2}, + {cols,dim2} + }; + + Nd4jLong expectedStrides[3][2] = { + {20,5}, + {20,1}, + {5,1} + }; + +}; + +class TensorTwoFromFourDDimTest : public testing::Test { +public: + //From a 3d array: + int rows = 3; + int cols = 4; + int dim2 = 5; + int dim3 = 6; + Nd4jLong shape[4] = {rows,cols,dim2,dim3}; + int dimensionLength = 2; + //Along dimension 0,1: expect matrix with shape [rows,cols] + //Along dimension 0,2: expect matrix with shape [rows,dim2] + //Along dimension 0,3: expect matrix with shape [rows,dim3] + //Along dimension 1,2: expect matrix with shape [cols,dim2] + //Along dimension 1,3: expect matrix with shape [cols,dim3] + //Along dimension 2,3: expect matrix with shape [dim2,dim3] + + int dims[6][2] = { + {0,1}, + {0,2}, + {0,3}, + {1,2}, + {1,3}, + {2,3} + }; + + Nd4jLong expectedShapes[6][2] = { + {rows,cols}, + {rows,dim2}, + {rows,dim3}, + {cols,dim2}, + {cols,dim3} + ,{dim2,dim3} + }; + + Nd4jLong expectedStrides[6][2] = { + {120,30}, + {120,6}, + {120,1}, + {30,6}, + {30,1}, + {6,1} + }; +}; + + +class OrderTest : public testing::Test { +public: + Nd4jLong expected[8] = {2,3,4,1,3,0,0,102}; + Nd4jLong test[8] = {2,3,4,1,3,0,0,102}; + +}; + + +class LeadingOnes : public testing::Test { +public: + Nd4jLong shapeBufferF[16] = {4,1,1,4,4,1,1,1,4,16384,1,102}; // shapes with data type DOUBLE + Nd4jLong shapeBufferC[16] = {4,1,1,4,4,16,16,4,1,16384,1,99}; + int dimensionLength = 2; + int dimension[2] = {2,3}; + Nd4jLong tadAssertionC[10] = {3,4,4,1,4,1,16,16384,1,99}; + Nd4jLong tadCAssertionF[10] = {3,4,4,1,1,4,1,16384,1,102}; +}; + + +TEST_F(LeadingOnes,OnesTest) { + + shape::TAD *cTad = new shape::TAD; + cTad->init(shapeBufferC,dimension,dimensionLength); + cTad->createTadOnlyShapeInfo(); + cTad->createOffsets(); + shape::TAD *fTad = new shape::TAD; + fTad->init(shapeBufferF,dimension,dimensionLength); + fTad->createTadOnlyShapeInfo(); + fTad->createOffsets(); + // shape::printShapeInfoLinear(cTad->tadOnlyShapeInfo); + // shape::printShapeInfoLinear(fTad->tadOnlyShapeInfo); + ASSERT_TRUE(arrsEquals(10, tadCAssertionF, fTad->tadOnlyShapeInfo)); + ASSERT_TRUE(arrsEquals(10, tadAssertionC, cTad->tadOnlyShapeInfo)); + + delete cTad; + delete fTad; +} + + +class NormalThreeFourFive : public testing::Test { +public: + Nd4jLong assertionBuffer[8] = {2, 3, 4, 20, 5, 16384, 5, 99}; + Nd4jLong inputShapeBuffer[10] = {3,3,4,5,20,5,1,16384,1,99}; + int dimensionLength = 2; + int dimension[2] = {0,1}; +}; + + +TEST_F(NormalThreeFourFive,DimensionTest) { + shape::TAD *tad = new shape::TAD; + tad->init(inputShapeBuffer,dimension,dimensionLength); + tad->createTadOnlyShapeInfo(); + tad->createOffsets(); + ASSERT_TRUE(arrsEquals(8,assertionBuffer,tad->tadOnlyShapeInfo)); + + delete tad; +} + +class DimensionWarning : public testing::Test { +public: + int dimensionLength = 2; + int dimensions[2] = {0,1}; + Nd4jLong shape[3] = {1,5,1}; + Nd4jLong *shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 3, shape); + + ~DimensionWarning() { + delete[] shapeBuffer; + } +}; + + +TEST_F(DimensionWarning,ShapeWarning) { + shape::TAD *tad = new shape::TAD; + tad->init(shapeBuffer,dimensions,dimensionLength); + tad->createTadOnlyShapeInfo(); + tad->createOffsets(); + delete tad; +} + + +class TadRank : public testing::Test { + Nd4jLong shapeBuffer[12] = {4,2,1,3,3,9,9,3,1,0,1,99}; + int dimensionLength = 2; + int dimension[2] = {2,3}; + +}; + +class TestRemoveIndex : public testing::Test {}; + +class TestReverseCopy : public testing::Test {}; + +class TestConcat : public testing::Test {}; + +class SliceVectorTest : public testing::Test {}; + +class SliceMatrixTest : public testing::Test {}; + +class SliceTensorTest : public testing::Test {}; + +class ElementWiseStrideTest : public testing::Test { +public: + Nd4jLong shape[3] = {3,4,5}; + Nd4jLong stride[2] = {20,5}; + int elementWiseStrideAssertion = -1; +}; + +class PermuteTest : public testing::Test{}; + +class LengthPerSliceTest : public testing::Test{}; + +class ExpectedValuesTest : public testing::Test { +public: + Nd4jLong mainShape[4] = {9,7,5,3}; + int testDimensions[3] = {0,2,3}; + +}; + +class BeginOneTadTest : public testing::Test { +public: + Nd4jLong assertionShapeBuffer[8] = {2,3,5,1,3,16384,1,102}; + Nd4jLong inputShapeBuffer[10] = {3,1,3,5,1,1,3,16384,0,102}; + int dimensionLength = 2; + int dimension[2] = {1,2}; + //error: [2,1,1,1,1,0,1,97] +}; + +class FourDTest : public testing::Test { + /** + * INDArray array3d = Nd4j.ones(1, 10, 10); +array3d.sum(1); + +INDArray array4d = Nd4j.ones(1, 10, 10, 10); +INDArray sum40 = array4d.sum(0); + */ +public: + Nd4jLong threeDShape[3] = {1,10,10}; + Nd4jLong fourDShape[4] = {1,10,10,10}; + Nd4jLong *threeDShapeBuffer = nullptr,*fourDShapeBuffer = nullptr; + int dimensionThree = 1; + int dimensionThreeTwo = 0; + int dimensionFour = 0; + int dimensionLength = 1; + FourDTest() { + threeDShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'f', 3, threeDShape); + fourDShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'f', 4, fourDShape); + } + ~FourDTest() { + if(threeDShapeBuffer != nullptr) + delete[] threeDShapeBuffer; + if(fourDShapeBuffer != nullptr) + delete[] fourDShapeBuffer; + } + + + +}; + + + +TEST_F(FourDTest,ThreeDFourDTest) { + shape::TAD *threeTadTwo = new shape::TAD; + threeTadTwo->init(threeDShapeBuffer,&dimensionThreeTwo,dimensionLength); + threeTadTwo->createTadOnlyShapeInfo(); + threeTadTwo->createOffsets(); + + shape::TAD *threeTad = new shape::TAD; + threeTad->init(threeDShapeBuffer,&dimensionThree,dimensionLength); + threeTad->createTadOnlyShapeInfo(); + threeTad->createOffsets(); + + shape::TAD *fourTad = new shape::TAD; + fourTad->init(fourDShapeBuffer,&dimensionFour,dimensionLength); + fourTad->createTadOnlyShapeInfo(); + fourTad->createOffsets(); + + delete threeTadTwo; + delete threeTad; + delete fourTad; +} + + + +class RowVectorOnesTest : public testing::Test { +public: + Nd4jLong shapeBuffer[12] = {4,4,3,1,1,3,1,1,1,8192,1,99}; // float32 type of shape + float data[12] = {1,2,3,4,5,6,7,8,9,10,11,12}; + Nd4jLong assertionBuffer[10] = {3,4,1,1,3,1,1,8192,0,99}; + int dimensionLength = 3; + int dimension[3] = {0,2,3}; +}; + +// TEST_F(RowVectorOnesTest,TadShape) { +// shape::TAD *tad = new shape::TAD(shapeBuffer,dimension,dimensionLength); +// tad->createTadOnlyShapeInfo(); +// tad ->createOffsets(); +// ASSERT_TRUE(arrsEquals(10,assertionBuffer,tad->tadOnlyShapeInfo)); +// delete tad; +// } + + + +class SixDTest : public testing::Test { +public: + Nd4jLong inputShapeBuffer[16] = {6,1,1,4,4,4,4,1,1,1,4,16,64,16384,1,102}; // shape with double data type + int dimensionLength = 2; + int dimension[2] = {2,3}; + Nd4jLong assertionShapeBuffer[8] = {2,4,4,1,4,16384,1,102}; // also double typed shape +}; + +TEST_F(SixDTest, SixDWithOnes) { + shape::TAD *tad = new shape::TAD; + tad->init(inputShapeBuffer,dimension,dimensionLength); + tad->createTadOnlyShapeInfo(); + tad->createOffsets(); + // shape::printShapeInfoLinear(inputShapeBuffer); + // shape::printShapeInfoLinear(tad->tadOnlyShapeInfo); + //[2,1,1,1,1,0,1,97] + ASSERT_TRUE(arrsEquals(8,assertionShapeBuffer,tad->tadOnlyShapeInfo)); + delete tad; +} + +class TrailingTest : public testing::Test { +public: + Nd4jLong inputShapeBuffer[12] = {4,5,5,5,1,1,5,25,125,16384,1,102}; + int dimensionLength = 1; + int dimension[1] = {0}; + Nd4jLong assertionShapeBuffer[8] = {2,1,5,125,1,16384,1,102}; +}; + +TEST_F(TrailingTest,TrailingTest2) { + shape::TAD *tad = new shape::TAD; + tad->init(inputShapeBuffer,dimension,dimensionLength); + tad->createTadOnlyShapeInfo(); + tad->createOffsets(); + //[2,1,1,1,1,0,1,97] + ASSERT_TRUE(arrsEquals(8,assertionShapeBuffer,tad->tadOnlyShapeInfo)); + delete tad; +} + + +class ScalarTest : public testing::Test { +public: + Nd4jLong inputShapeBuffer[12] = {3,2,3,4,12,4,1,16384,1,99}; + int dimensionLength = 1; + int dimension[1] = {1}; + Nd4jLong assertionShapeBuffer[8] = {2,1,1,1,1,16384,1,99}; +}; +/* +TEST_F(ScalarTest,ScalarTest2) { + shape::TAD *tad = new shape::TAD(inputShapeBuffer,dimension,dimensionLength); + tad->createTadOnlyShapeInfo(); + tad ->createOffsets(); + //[2,1,1,1,1,0,1,97] + shape::printShapeInfoLinear(tad->tadOnlyShapeInfo); + ASSERT_TRUE(arrsEquals(8,assertionShapeBuffer,tad->tadOnlyShapeInfo)); +} +*/ + + + +class ThreeTest : public testing::Test { +public: + Nd4jLong inputShapeBuffer[10] = {3,4,3,2,6,2,1,16384,1,99}; + int dimensionLength = 1; + int dimension[1] = {0}; + Nd4jLong assertionShapeBuffer[8] = {2,1,4,1,6,16384,6,99}; +}; + +TEST_F(ThreeTest,ThreeTest ) { + shape::TAD *tad = new shape::TAD; + tad->init(inputShapeBuffer,dimension,dimensionLength); + tad->createTadOnlyShapeInfo(); + tad->createOffsets(); + //[2,1,1,1,1,0,1,97] + ASSERT_TRUE(arrsEquals(8,assertionShapeBuffer,tad->tadOnlyShapeInfo)); + delete tad; +} + + +TEST_F(BeginOneTadTest, TadTest) { + shape::TAD *tad = new shape::TAD; + tad->init(inputShapeBuffer,dimension,dimensionLength); + tad->createTadOnlyShapeInfo(); + auto tadShapeBuffer = tad->tadOnlyShapeInfo; + // shape::printShapeInfoLinear(tadShapeBuffer); + //[2,1,1,1,1,0,1,97] + ASSERT_TRUE(arrsEquals(8,assertionShapeBuffer,tadShapeBuffer)); + + delete tad; +} + +/* +TEST_F(OnesTest,OnesTadTest) { + shape::TAD *tad = new shape::TAD(shapeBuffer,dimension,dimensionLength); + int *tadShapeBuffer = tad->shapeInfoOnlyShapeAndStride(); + ASSERT_TRUE(arrsEquals(10,tadAssertionShape,tadShapeBuffer)); + delete[] tadShapeBuffer; +} +*/ + +TEST_F(LabelTest,LabelTad) { + shape::TAD *tad = new shape::TAD; + tad->init(shapeInfo,dimension,dimensionLength); + tad->createTadOnlyShapeInfo(); + auto tadShapeInfo = tad->tadOnlyShapeInfo; + ASSERT_TRUE(arrsEquals(8,tadShapeInfoAssert,tadShapeInfo)); + + delete tad; +} + +TEST_F(ExpectedValuesTest,TadTest) { + auto shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 4, mainShape); + shape::TAD *tad = new shape::TAD; + tad->init(shapeBuffer,testDimensions,3); + tad->createTadOnlyShapeInfo(); + auto shapeInfo = tad->tadOnlyShapeInfo; + + delete tad; + delete[] shapeBuffer; +} + +TEST_F(OrderTest,testOrder) { + int rank = shape::rank(expected); + auto expectedShape = shape::shapeOf(expected); + auto expectedStride = shape::stride(expected); + int realOrder = shape::getOrder(rank,expectedShape,expectedStride,1); + int expectedOrder = 102; + ASSERT_EQ(expectedOrder,realOrder); +} + + +TEST_F(ThreeDTest,TensorAlongDimensionTest) { + int dimension[2] = {0,2}; + Nd4jLong tadShapeAssertion[2] = {3,5}; + Nd4jLong strideAssertion[2] = {20,1}; + shape::TAD *tad = new shape::TAD; + tad->init(0,this->shapeBuffer,dimension,2); + tad->createTadOnlyShapeInfo(); + auto shapeBufferTest = tad->tadOnlyShapeInfo; + auto shapeTest = shape::shapeOf(shapeBufferTest); + auto strideTest = shape::stride(shapeBufferTest); + ASSERT_TRUE(arrsEquals(2,tadShapeAssertion,shapeTest)); + ASSERT_TRUE(arrsEquals(2,strideAssertion,strideTest)); + delete tad; +} + + +TEST_F(NumTadTests,TadTest) { + auto shape = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 3, this->shape); + shape::TAD *tad = new shape::TAD; + tad->init(shape,&dimension,1); + int numTads = shape::tensorsAlongDimension(shape,&dimension,1); + ASSERT_EQ(20,numTads); + delete[] shape; + delete tad; +} + +TEST_F(TADStall,TestStall) { + auto shapeInfo = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 4, shape); + shape::TAD *tad = new shape::TAD; + tad->init(0,shapeInfo,this->dimensions,3); + tad->createTadOnlyShapeInfo(); + Nd4jLong *test = tad->tadOnlyShapeInfo; + + delete[] shapeInfo; + delete tad; +} + + +TEST_F(LengthPerSliceTest,TestLengthPerSlice) { + Nd4jLong firstShape[2] = {5,3}; + int lengthPerSliceAssertionFirst = 3; + int firstDimension = 0; + int lengthPerSliceTest = shape::lengthPerSlice(2,firstShape,&firstDimension,1); + ASSERT_EQ(lengthPerSliceAssertionFirst,lengthPerSliceTest); +} + +TEST_F(PermuteTest,PermuteShapeBufferTest) { + int permuteOrder[4] = {3,2,1,0}; + int normalOrder[4] = {0,1,2,3}; + Nd4jLong shapeToPermute[4] = {5,3,2,6}; + Nd4jLong permutedOrder[4] = {6,2,3,5}; + auto shapeBufferOriginal = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 4, shapeToPermute); + auto assertionShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 4, shapeToPermute); + shape::permuteShapeBufferInPlace(shapeBufferOriginal,normalOrder,shapeBufferOriginal); + EXPECT_TRUE(arrsEquals(4,assertionShapeBuffer,shapeBufferOriginal)); + + auto backwardsAssertion = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 4, permutedOrder); + auto permuted = shape::permuteShapeBuffer(assertionShapeBuffer, permuteOrder); + EXPECT_TRUE(arrsEquals(4, backwardsAssertion, permuted)); + + + delete[] permuted; + delete[] backwardsAssertion; + delete[] shapeBufferOriginal; + delete[] assertionShapeBuffer; +} + +TEST_F(ElementWiseStrideTest,ElementWiseStrideTest) { + +} + +TEST_F(SliceVectorTest,RowColumnVectorTest) { + Nd4jLong rowVectorShape[2] = {1,5}; + auto rowVectorShapeInfo = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, rowVectorShape); + Nd4jLong colVectorShape[2] = {5,1}; + auto colVectorShapeInfo = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, colVectorShape); + Nd4jLong *sliceRow = shape::sliceOfShapeBuffer(0,rowVectorShapeInfo); + EXPECT_TRUE(arrsEquals(2,rowVectorShapeInfo,sliceRow)); + Nd4jLong *scalarSliceInfo = shape::createScalarShapeInfo(); + Nd4jLong *scalarColumnAssertion = shape::createScalarShapeInfo(); + scalarColumnAssertion[shape::shapeInfoLength(2) - 3] = 1; + Nd4jLong *scalarColumnTest = shape::sliceOfShapeBuffer(1L,colVectorShapeInfo); + EXPECT_TRUE(arrsEquals(2,scalarColumnAssertion,scalarColumnTest)); + + delete[] scalarColumnTest; + delete[] scalarColumnAssertion; + delete[] scalarSliceInfo; + delete[] sliceRow; + delete[] rowVectorShapeInfo; + delete[] colVectorShapeInfo; +} + +TEST_F(SliceTensorTest,TestSlice) { + Nd4jLong shape[3] = {3,3,2}; + auto shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 3, shape); + Nd4jLong sliceShape[2] = {3,2}; + auto sliceShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, sliceShape); + Nd4jLong *testSlice = shape::sliceOfShapeBuffer(0,shapeBuffer); + EXPECT_TRUE(arrsEquals(2,sliceShapeBuffer,testSlice)); + delete[] testSlice; + delete[] shapeBuffer; + delete[] sliceShapeBuffer; + +} + +TEST_F(SliceMatrixTest,TestSlice) { + Nd4jLong shape[2] = {3,2}; + auto shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, shape); + Nd4jLong sliceShape[2] = {1,2}; + auto sliceShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, sliceShape); + Nd4jLong *testSlice = shape::sliceOfShapeBuffer(0,shapeBuffer); + EXPECT_TRUE(arrsEquals(2,sliceShapeBuffer,testSlice)); + delete[] testSlice; + delete[] shapeBuffer; + delete[] sliceShapeBuffer; + +} + + +TEST_F(TestConcat,ConcatTest) { + Nd4jLong firstArr[2] = {1,2}; + Nd4jLong secondConcat[2] = {3,4}; + Nd4jLong concatAssertion[4] = {1,2,3,4}; + Nd4jLong *concatTest = shape::concat(firstArr,2,secondConcat,2); + EXPECT_TRUE(arrsEquals(4,concatAssertion,concatTest)); + delete[] concatTest; +} + +TEST_F(TestReverseCopy,ReverseCopyTest) { + Nd4jLong toCopy[5] = {0,1,2,3,4}; + Nd4jLong reverseAssertion[5] = {4,3,2,1,0}; + Nd4jLong *reverseCopyTest = shape::reverseCopy(toCopy,5); + EXPECT_TRUE(arrsEquals(5,reverseAssertion,reverseCopyTest)); + delete[] reverseCopyTest; +} + +TEST_F(TestRemoveIndex,Remove) { + Nd4jLong input[5] = {0,1,2,3,4}; + Nd4jLong indexesToRemove[3] = {0,1,2}; + Nd4jLong indexesToRemoveAssertion[2] = {3,4}; + Nd4jLong *indexesToRemoveTest = shape::removeIndex(input,indexesToRemove, (Nd4jLong) 5, (Nd4jLong) 3); + EXPECT_TRUE(arrsEquals(2,indexesToRemoveAssertion,indexesToRemoveTest)); + delete[] indexesToRemoveTest; +} + +TEST_F(TensorTwoFromFourDDimTest,TadTwoFromFourDimTest) { + //Along dimension 0,1: expect matrix with shape [rows,cols] + //Along dimension 0,2: expect matrix with shape [rows,dim2] + //Along dimension 0,3: expect matrix with shape [rows,dim3] + //Along dimension 1,2: expect matrix with shape [cols,dim2] + //Along dimension 1,3: expect matrix with shape [cols,dim3] + //Along dimension 2,3: expect matrix with shape [dim2,dim3] + auto baseShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 4, shape); + for(int i = 0; i < 3; i++) { + int *dimArr = dims[i]; + Nd4jLong *expectedShape = expectedShapes[i]; + shape::TAD *tad = new shape::TAD; + tad->init(baseShapeBuffer,dimArr,dimensionLength); + auto expectedShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', dimensionLength, expectedShape); + tad->createTadOnlyShapeInfo(); + Nd4jLong *testShapeBuffer = tad->tadOnlyShapeInfo; + EXPECT_TRUE(arrsEquals(shape::rank(expectedShapeBuffer),expectedShape,shape::shapeOf(testShapeBuffer))); + EXPECT_TRUE(arrsEquals(shape::rank(expectedShapeBuffer),expectedStrides[i],shape::stride(testShapeBuffer))); + + delete[] expectedShapeBuffer; + delete tad; + } + + delete[] baseShapeBuffer; +} + +TEST_F(TensorTwoDimTest,TadTwoDimTest) { + //Along dimension 0,1: expect matrix with shape [rows,cols] + //Along dimension 0,2: expect matrix with shape [rows,dim2] + //Along dimension 1,2: expect matrix with shape [cols,dim2] + auto baseShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 3, shape); + + for(int i = 0; i < 3; i++) { + int *dimArr = dims[i]; + Nd4jLong *expectedShape = expectedShapes[i]; + shape::TAD *tad = new shape::TAD; + tad->init(baseShapeBuffer,dimArr,dimensionLength); + auto expectedShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', dimensionLength, expectedShape); + tad->createTadOnlyShapeInfo(); + Nd4jLong *testShapeBuffer = tad->tadOnlyShapeInfo; + Nd4jLong *expectedStride = expectedStrides[i]; + Nd4jLong *testShape = shape::shapeOf(testShapeBuffer); + Nd4jLong *testStride = shape::stride(testShapeBuffer); + EXPECT_TRUE(arrsEquals(shape::rank(expectedShapeBuffer),expectedShape,testShape)); + EXPECT_TRUE(arrsEquals(shape::rank(testShapeBuffer),expectedStride,testStride)); + + delete[] expectedShapeBuffer; + delete tad; + + } + + delete[] baseShapeBuffer; + + +} + +TEST_F(TensorOneDimTest,TadDimensionsForTensor) { + Nd4jLong shape[3] = {rows,cols,dim2}; + auto shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', rank, shape); + + for(int i = 0; i < rank; i++) { + //Along dimension 0: expect row vector with length 'dims[i]' + shape::TAD *zero = new shape::TAD; + zero->init(shapeBuffer,&dims[i],1); + zero->createTadOnlyShapeInfo(); + Nd4jLong *testDimZeroShapeBuffer = zero->tadOnlyShapeInfo; + Nd4jLong *testShape = shape::shapeOf(testDimZeroShapeBuffer); + Nd4jLong *testStride = shape::stride(testDimZeroShapeBuffer); + EXPECT_TRUE(arrsEquals(2,expectedShapes[i],testShape)); + EXPECT_TRUE(arrsEquals(2,expectedStrides[i],testStride)); + + delete zero; + } + + delete[] shapeBuffer; +} + + +TEST_F(MatrixTest,TadDimensionsForMatrix) { + Nd4jLong shape[2] = {rows,cols}; + auto shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', rank, shape); + + shape::TAD *dimZero = new shape::TAD; + dimZero->init(shapeBuffer,&dims[0],1); + shape::TAD *dimOne = new shape::TAD; + dimOne->init(shapeBuffer,&dims[1],1); + //Along dimension 0: expect row vector with length 'rows' + Nd4jLong rowVectorShape[2] = {1,rows}; + auto expectedDimZeroShape = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, rowVectorShape); + dimZero->createTadOnlyShapeInfo(); + Nd4jLong *testDimZero = dimZero->tadOnlyShapeInfo; + EXPECT_TRUE(arrsEquals(2,expectedShapes[0],shape::shapeOf(testDimZero))); + EXPECT_TRUE(arrsEquals(2,expectedStrides[0],shape::stride(testDimZero))); + + delete[] expectedDimZeroShape; + //Along dimension 1: expect row vector with length 'cols' + Nd4jLong rowVectorColShape[2] {1,cols}; + auto expectedDimOneShape = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, rowVectorColShape); + dimOne->createTadOnlyShapeInfo(); + Nd4jLong *testDimOneShape = dimOne->tadOnlyShapeInfo; + EXPECT_TRUE(arrsEquals(2,expectedShapes[1],shape::shapeOf(testDimOneShape))); + EXPECT_TRUE(arrsEquals(2,expectedStrides[1],shape::stride(testDimOneShape))); + + delete[] expectedDimOneShape; + delete dimOne; + delete dimZero; + delete[] shapeBuffer; +} + +TEST_F(VectorTest,VectorTadShape) { + Nd4jLong rowVector[2] = {2,2}; + auto rowBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, rowVector); + int rowDimension = 1; + + Nd4jLong columnVector[2] = {2,2}; + auto colShapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, columnVector); + int colDimension = 0; + + + shape::TAD *rowTad = new shape::TAD; + rowTad->init(rowBuffer,&rowDimension,1); + rowTad->createTadOnlyShapeInfo(); + Nd4jLong *rowTadShapeBuffer = rowTad->tadOnlyShapeInfo; + Nd4jLong *rowTadShape = shape::shapeOf(rowTadShapeBuffer); + shape::TAD *colTad = new shape::TAD; + colTad->init(colShapeBuffer,&colDimension,1); + colTad->createTadOnlyShapeInfo(); + Nd4jLong *colTadShapeBuffer = colTad->tadOnlyShapeInfo; + Nd4jLong *colTadShape = shape::shapeOf(colTadShapeBuffer); + Nd4jLong assertionShape[2] = {1,2}; + Nd4jLong assertionStride[2] = {1,1}; + EXPECT_TRUE(arrsEquals(2,assertionShape,rowTadShape)); + EXPECT_TRUE(arrsEquals(2,assertionStride,shape::stride(rowTadShapeBuffer))); + EXPECT_TRUE(arrsEquals(2,assertionShape,colTadShape)); + + delete[] rowBuffer; + delete[] colShapeBuffer; + delete rowTad; + delete colTad; +} + + + + +TEST_F(ShapeTest,IsVector) { + ASSERT_TRUE(shape::isVector(vectorShape,2)); +} + +TEST_F(VectorTest,LinspaceCombinationTest) { + int rows = 3; + int cols = 4; + int len = rows * cols; + double *linspaced = linspace(1,rows * cols,len); + Nd4jLong shape[2] = {rows,cols}; + auto shapeBuffer = sd::ShapeBuilders::createShapeInfo(sd::DataType::FLOAT32, 'c', 2, shape); + + delete[] shapeBuffer; + delete[] linspaced; +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ShapeUtilsTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ShapeUtilsTests.cpp new file mode 100644 index 000000000..58d4dd4cb --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ShapeUtilsTests.cpp @@ -0,0 +1,295 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 01.11.2017. +// + +#include "testlayers.h" +#include +#include + + +using namespace sd; +using namespace sd::graph; + +class ShapeUtilsTests : public testing::Test { +public: + +}; + +////////////////////////////////////////////////////////////////// +TEST_F(ShapeUtilsTests, evalDimsToExclude_1) { + std::vector res = ShapeUtils::evalDimsToExclude(3, {0}); + + ASSERT_EQ(2, res.size()); + ASSERT_EQ(1, res.at(0)); + ASSERT_EQ(2, res.at(1)); +} + +////////////////////////////////////////////////////////////////// +TEST_F(ShapeUtilsTests, evalDimsToExclude_2) { + std::vector res = ShapeUtils::evalDimsToExclude(4, {2, 3}); + + ASSERT_EQ(2, res.size()); + ASSERT_EQ(0, res.at(0)); + ASSERT_EQ(1, res.at(1)); +} + +////////////////////////////////////////////////////////////////// +TEST_F(ShapeUtilsTests, EvalBroadcastShapeInfo_1) +{ + + Nd4jLong xShapeInfo[] = {3, 3, 2, 2, 4, 2, 1, 8192, 1, 99}; + Nd4jLong yShapeInfo[] = {2, 1, 2, 2, 1, 8192, 1, 99}; + Nd4jLong expShapeInfo[] = {3, 3, 2, 2, 4, 2, 1, 8192, 1, 99}; + + NDArray x(xShapeInfo); + NDArray y(yShapeInfo); + + const Nd4jLong *newShapeInfo = nullptr; + ShapeUtils::evalBroadcastShapeInfo(x, y, false, newShapeInfo, nullptr); + + ASSERT_TRUE(shape::equalsStrict(expShapeInfo, newShapeInfo)); +} + +////////////////////////////////////////////////////////////////// +TEST_F(ShapeUtilsTests, EvalBroadcastShapeInfo_2) +{ + + Nd4jLong xShapeInfo[] = {4, 8, 1, 6, 1, 6, 6, 1, 1, 8192, 1, 99}; + Nd4jLong yShapeInfo[] = {3, 7, 1, 5, 5, 5, 1, 8192, 1, 99}; + Nd4jLong expShapeInfo[] = {4, 8, 7, 6, 5, 210, 30, 5, 1, 8192, 1, 99}; + + NDArray x(xShapeInfo); + NDArray y(yShapeInfo); + + const Nd4jLong *newShapeInfo = nullptr; + ShapeUtils::evalBroadcastShapeInfo(x, y, false, newShapeInfo, nullptr); + + ASSERT_TRUE(shape::equalsStrict(expShapeInfo, newShapeInfo)); +} + +////////////////////////////////////////////////////////////////// +TEST_F(ShapeUtilsTests, EvalBroadcastShapeInfo_3) +{ + + Nd4jLong xShapeInfo[] = {3, 15, 3, 5, 15, 5, 1, 8192, 1, 99}; + Nd4jLong yShapeInfo[] = {3, 15, 1, 5, 5, 5, 1, 8192, 1, 99}; + Nd4jLong expShapeInfo[] = {3, 15, 3, 5, 15, 5, 1, 8192, 1, 99}; + + NDArray x(xShapeInfo); + NDArray y(yShapeInfo); + + const Nd4jLong *newShapeInfo = nullptr; + ShapeUtils::evalBroadcastShapeInfo(x, y, false, newShapeInfo, nullptr); + + ASSERT_TRUE(shape::equalsStrict(expShapeInfo, newShapeInfo)); +} + +////////////////////////////////////////////////////////////////// +TEST_F(ShapeUtilsTests, EvalBroadcastShapeInfo_4) +{ + + Nd4jLong xShapeInfo[] = {3, 8, 1, 3, 3, 3, 1, 8192, 1, 99}; + Nd4jLong yShapeInfo[] = {2, 4, 3, 3, 1, 8192, 1, 99}; + Nd4jLong expShapeInfo[] = {3, 8, 4, 3, 12, 3, 1, 8192, 1, 99}; + + NDArray x(xShapeInfo); + NDArray y(yShapeInfo); + + const Nd4jLong *newShapeInfo = nullptr; + ShapeUtils::evalBroadcastShapeInfo(x, y, false, newShapeInfo, nullptr); + //for(int i=0; i<2*newShapeInfo[0]+4; ++i) + // std::cout<('c',{2,3,4,5}); + auto expected = NDArrayFactory::create('c', {2,4,5}); + std::vector dimensions = {1}; + + auto newShapeInfo = ShapeUtils::evalReduceShapeInfo('c', dimensions, x.shapeInfo()); + + ASSERT_TRUE(shape::shapeEquals(expected.shapeInfo(), newShapeInfo)); +} + +////////////////////////////////////////////////////////////////// +TEST_F(ShapeUtilsTests, evalReduceShapeInfo_test2) +{ + + auto x = NDArrayFactory::create('c',{2,3,4,5}); + auto expected = NDArrayFactory::create('c', {2,1,4,5}); + std::vector dimensions = {1}; + + auto newShapeInfo = ShapeUtils::evalReduceShapeInfo('c', dimensions, x.shapeInfo(), true); + + ASSERT_TRUE(shape::shapeEquals(expected.shapeInfo(), newShapeInfo)); +} + +////////////////////////////////////////////////////////////////// +TEST_F(ShapeUtilsTests, evalReduceShapeInfo_test3) +{ + + auto x = NDArrayFactory::create('c',{2,3,4,5}); + auto expected = NDArrayFactory::create('c', {1,1,1,5}); + std::vector dimensions = {0,1,2}; + + auto newShapeInfo = ShapeUtils::evalReduceShapeInfo('c', dimensions, x.shapeInfo(), true); + + ASSERT_TRUE(shape::shapeEquals(expected.shapeInfo(), newShapeInfo)); + +} + +////////////////////////////////////////////////////////////////// +TEST_F(ShapeUtilsTests, evalReduceShapeInfo_test4) +{ + + auto x = NDArrayFactory::create('c',{2,3,4,5}); + auto expected = NDArrayFactory::create('c', {1,1,1,1}); + std::vector dimensions = {0,1,2,3}; + + auto newShapeInfo = ShapeUtils::evalReduceShapeInfo('c', dimensions, x.shapeInfo(), true); + + ASSERT_TRUE(shape::shapeEquals(expected.shapeInfo(), newShapeInfo)); +} + +TEST_F(ShapeUtilsTests, Test_Strings_1) { + auto x = NDArrayFactory::create('c', {2, 3, 4, 5}); + std::string exp("[2, 3, 4, 5]"); + + auto s = ShapeUtils::shapeAsString(&x); + + ASSERT_EQ(exp, s); +} + +TEST_F(ShapeUtilsTests, Test_Backward_Axis_1) { + auto x = NDArrayFactory::create('c', {2, 4, 3}); + auto y = NDArrayFactory::create('c', {4, 3}); + std::vector exp({0}); + + auto z = ShapeUtils::evalBroadcastBackwardAxis(y.shapeInfo(), x.shapeInfo()); + + ASSERT_EQ(exp, z); +} + +TEST_F(ShapeUtilsTests, Test_Backward_Axis_2) { + auto x = NDArrayFactory::create('c', {2, 4, 4, 3}); + auto y = NDArrayFactory::create('c', {4, 1, 3}); + std::vector exp({0, 2}); + + auto z = ShapeUtils::evalBroadcastBackwardAxis(y.shapeInfo(), x.shapeInfo()); + + ASSERT_EQ(exp, z); +} + + +TEST_F(ShapeUtilsTests, Test_Backward_Axis_3) { + auto x = NDArrayFactory::create('c', {2, 4, 4, 3}); + auto y = NDArrayFactory::create('c', {2, 1, 1, 3}); + std::vector exp({1, 2}); + + auto z = ShapeUtils::evalBroadcastBackwardAxis(y.shapeInfo(), x.shapeInfo()); + + ASSERT_EQ(exp, z); +} + +////////////////////////////////////////////////////////////////// +TEST_F(ShapeUtilsTests, evalPermutFromTo_test1) { + + int a=1, b=2, c=3, d=4; + std::vector expected = {2, 3, 0, 1}; + + std::vector result = ShapeUtils::evalPermutFromTo({a,b,c,d}, {c,d,a,b}); + + ASSERT_TRUE(std::equal(begin(expected), end(expected), begin(result))); + +} + +////////////////////////////////////////////////////////////////// +TEST_F(ShapeUtilsTests, evalPermutFromTo_test2) { + + int a=1, b=2, c=3, d=4; + std::vector expected = {0, 1, 3, 2}; + + std::vector result = ShapeUtils::evalPermutFromTo({a,b,c,d}, {a,b,d,c}); + + ASSERT_TRUE(std::equal(begin(expected), end(expected), begin(result))); + +} + +////////////////////////////////////////////////////////////////// +TEST_F(ShapeUtilsTests, evalPermutFromTo_test3) { + + int a=2, b=2, c=3, d=2; + std::vector expected = {0, 1, 3, 2}; + + std::vector result = ShapeUtils::evalPermutFromTo({a,b,c,d}, {a,b,d,c}); + + ASSERT_TRUE(std::equal(begin(expected), end(expected), begin(result))); + +} + +////////////////////////////////////////////////////////////////// +TEST_F(ShapeUtilsTests, evalPermutFromTo_test4) { + + int a=2, b=3, c=4, d=5; + + std::vector result = ShapeUtils::evalPermutFromTo({a,b,c,d}, {a,b,c,d}); + + ASSERT_TRUE(result.empty()); + +} + +////////////////////////////////////////////////////////////////// +TEST_F(ShapeUtilsTests, evalPermutFromTo_test5) { + + int a=1, b=2, c=3, d=4; + + // EXPECT_THROW(ShapeUtils::evalPermutFromTo({a,b,c,d}, {c,d,a,8}), const char*); + ASSERT_TRUE(1); +} + +////////////////////////////////////////////////////////////////// +TEST_F(ShapeUtilsTests, evalPermutFromTo_test6) { + + int a=1, b=2, c=3, d=4; + + // EXPECT_THROW(ShapeUtils::evalPermutFromTo({a,b,c,d}, {a,b,c,d,d}), const char*); + ASSERT_TRUE(1); +} + +////////////////////////////////////////////////////////////////// +TEST_F(ShapeUtilsTests, isPermutNecessary_test1) { + + ASSERT_TRUE(ShapeUtils::isPermutNecessary({1,0,2,3})); +} + +////////////////////////////////////////////////////////////////// +TEST_F(ShapeUtilsTests, isPermutNecessary_test2) { + + ASSERT_TRUE(!ShapeUtils::isPermutNecessary({0,1,2,3})); +} + + diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/SingleDimTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/SingleDimTests.cpp new file mode 100644 index 000000000..23931dee4 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/SingleDimTests.cpp @@ -0,0 +1,187 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include + +using namespace sd; +using namespace sd::graph; + +class SingleDimTests : public testing::Test { +public: + +}; + +TEST_F(SingleDimTests, Test_Create_1) { + auto x = NDArrayFactory::create('c', {5}, {1, 2, 3, 4, 5}); + ASSERT_EQ(5, x.lengthOf()); + ASSERT_EQ(1, x.rankOf()); + ASSERT_TRUE(x.isVector()); + ASSERT_TRUE(x.isRowVector()); + ASSERT_FALSE(x.isMatrix()); +} + +TEST_F(SingleDimTests, Test_Add_1) { + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {3}, {2, 3, 4}); + + x += 1.0f; + + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); +} + + +TEST_F(SingleDimTests, Test_Pairwise_1) { + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {3}, {2, 4, 6}); + + x += x; + + ASSERT_TRUE(exp.isSameShape(&x)); + ASSERT_TRUE(exp.equalsTo(&x)); +} + +TEST_F(SingleDimTests, Test_Concat_1) { + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto y = NDArrayFactory::create('c', {3}, {4, 5, 6}); + auto exp = NDArrayFactory::create('c', {6}, {1, 2, 3, 4, 5, 6}); + + sd::ops::concat op; + auto result = op.evaluate({&x, &y}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(SingleDimTests, Test_Reduce_1) { + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + + float r = x.reduceNumber(reduce::Sum).e(0); + + ASSERT_NEAR(6.0f, r, 1e-5f); +} + +TEST_F(SingleDimTests, Test_IndexReduce_1) { + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + + auto r = x.indexReduceNumber(indexreduce::IndexMax).e(0); + + ASSERT_NEAR(2, r, 1e-5f); +} + + +TEST_F(SingleDimTests, Test_ExpandDims_1) { + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {1, 3}, {1, 2, 3}); + + sd::ops::expand_dims op; + auto result = op.evaluate({&x}, {}, {0}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(SingleDimTests, Test_ExpandDims_2) { + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {3, 1}, {1, 2, 3}); + + sd::ops::expand_dims op; + auto result = op.evaluate({&x}, {}, {1}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + + +TEST_F(SingleDimTests, Test_Squeeze_1) { + std::vector vecS({1}); + std::vector vecB({3.0f}); + auto x = NDArrayFactory::create('c', vecS, vecB); + auto exp = NDArrayFactory::create(3.0f); + + sd::ops::squeeze op; + auto result = op.evaluate({&x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_EQ(exp.rankOf(), z->rankOf()); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(SingleDimTests, Test_Squeeze_2) { + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); + + sd::ops::squeeze op; + auto result = op.evaluate({&x}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} + +TEST_F(SingleDimTests, Test_Permute_1) { + auto x = NDArrayFactory::create('c', {3}, {1, 2, 3}); + auto exp = NDArrayFactory::create('c', {3}, {1, 2, 3}); + + sd::ops::permute op; + auto result = op.evaluate({&x}, {}, {0}); + ASSERT_EQ(ND4J_STATUS_OK, result.status()); + + auto z = result.at(0); + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/SortCpuTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/SortCpuTests.cpp new file mode 100644 index 000000000..3f21a7e79 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/SortCpuTests.cpp @@ -0,0 +1,106 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include + +using namespace sd; +using namespace sd::graph; + +class SortCpuTests : public testing::Test { +public: + +}; + + +TEST_F(SortCpuTests, test_linear_sort_by_key_1) { + if (!Environment::getInstance().isCPU()) + return; + + auto k = NDArrayFactory::create('c', {10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create('c', {10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + + auto ek = NDArrayFactory::create('c', {10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + + + sortByKey(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false); + + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); +} + +TEST_F(SortCpuTests, test_linear_sort_by_val_1) { + if (!Environment::getInstance().isCPU()) + return; + + auto k = NDArrayFactory::create('c', {10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create('c', {10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + + auto ek = NDArrayFactory::create('c', {10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + + + sortByValue(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false); + + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); +} + +TEST_F(SortCpuTests, test_tad_sort_by_key_1) { + if (!Environment::getInstance().isCPU()) + return; + + auto k = NDArrayFactory::create('c', {2, 10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create('c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + + auto ek = NDArrayFactory::create('c', {2, 10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create('c', {2, 10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + + + int axis = 1; + sortTadByKey(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false); + + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); +} + +TEST_F(SortCpuTests, test_tad_sort_by_val_1) { + if (!Environment::getInstance().isCPU()) + return; + + auto k = NDArrayFactory::create('c', {2, 10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create('c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + + auto ek = NDArrayFactory::create('c', {2, 10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create('c', {2, 10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + + + int axis = 1; + sortTadByValue(nullptr, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false); + + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); +} diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/SortCudaTests.cu b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/SortCudaTests.cu new file mode 100644 index 000000000..7525e2597 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/SortCudaTests.cu @@ -0,0 +1,126 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include + +using namespace sd; +using namespace sd::graph; + +class SortCudaTests : public testing::Test { +public: + +}; + + +TEST_F(SortCudaTests, test_linear_sort_by_key_1) { + auto k = NDArrayFactory::create('c', {10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create('c', {10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + + auto ek = NDArrayFactory::create('c', {10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + + Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; + + sortByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false); + k.tickWriteDevice(); + v.tickWriteDevice(); + + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); +} + +TEST_F(SortCudaTests, test_linear_sort_by_val_1) { + auto k = NDArrayFactory::create('c', {10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create('c', {10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + + auto ek = NDArrayFactory::create('c', {10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create('c', {10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + + Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; + + sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), false); + k.tickWriteDevice(); + v.tickWriteDevice(); + + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); +} + +TEST_F(SortCudaTests, test_linear_sort_by_val_2) { + auto k = NDArrayFactory::create('c', {6}, {0, 1, 2, 3, 4, 5}); +// auto v = NDArrayFactory::create('c', {6}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + NDArray v = NDArrayFactory::create('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); + auto ek = NDArrayFactory::create('c', {6}, {3, 0, 1, 2, 4, 5}); + auto ev = NDArrayFactory::create('c', {6}, {0.95, 0.9, 0.75, 0.6, 0.5, 0.3}); + + Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; + + sortByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), true); + k.tickWriteDevice(); + v.tickWriteDevice(); + // k.printIndexedBuffer("KEYS"); + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); +} + +TEST_F(SortCudaTests, test_tad_sort_by_key_1) { + auto k = NDArrayFactory::create('c', {2, 10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create('c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + + auto ek = NDArrayFactory::create('c', {2, 10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create('c', {2, 10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + + Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; + + int axis = 1; + sortTadByKey(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false); + k.tickWriteDevice(); + v.tickWriteDevice(); + + // k.printIndexedBuffer("k"); + // v.printIndexedBuffer("v"); + + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); +} + +TEST_F(SortCudaTests, test_tad_sort_by_val_1) { + auto k = NDArrayFactory::create('c', {2, 10}, {1, 3, 5, 9, 0, 2, 4, 6, 7, 8, 1, 3, 5, 9, 0, 2, 4, 6, 7, 8}); + auto v = NDArrayFactory::create('c', {2, 10}, {1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5, 1.5, 3.5, 5.5, 9.5, 0.5, 2.5, 4.5, 6.5, 7.5, 8.5}); + + auto ek = NDArrayFactory::create('c', {2, 10}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + auto ev = NDArrayFactory::create('c', {2, 10}, {0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5, 0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5}); + + Nd4jPointer extras[2] = {nullptr, LaunchContext::defaultContext()->getCudaStream()}; + + int axis = 1; + sortTadByValue(extras, k.buffer(), k.shapeInfo(), k.specialBuffer(), k.specialShapeInfo(), v.buffer(), v.shapeInfo(), v.specialBuffer(), v.specialShapeInfo(), &axis, 1, false); + k.tickWriteDevice(); + v.tickWriteDevice(); + + ASSERT_EQ(ek, k); + ASSERT_EQ(ev, v); +} diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/SparseUtilsTest.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/SparseUtilsTest.cpp new file mode 100644 index 000000000..8afef3701 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/SparseUtilsTest.cpp @@ -0,0 +1,248 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 04.08.17. +// + +#include "testlayers.h" +#include +#include +#include "ops/specials_sparse.h" +using namespace sd; + +////////////////////////////////////////////////////////////////////// +class SparseUtilsTest : public testing::Test { +public: + static const Nd4jLong nnz = 40; + static const int rank = 3; +}; + + +////////////////////////////////////////////////////////////////////// +TEST_F(SparseUtilsTest, SortCOOindices_Test) { + +#ifndef __CUDABLAS__ + + + Nd4jLong * indicesArr = new Nd4jLong[nnz * rank]{ + 0,2,7, + 2,36,35, + 3,30,17, + 5,12,22, + 5,43,45, + 6,32,11, + 8,8,32, + 9,29,11, + 5,11,22, + 15,26,16, + 17,48,49, + 24,28,31, + 26,6,23, + 31,21,31, + 35,46,45, + 37,13,14, + 6,38,18, + 7,28,20, + 8,29,39, + 8,32,30, + 9,42,43, + 11,15,18, + 13,18,45, + 29,26,39, + 30,8,25, + 42,31,24, + 28,33,5, + 31,27,1, + 35,43,26, + 36,8,37, + 39,22,14, + 39,24,42, + 42,48,2, + 43,26,48, + 44,23,49, + 45,18,34, + 46,28,5, + 46,32,17, + 48,34,44, + 49,38,39, + }; + + Nd4jLong * expIndicesArr = new Nd4jLong[nnz * rank]{ + 0, 2, 7, + 2, 36, 35, + 3, 30, 17, + 5, 11, 22, + 5, 12, 22, + 5, 43, 45, + 6, 32, 11, + 6, 38, 18, + 7, 28, 20, + 8, 8, 32, + 8, 29, 39, + 8, 32, 30, + 9, 29, 11, + 9, 42, 43, + 11, 15, 18, + 13, 18, 45, + 15, 26, 16, + 17, 48, 49, + 24, 28, 31, + 26, 6, 23, + 28, 33, 5, + 29, 26, 39, + 30, 8, 25, + 31, 21, 31, + 31, 27, 1, + 35, 43, 26, + 35, 46, 45, + 36, 8, 37, + 37, 13, 14, + 39, 22, 14, + 39, 24, 42, + 42, 31, 24, + 42, 48, 2, + 43, 26, 48, + 44, 23, 49, + 45, 18, 34, + 46, 28, 5, + 46, 32, 17, + 48, 34, 44, + 49, 38, 39, + }; + + auto values = NDArrayFactory::create('c', {40}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, + 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39}); + + auto expValues = NDArrayFactory::create('c', {40}, {0, 1, 2, 8, 3, 4, 5, 16, 17, 6, 18, 19, 7, 20, 21, 22, 9, + 10, 11, 12, 26, 23, 24, 13, 27, 28, 14, 29, 15, 30, 31, 25, 32, 33, + 34, 35, 36, 37, 38, 39 + }); + + sd::sparse::SparseUtils::sortCooIndicesGeneric(indicesArr, reinterpret_cast(values.buffer()), nnz, rank); + + for ( int i = 0; i < rank * nnz; ++i){ + ASSERT_EQ(expIndicesArr[i], indicesArr[i]); + } + + ASSERT_TRUE(expValues.equalsTo(values)); + + + delete[] indicesArr; + delete[] expIndicesArr; + + +#endif +} + +////////////////////////////////////////////////////////////////////// +TEST_F(SparseUtilsTest, RavelIndices_Test) { + +#ifndef __CUDABLAS__ + + Nd4jLong * indicesArrExp = new Nd4jLong[nnz * rank]{ + 0,2,7, + 2,36,35, + 3,30,17, + 5,12,22, + 5,43,45, + 6,32,11, + 8,8,32, + 9,29,11, + 5,11,22, + 15,26,16, + 17,48,49, + 24,28,31, + 26,6,23, + 31,21,31, + 35,46,45, + 37,13,14, + 6,38,18, + 7,28,20, + 8,29,39, + 8,32,30, + 9,42,43, + 11,15,18, + 13,18,45, + 29,26,39, + 30,8,25, + 42,31,24, + 28,33,5, + 31,27,1, + 35,43,26, + 36,8,37, + 39,22,14, + 39,24,42, + 42,48,2, + 43,26,48, + 44,23,49, + 45,18,34, + 46,28,5, + 46,32,17, + 48,34,44, + 49,38,39, + }; + Nd4jLong * indicesArr = new Nd4jLong[nnz * rank]; + + Nd4jLong * flatIndicesExp = new Nd4jLong[nnz]{ + 147, 10955, 14717, 21862, 24055, 27451, 34192, 39841, + 21792, 64836, 74809, 102791, 109643, 131701, 150265, 156324, + 27878, 31380, 35669, 35870, 40783, 47268, 55905, 123659, + 126585, 178594, 119915, 132091, 150036, 151797, 165354, 165522, + 179762, 182468, 186459, 190294, 195165, 195457, 204024, 208499 + }; + + Nd4jLong * flatIndices = new Nd4jLong[nnz]; + + + Nd4jLong * shape = new Nd4jLong[rank]{50, 60, 70}; + Nd4jLong * shapeInfoBuffer = shape::shapeBuffer(rank, sd::DataType::INT64, shape); + + + sd::sparse::IndexUtils::ravelMultiIndex(indicesArrExp, flatIndices, nnz, shapeInfoBuffer, ND4J_CLIPMODE_THROW); + + for ( int i = 0; i < nnz; ++i){ + ASSERT_EQ(flatIndicesExp[i], flatIndices[i]); + } + + sd::sparse::IndexUtils::unravelIndex(indicesArr, flatIndices, nnz, shapeInfoBuffer); + + for ( int i = 0; i < nnz * rank; ++i){ + ASSERT_EQ(indicesArrExp[i], indicesArr[i]); + } + + shape[2] = 30; + delete[] shapeInfoBuffer; + shapeInfoBuffer = shape::shapeBuffer(rank, sd::DataType::INT64, shape); + + try { + sd::sparse::IndexUtils::ravelMultiIndex(indicesArrExp, flatIndices, nnz, shapeInfoBuffer, ND4J_CLIPMODE_THROW); + FAIL(); + } catch (const std::runtime_error& e) { + // pass + } + + delete[] indicesArrExp; + delete[] indicesArr; + delete[] flatIndicesExp; + delete[] flatIndices; + delete[] shape; + delete[] shapeInfoBuffer; + +#endif +} diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/StashTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/StashTests.cpp new file mode 100644 index 000000000..fd479bd4e --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/StashTests.cpp @@ -0,0 +1,90 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef LIBND4J_STASHTESTS_H +#define LIBND4J_STASHTESTS_H + +#include +#include "testlayers.h" +#include + +using namespace sd; +using namespace sd::graph; + +class StashTests : public testing::Test { +public: + +}; + +TEST_F(StashTests, BasicTests_1) { + Stash stash; + + auto alpha = NDArrayFactory::create_('c',{5, 5}); + alpha->assign(1.0); + + auto beta = NDArrayFactory::create_('c',{5, 5}); + beta->assign(2.0); + + auto cappa = NDArrayFactory::create_('c',{5, 5}); + cappa->assign(3.0); + + stash.storeArray(1, "alpha", alpha); + stash.storeArray(2, "alpha", beta); + stash.storeArray(3, "cappa", cappa); + + ASSERT_TRUE(stash.checkStash(1, "alpha")); + ASSERT_TRUE(stash.checkStash(2, "alpha")); + ASSERT_TRUE(stash.checkStash(3, "cappa")); + + ASSERT_FALSE(stash.checkStash(3, "alpha")); + ASSERT_FALSE(stash.checkStash(2, "beta")); + ASSERT_FALSE(stash.checkStash(1, "cappa")); +} + + +TEST_F(StashTests, BasicTests_2) { + Stash stash; + + auto alpha = NDArrayFactory::create_('c',{5, 5}); + alpha->assign(1.0); + + auto beta = NDArrayFactory::create_('c',{5, 5}); + beta->assign(2.0); + + auto cappa = NDArrayFactory::create_('c',{5, 5}); + cappa->assign(3.0); + + stash.storeArray(1, "alpha", alpha); + stash.storeArray(1, "beta", beta); + stash.storeArray(1, "cappa", cappa); + + ASSERT_FALSE(stash.checkStash(2, "alpha")); + ASSERT_FALSE(stash.checkStash(2, "beta")); + ASSERT_FALSE(stash.checkStash(2, "cappa")); + + ASSERT_TRUE(alpha == stash.extractArray(1, "alpha")); + ASSERT_TRUE(beta == stash.extractArray(1, "beta")); + ASSERT_TRUE(cappa == stash.extractArray(1, "cappa")); + +} + +#endif //LIBND4J_STASHTESTS_H diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/StringTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/StringTests.cpp new file mode 100644 index 000000000..156831456 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/StringTests.cpp @@ -0,0 +1,880 @@ +/* + * ****************************************************************************** + * * + * * + * * This program and the accompanying materials are made available under the + * * terms of the Apache License, Version 2.0 which is available at + * * https://www.apache.org/licenses/LICENSE-2.0. + * * + * * See the NOTICE file distributed with this work for additional + * * information regarding copyright ownership. + * * Unless required by applicable law or agreed to in writing, software + * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * * License for the specific language governing permissions and limitations + * * under the License. + * * + * * SPDX-License-Identifier: Apache-2.0 + * ***************************************************************************** + */ + +// +// @author raver119@gmail.com +// @author Oleg Semeniv +// + + +#include +#include +#include "testlayers.h" +#include +#include +#include + +using namespace sd; + +class StringTests : public testing::Test { +public: + +}; +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_1) { + std::string f("alpha"); + auto array = NDArrayFactory::string(f); + ASSERT_EQ(sd::DataType::UTF8, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto z = array.e(0); + + ASSERT_EQ(f, z); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_2) { + std::string f("alpha"); + auto array = NDArrayFactory::string(f.c_str()); + ASSERT_EQ(sd::DataType::UTF8, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto z = array.e(0); + + ASSERT_EQ(f, z); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_3) { + + auto array = NDArrayFactory::string({3, 2}, {"alpha", "beta", "gamma", "phi", "theta", "omega"}); + + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_4) { + + NDArray array( { 3, 2 }, std::vector{ U"alpha", U"beta", U"gamma€한", U"pÿqwe", U"ß水𝄋", U"omega" }); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_5) { + + NDArray array( { 3, 2 }, std::vector{ u"alpha", u"beta", u"gamma€한", u"pÿqwe", u"ß水𝄋", u"omega" }); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_6) { + + NDArray array( { 3, 2 }, std::vector{ "alpha", "beta", "gamma€한", "pÿqwe", "ß水𝄋", "omega" }); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_7) { + + NDArray array( { 3, 2 }, std::vector{ U"alpha", U"beta", U"gamma€한", U"pÿqwe", U"ß水𝄋", U"omega" }); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_8) { + + NDArray array( { 3, 2 }, std::vector{ u"alpha", u"beta", u"gamma€한", u"pÿqwe", u"ß水𝄋", u"omega" }); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_9) { + + NDArray array( { 3, 2 }, std::vector{ "alpha", "beta", "gamma€한", "pÿqwe", "ß水𝄋", "omega" }); + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_10) { + + NDArray array(std::u32string(U"gamma€한")); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_11) { + + NDArray array(U"gamma€한"); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_12) { + + NDArray array(std::u16string(u"gamma€한")); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_13) { + + NDArray array(u"gamma€한"); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_14) { + + NDArray array(std::string("gamma€한")); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_15) { + + NDArray array("gamma€한"); + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_16) { + + auto array = NDArrayFactory::string( { 3, 2 }, std::vector{ "alpha", "beta", "gamma", "phi", "theta", "omega" }); + + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_17) { + + auto array = NDArrayFactory::string({ 3, 2 }, std::vector{ "alpha", "beta", "gamma", "phi", "theta", "omega" }); + + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_18) { + + auto array = NDArrayFactory::string({ 3, 2 }, std::vector{ u"alpha", u"beta", u"gamma", u"phi", u"theta", u"omega" }); + + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_19) { + + auto array = NDArrayFactory::string( { 3, 2 }, std::vector{ u"alpha", u"beta", u"gamma", u"phi", u"theta", u"omega" }); + + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_20) { + + auto array = NDArrayFactory::string( { 3, 2 }, std::vector{ U"alpha", U"beta", U"gamma", U"phi", U"theta", U"omega" }); + + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_21) { + + auto array = NDArrayFactory::string( { 3, 2 }, std::vector{ U"alpha", U"òèçùà12345¤z", U"ß水𝄋ÿ€한𐍈®кею90ощъ]ї", U"phi", U"theta", U"omega" }); + + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_22) { + std::u16string f(u"ß水𝄋ÿ€한𐍈®кею90ощъ]ї"); + auto array = NDArrayFactory::string(f.c_str()); + ASSERT_EQ(sd::DataType::UTF16, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto z = array.e(0); + + ASSERT_EQ(f, z); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_23) { + std::u32string f(U"ß水𝄋ÿ€한𐍈®кею90ощъ]ї"); + auto array = NDArrayFactory::string(f.c_str()); + ASSERT_EQ(sd::DataType::UTF32, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto z = array.e(0); + + ASSERT_EQ(f, z); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Export_Test_1) { + auto array = NDArrayFactory::string( {3}, {"alpha", "beta", "gamma"}); + auto vector = array.asByteVector(); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_dup_1) { + std::string f("alpha"); + auto array = NDArrayFactory::string(f); + ASSERT_EQ(sd::DataType::UTF8, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto dup = new NDArray(array.dup()); + + auto z0 = array.e(0); + auto z1 = dup->e(0); + + ASSERT_EQ(f, z0); + ASSERT_EQ(f, z1); + + delete dup; +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, byte_length_test_1) { + std::string f("alpha"); + auto array = NDArrayFactory::string(f); + + ASSERT_EQ(f.length(), StringUtils::byteLength(array)); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, byte_length_test_2) { + auto array = NDArrayFactory::string( {2}, {"alpha", "beta"}); + + ASSERT_EQ(9, StringUtils::byteLength(array)); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, test_split_1) { + auto split = StringUtils::split("alpha beta gamma", " "); + + ASSERT_EQ(3, split.size()); + ASSERT_EQ(std::string("alpha"), split[0]); + ASSERT_EQ(std::string("beta"), split[1]); + ASSERT_EQ(std::string("gamma"), split[2]); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, test_unicode_utf8_utf16) { + + std::string utf8 = u8"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + std::u16string utf16Exp = u"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + + std::u16string utf16Res; + ASSERT_TRUE(StringUtils::u8StringToU16String(utf8, utf16Res)); + + ASSERT_EQ(utf16Res.size(), utf16Exp.size()); + for (auto i = 0; i < utf16Exp.size(); i++) { + ASSERT_EQ(utf16Exp[i], utf16Res[i]); + } +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, test_unicode_utf8_utf32) { + + std::string utf8 = u8"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + std::u32string utf32Exp = U"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + + std::u32string utf32Res; + ASSERT_TRUE(StringUtils::u8StringToU32String(utf8, utf32Res)); + + ASSERT_EQ(utf32Res.size(), utf32Exp.size()); + for (auto i = 0; i < utf32Exp.size(); i++) { + ASSERT_EQ(utf32Exp[i], utf32Res[i]); + } +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, test_unicode_utf16_utf8) { + + std::string utf8Exp = u8"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + std::u16string utf16 = u"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + + std::string utf8Res; + ASSERT_TRUE(StringUtils::u16StringToU8String(utf16, utf8Res)); + + ASSERT_EQ(utf8Res.size(), utf8Exp.size()); + for (auto i = 0; i < utf8Exp.size(); i++) { + ASSERT_EQ(utf8Exp[i], utf8Res[i]); + } +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, test_unicode_utf32_utf8) { + + std::string utf8Exp = u8"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею 90ощъ]їїщkk1q\n\t\rop~"; + std::u32string utf32 = U"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею 90ощъ]їїщkk1q\n\t\rop~"; + + std::string utf8Res; + ASSERT_TRUE(StringUtils::u32StringToU8String(utf32, utf8Res)); + + ASSERT_EQ(utf8Res.size(), utf8Exp.size()); + for (auto i = 0; i < utf8Exp.size(); i++) { + ASSERT_EQ(utf8Exp[i], utf8Res[i]); + } +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, test_unicode_utf16_utf32) { + + std::u32string utf32Exp = U"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + std::u16string utf16 = u"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + + std::u32string utf32Res; + ASSERT_TRUE(StringUtils::u16StringToU32String(utf16, utf32Res)); + + ASSERT_EQ(utf32Res.size(), utf32Exp.size()); + for (auto i = 0; i < utf32Exp.size(); i++) { + ASSERT_EQ(utf32Exp[i], utf32Res[i]); + } +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, test_unicode_utf32_utf16) { + + std::u16string utf16Exp = u"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + std::u32string utf32 = U"\nòèçùà12345¤zß水𝄋ÿ€한𐍈®кею90ощъ]їїщkk1q\n\t\rop~"; + + std::u16string utf16Res; + ASSERT_TRUE(StringUtils::u32StringToU16String(utf32, utf16Res)); + + ASSERT_EQ(utf16Res.size(), utf16Exp.size()); + for (auto i = 0; i < utf16Exp.size(); i++) { + ASSERT_EQ(utf16Exp[i], utf16Res[i]); + } +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, byte_length_test_Default) { + + std::string f("alpha"); + auto array = NDArrayFactory::string(f); + + ASSERT_EQ(f.length(), StringUtils::byteLength(array)); + + std::u16string f16(u"alpha"); + auto array16 = NDArrayFactory::string(f16); + + ASSERT_EQ(sizeof(char16_t)*f16.length(), StringUtils::byteLength(array16)); + + std::u32string f32(U"alpha"); + auto array32 = NDArrayFactory::string(f32); + + ASSERT_EQ(sizeof(char32_t) * f32.length(), StringUtils::byteLength(array32)); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, byte_length_test_UTF16) { + std::string f(u8"alpha"); + auto array = NDArrayFactory::string(f, sd::DataType::UTF16); + + ASSERT_EQ(sizeof(char16_t) * f.length(), StringUtils::byteLength(array)); + + std::u16string f16(u"alpha"); + auto array16 = NDArrayFactory::string(f16, sd::DataType::UTF16); + + ASSERT_EQ(sizeof(char16_t) * f16.length(), StringUtils::byteLength(array16)); + + std::u32string f32(U"alpha"); + auto array32 = NDArrayFactory::string(f32, sd::DataType::UTF16); + + ASSERT_EQ(sizeof(char16_t) * f32.length(), StringUtils::byteLength(array32)); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_UTF16toU8) { + + std::u16string f16(u"alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f16, sd::DataType::UTF8); + ASSERT_EQ(sd::DataType::UTF8, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto z = array.e(0); + + std::string f(u8"alpha水𝄋ÿ€한𐍈®кею"); + ASSERT_EQ(f, z); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_UTF32toU8) { + std::u32string f32(U"alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f32.c_str(), sd::DataType::UTF8); + ASSERT_EQ(sd::DataType::UTF8, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto z = array.e(0); + std::string f(u8"alpha水𝄋ÿ€한𐍈®кею"); + ASSERT_EQ(f, z); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_UTF16toU16) { + + std::u16string f16(u"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f16, sd::DataType::UTF16); + ASSERT_EQ(sd::DataType::UTF16, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + auto z = array.e(0); + + ASSERT_EQ(z, f16); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_UTF32toU16) { + + std::u32string f32(U"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f32, sd::DataType::UTF16); + ASSERT_EQ(sd::DataType::UTF16, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + auto z = array.e(0); + std::u16string f16(u"€alpha水𝄋ÿ€한𐍈®кею"); + ASSERT_EQ(z, f16); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_UTF16toU32) { + + std::u16string f16(u"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f16, sd::DataType::UTF32); + ASSERT_EQ(sd::DataType::UTF32, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto z = array.e(0); + std::u32string fres(U"€alpha水𝄋ÿ€한𐍈®кею"); + ASSERT_EQ(z, fres); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_UTF32toU32) { + + std::u32string f32(U"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f32); + ASSERT_EQ(sd::DataType::UTF32, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + auto z = array.e(0); + ASSERT_EQ(f32, z); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_UTF8toU32) { + + std::string f(u8"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f, sd::DataType::UTF32); + ASSERT_EQ(sd::DataType::UTF32, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + std::u32string f32(U"€alpha水𝄋ÿ€한𐍈®кею"); + auto z = array.e(0); + ASSERT_EQ(f32, z); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_StringVecU8toUTF16) { + auto array = NDArrayFactory::string({ 3, 2 }, { "alpha€", "beta", "gamma水", "phi", "theta", "omega水" }, sd::DataType::UTF16); + + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_StringVecU8toUTF32) { + auto array = NDArrayFactory::string( { 3, 2 }, { "alpha€", "beta水", "gamma", "phi", "theta", "omega" }, sd::DataType::UTF32); + + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Export_Test_U8toUTF16) { + auto array = NDArrayFactory::string({ 3 }, { "alpha", "beta", "gamma" }, sd::DataType::UTF16); + + auto vector = array.asByteVector(); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Export_Test_U8toUTF32) { + auto array = NDArrayFactory::string({ 3 }, { "alpha", "beta", "gamma" }, sd::DataType::UTF32); + + auto vector = array.asByteVector(); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_StringVecU16toUTF16) { + auto array = NDArrayFactory::string({ 3, 2 }, { u"alpha水", u"beta", u"gamma", u"phi", u"theta水", u"omega" }, sd::DataType::UTF16); + + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_StringVecU16toUTF32) { + auto array = NDArrayFactory::string( { 3, 2 }, { u"alpha水", u"beta", u"gamma水", u"phi", u"theta", u"omega" }, sd::DataType::UTF32); + + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_StringVecU16toUTF8) { + auto array = NDArrayFactory::string( { 3, 2 }, { u"alpha€", u"beta水", u"gamma", u"phi水", u"theta", u"omega" }, sd::DataType::UTF8); + + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Export_Test_U16toUTF8) { + auto array = NDArrayFactory::string( { 3 }, { u"alpha", u"beta", u"gamma" }, sd::DataType::UTF8); + + auto vector = array.asByteVector(); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Export_Test_U16toUTF16) { + auto array = NDArrayFactory::string( { 3 }, { u"alpha", u"beta", u"gamma" }, sd::DataType::UTF16); + + auto vector = array.asByteVector(); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Export_Test_U16toUTF32) { + auto array = NDArrayFactory::string( { 3 }, { u"alpha水", u"beta", u"gamma水" }, sd::DataType::UTF32); + + auto vector = array.asByteVector(); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_StringVecU32toUTF32) { + auto array = NDArrayFactory::string( { 3, 2 }, { U"alpha€", U"beta水", U"gamma", U"phi", U"theta", U"omega水" }, sd::DataType::UTF32); + + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_StringVecU32toUTF16) { + auto array = NDArrayFactory::string({ 3, 2 }, { U"alpha水", U"水beta", U"gamma", U"phi水", U"theta", U"omega" }, sd::DataType::UTF16); + + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); + + printf("Array elements size: \n"); + for (int e = 0; e < array.lengthOf(); e++) { + printf("Element %d size: %d\n", e, static_cast(array.e(e).size())); + } +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_Test_StringVecU32toUTF8) { + auto array = NDArrayFactory::string( { 3, 2 }, { U"alpha水", U"beta", U"gamma水", U"phi", U"theta", U"omega" }, sd::DataType::UTF8); + + ASSERT_EQ(6, array.lengthOf()); + ASSERT_EQ(2, array.rankOf()); + + array.printIndexedBuffer("String array"); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Export_Test_U32toUTF32) { + auto array = NDArrayFactory::string( { 3 }, { U"alpha", U"beta", U"gamma" }, sd::DataType::UTF32); + + auto vector = array.asByteVector(); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Export_Test_U32toUTF16) { + auto array = NDArrayFactory::string( { 3 }, { U"alpha", U"beta水", U"gamma水" }, sd::DataType::UTF16); + + auto vector = array.asByteVector(); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Export_Test_U32toUTF8) { + auto array = NDArrayFactory::string( { 3 }, { U"alpha", U"beta", U"gamma水" }, sd::DataType::UTF8); + + auto vector = array.asByteVector(); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_dup_UTF16) { + std::u16string f(u"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f); + ASSERT_EQ(sd::DataType::UTF16, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto dup = new NDArray(array.dup()); + + auto z0 = array.e(0); + auto z1 = dup->e(0); + + ASSERT_EQ(f, z0); + ASSERT_EQ(f, z1); + + delete dup; +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_dup_UTF32) { + std::u32string f(U"€alpha水𝄋ÿ€한𐍈®кею"); + auto array = NDArrayFactory::string(f); + ASSERT_EQ(sd::DataType::UTF32, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto dup = new NDArray(array.dup()); + + auto z0 = array.e(0); + auto z1 = dup->e(0); + + ASSERT_EQ(f, z0); + ASSERT_EQ(f, z1); + + delete dup; +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_cast_UTF32toUTF8) { + + std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею"); + + std::string u8(u8"€alpha水𝄋ÿ€한𐍈®кею"); + + auto array = NDArrayFactory::string(u32); + ASSERT_EQ(sd::DataType::UTF32, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto aCast = array.cast(sd::DataType::UTF8); + + auto z0 = array.e(0); + auto z1 = aCast.e(0); + + ASSERT_EQ(u32, z0); + ASSERT_EQ(u8, z1); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_cast_UTF32toUTF16) { + + std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею"); + + std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею"); + + auto array = NDArrayFactory::string(u32); + ASSERT_EQ(sd::DataType::UTF32, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto aCast = array.cast(sd::DataType::UTF16); + + auto z0 = array.e(0); + auto z1 = aCast.e(0); + + ASSERT_EQ(u32, z0); + ASSERT_EQ(u16, z1); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_cast_UTF32toUTF32) { + + std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею"); + + auto array = NDArrayFactory::string(u32); + ASSERT_EQ(sd::DataType::UTF32, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto aCast = array.cast(sd::DataType::UTF32); + + auto z0 = array.e(0); + auto z1 = aCast.e(0); + + ASSERT_EQ(u32, z0); + ASSERT_EQ(u32, z1); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_cast_UTF16toUTF16) { + + std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею"); + + auto array = NDArrayFactory::string(u16); + ASSERT_EQ(sd::DataType::UTF16, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto aCast = array.cast(sd::DataType::UTF16); + + auto z0 = array.e(0); + auto z1 = aCast.e(0); + + ASSERT_EQ(u16, z0); + ASSERT_EQ(u16, z1); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_cast_UTF16toUTF32) { + + std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею"); + + std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею"); + + auto array = NDArrayFactory::string(u16); + ASSERT_EQ(sd::DataType::UTF16, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto aCast = array.cast(sd::DataType::UTF32); + + auto z0 = array.e(0); + auto z1 = aCast.e(0); + + ASSERT_EQ(u32, z1); + ASSERT_EQ(u16, z0); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_cast_UTF16toUTF8) { + + std::string u8(u8"€alpha水𝄋ÿ€한𐍈®кею"); + + std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею"); + + auto array = NDArrayFactory::string(u16); + ASSERT_EQ(sd::DataType::UTF16, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto aCast = array.cast(sd::DataType::UTF8); + + auto z0 = array.e(0); + auto z1 = aCast.e(0); + + ASSERT_EQ(u8, z1); + ASSERT_EQ(u16, z0); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_cast_UTF8toUTF8) { + + std::string u8("€alpha水𝄋ÿ€한𐍈®кею"); + + auto array = NDArrayFactory::string(u8); + ASSERT_EQ(sd::DataType::UTF8, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto aCast = array.cast(sd::DataType::UTF8); + + auto z0 = array.e(0); + auto z1 = aCast.e(0); + + ASSERT_EQ(u8, z1); + ASSERT_EQ(u8, z0); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_cast_UTF8toUTF16) { + + std::string u8(u8"€alpha水𝄋ÿ€한𐍈®кею"); + + std::u16string u16(u"€alpha水𝄋ÿ€한𐍈®кею"); + + auto array = NDArrayFactory::string(u8); + ASSERT_EQ(sd::DataType::UTF8, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto aCast = array.cast(sd::DataType::UTF16); + + auto z0 = array.e(0); + auto z1 = aCast.e(0); + + ASSERT_EQ(u8, z0); + ASSERT_EQ(u16, z1); +} +///////////////////////////////////////////////////////////////////////// +TEST_F(StringTests, Basic_cast_UTF8toUTF32) { + + std::string u8(u8"€alpha水𝄋ÿ€한𐍈®кею"); + + std::u32string u32(U"€alpha水𝄋ÿ€한𐍈®кею"); + + auto array = NDArrayFactory::string(u8); + ASSERT_EQ(sd::DataType::UTF8, array.dataType()); + + ASSERT_EQ(1, array.lengthOf()); + ASSERT_EQ(0, array.rankOf()); + + auto aCast = array.cast(sd::DataType::UTF32); + + auto z0 = array.e(0); + auto z1 = aCast.e(0); + + ASSERT_EQ(u8, z0); + ASSERT_EQ(u32, z1); +} + +TEST_F(StringTests, test_bit_string_1) { + // check bits -> vector conversion first + auto vec = BitwiseUtils::valueBits(1); + + // check bits -> string conversion next; + auto str = StringUtils::bitsToString(1); + ASSERT_EQ(32, str.length()); + ASSERT_EQ(std::string("00000000000000000000000000000001"), str); +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/SwitchTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/SwitchTests.cpp new file mode 100644 index 000000000..41881e323 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/SwitchTests.cpp @@ -0,0 +1,253 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 13.10.2017. +// + +#include "testlayers.h" +#include + +using namespace sd; +using namespace sd::ops; +using namespace sd::graph; + +class SwitchTests : public testing::Test { +public: + +}; + +TEST_F(SwitchTests, SwitchTest1) { + Graph graph; + + FlowPath flowPath; + + auto variableSpace = graph.getVariableSpace(); + variableSpace->setFlowPath(&flowPath); + + auto input = NDArrayFactory::create_('c',{32, 100}); + input->assign(-119.0f); + + auto condtionX = NDArrayFactory::create_('c', {1, 1}); + condtionX->p(0, 0.0f); + auto condtionY = NDArrayFactory::create_('c', {1, 1}); + condtionY->p(0, 0.0f); + + variableSpace->putVariable(-1, input); + variableSpace->putVariable(-2, condtionX); + variableSpace->putVariable(-3, condtionY); + + // this is just 2 ops, that are executed sequentially. We don't really care bout them + auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); + auto nodeB = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {3}); + + // this is our condition op, we'll be using Equals condition, on variables conditionX and conditionY (ids -2 and -3 respectively) + // we're creating this op manually in tests, as always. + sd::ops::eq_scalar eqOp; + auto nodeCondition = new Node(&eqOp, 119, {-2, -3}); + //nodeCondition->setOpType(OpType_BOOLEAN); + + // now, this is Switch operation. It takes BooleanOperation operation in, + // and based on evaluation result (true/false) - it'll pass data via :0 or :1 output + // other idx will be considered disabled, and that graph branch won't be executed + sd::ops::Switch switchOp; + auto nodeSwitch = new Node(&switchOp, 3, {2, 119}, {4, 5}); + + // these 2 ops are connected to FALSE and TRUE outputs. output :0 considered FALSE, and output :1 considered TRUE + auto nodeZ0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 4, {}, {}); + nodeZ0->pickInput(3, 0); + auto nodeZ1 = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 5, {}, {}); + nodeZ1->pickInput(3, 1); + + + graph.addNode(nodeA); + graph.addNode(nodeB); + graph.addNode(nodeCondition); + graph.addNode(nodeSwitch); + graph.addNode(nodeZ0); + graph.addNode(nodeZ1); + + graph.buildGraph(); + + // we're making sure nodes connected to the Switch have no other inputs in this Graph + ASSERT_EQ(1, nodeZ0->input()->size()); + ASSERT_EQ(1, nodeZ1->input()->size()); + + // just validating topo sort + ASSERT_EQ(0, nodeA->getLayer()); + ASSERT_EQ(0, nodeCondition->getLayer()); + ASSERT_EQ(1, nodeB->getLayer()); + ASSERT_EQ(2, nodeSwitch->getLayer()); + ASSERT_EQ(3, nodeZ0->getLayer()); + ASSERT_EQ(3, nodeZ1->getLayer()); + + // executing graph + Nd4jStatus status = GraphExecutioner::execute(&graph); + + ASSERT_EQ(ND4J_STATUS_OK, status); + + // nd4j_printf("Z0: [%i]; Z1: [%i]\n", flowPath.isNodeActive(nodeZ0->id()), flowPath.isNodeActive(nodeZ1->id())); + + // we know that Switch got TRUE evaluation, so :0 should be inactive + ASSERT_FALSE(flowPath.isNodeActive(nodeZ0->id())); + + // and :1 should be active + ASSERT_TRUE(flowPath.isNodeActive(nodeZ1->id())); + + std::pair unexpected(4,0); + std::pair expectedResultIndex(5,0); + ASSERT_TRUE(variableSpace->hasVariable(expectedResultIndex)); + + // getting output of nodeZ1 + auto output = variableSpace->getVariable(expectedResultIndex)->getNDArray(); + + // and veryfing it against known expected value + ASSERT_NEAR(-118.0f, output->e(0), 1e-5f); +} + +TEST_F(SwitchTests, SwitchTest2) { + Graph graph; + + FlowPath flowPath; + auto variableSpace = graph.getVariableSpace(); + variableSpace->setFlowPath(&flowPath); + + auto input = NDArrayFactory::create_('c',{32, 100}); + input->assign(-119.0f); + + auto condtionX = NDArrayFactory::create_('c', {1, 1}); + condtionX->p(0, 1.0f); + auto condtionY = NDArrayFactory::create_('c', {1, 1}); + condtionY->p(0, 1.0f); + + + variableSpace->putVariable(-1, input); + variableSpace->putVariable(-2, condtionX); + variableSpace->putVariable(-3, condtionY); + + + auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); + auto nodeB = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {3}); + + auto scopeCondition = new Node(OpType_LOGIC, logic::Scope, 3); + scopeCondition->setName("scopeCondition"); + + auto nodeCondition = new Node(OpType_LOGIC, logic::Scope, 119, {-2, -3}); + nodeCondition->setScopeInfo(3, "scopeCondition"); + + sd::ops::eq_scalar eqOp; + nodeCondition->setCustomOp(&eqOp); + + auto nodeSwitch = new Node(OpType_LOGIC, logic::Switch, 5, {3, 2}); + + sd::ops::Switch switchOp; + nodeSwitch->setCustomOp(&switchOp); + + + // these 2 ops are connected to FALSE and TRUE outputs. output :0 considered FALSE, and output :1 considered TRUE + auto nodeZ0 = new Node(OpType_TRANSFORM_SAME, transform::Abs, 6, {}, {}); + nodeZ0->pickInput(5, 0); + auto nodeZ1 = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 7, {}, {}); + nodeZ1->pickInput(5, 1); + + graph.addNode(nodeA); + graph.addNode(nodeB); + graph.addNode(scopeCondition); + graph.addNode(nodeCondition); + graph.addNode(nodeSwitch); + graph.addNode(nodeZ0); + graph.addNode(nodeZ1); + + Nd4jStatus status = GraphExecutioner::execute(&graph); + + ASSERT_EQ(ND4J_STATUS_OK, status); + + ASSERT_TRUE(!flowPath.isNodeActive(nodeZ0->id())); + ASSERT_TRUE(flowPath.isNodeActive(nodeZ1->id())); + + auto z = graph.getVariableSpace()->getVariable(7)->getNDArray(); + + // abs(-119) = 119; 1 - 119 = -118 + ASSERT_NEAR(-118.f, z->e(0), 1e-5); +} + +TEST_F(SwitchTests, SwitchTest3) { + Graph graph; + + FlowPath flowPath; + auto variableSpace = graph.getVariableSpace(); + variableSpace->setFlowPath(&flowPath); + + auto input = NDArrayFactory::create_('c',{32, 100}); + input->assign(-119.0f); + + auto condtionX = NDArrayFactory::create_('c', {1, 1}); + condtionX->p(0, 2.0f); + auto condtionY = NDArrayFactory::create_('c', {1, 1}); + condtionY->p(0, 1.0f); + + + variableSpace->putVariable(-1, input); + variableSpace->putVariable(-2, condtionX); + variableSpace->putVariable(-3, condtionY); + + + auto nodeA = new Node(OpType_TRANSFORM_SAME, transform::Abs, 1, {-1}, {2}); + auto nodeB = new Node(OpType_TRANSFORM_SAME, transform::Abs, 2, {1}, {3}); + + auto scopeCondition = new Node(OpType_LOGIC, logic::Scope, 3); + scopeCondition->setName("scopeCondition"); + + auto nodeCondition = new Node(OpType_LOGIC, logic::Scope, 119, {-2, -3}); + nodeCondition->setScopeInfo(3, "scopeCondition"); + + sd::ops::eq_scalar eqOp; + nodeCondition->setCustomOp(&eqOp); + + auto nodeSwitch = new Node(OpType_LOGIC, logic::Switch, 5, {3, 2}); + + sd::ops::Switch switchOp; + nodeSwitch->setCustomOp(&switchOp); + + + // these 2 ops are connected to FALSE and TRUE outputs. output :0 considered FALSE, and output :1 considered TRUE + auto nodeZ0 = new Node(OpType_TRANSFORM_SAME, transform::Neg, 6, {}, {}); + nodeZ0->pickInput(5, 0); + auto nodeZ1 = new Node(OpType_TRANSFORM_SAME, transform::OneMinus, 7, {}, {}); + nodeZ1->pickInput(5, 1); + + graph.addNode(nodeA); + graph.addNode(nodeB); + graph.addNode(scopeCondition); + graph.addNode(nodeCondition); + graph.addNode(nodeSwitch); + graph.addNode(nodeZ0); + graph.addNode(nodeZ1); + + Nd4jStatus status = GraphExecutioner::execute(&graph); + + ASSERT_EQ(ND4J_STATUS_OK, status); + + ASSERT_TRUE(flowPath.isNodeActive(nodeZ0->id())); + ASSERT_TRUE(!flowPath.isNodeActive(nodeZ1->id())); + + auto z = graph.getVariableSpace()->getVariable(6)->getNDArray(); + + // abs(-119) = 119; Neg(119) = 119 + ASSERT_NEAR(-119.f, z->e(0), 1e-5); +} diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/TadTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/TadTests.cpp new file mode 100644 index 000000000..e7a0538a4 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/TadTests.cpp @@ -0,0 +1,445 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef LIBND4J_TADTESTS_H +#define LIBND4J_TADTESTS_H + +#include "testlayers.h" +#include +#include +#include +#include + +using namespace sd; + +class TadTests : public testing::Test { +public: + int numLoops = 100000000; + + int extLoops = 1000; + int intLoops = 1000; +}; + +TEST_F(TadTests, Test4DTad1) { + + NDArray* arraySource = sd::NDArrayFactory::linspace(1.0f, 10000.0f, 10000); + + Nd4jLong badShape[] = {4, 2, 1, 4, 4, 80, 16, 4, 1, 8192, -1, 99}; + Nd4jLong goodShape[] = {4, 2, 1, 4, 4, 16, 16, 4, 1, 8192, 1, 99}; + + std::vector buff = arraySource->getBufferAsVector(); + + NDArray* arrayExp = new NDArray(buff.data(), goodShape); + NDArray* arrayBad = new NDArray(buff.data(), badShape); + + int dim = 1; + shape::TAD tad; + tad.init(arrayBad->shapeInfo(), &dim, 1); + tad.createTadOnlyShapeInfo(); + tad.createOffsets(); + + int exp[] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95 }; + for (int e = 0; e < 32; e++) + ASSERT_EQ((int) tad.tadOffsets[e], exp[e]); + + delete arrayExp; + delete arrayBad; + delete arraySource; +} + +TEST_F(TadTests, TestNumTads1) { + auto x = NDArrayFactory::create('c', {2, 3}); + auto y = NDArrayFactory::create('c', {2, 2}); + + std::vector dim({0}); + + Nd4jLong tadLengthX = shape::tadLength(x.shapeInfo(), dim.data(), dim.size()); + Nd4jLong numTadsX = x.lengthOf() / tadLengthX; + + Nd4jLong tadLengthY = shape::tadLength(y.shapeInfo(), dim.data(), dim.size()); + Nd4jLong numTadsY = y.lengthOf() / tadLengthY; + + ASSERT_EQ(2, tadLengthX); + ASSERT_EQ(3, numTadsX); + + ASSERT_EQ(2, tadLengthY); + ASSERT_EQ(2, numTadsY); +} + +TEST_F(TadTests, TestShapeTad_1) { + + float buff[] = {1,2,3,4,5,6,7,8,9,10,11,12,13,14,16,16,17,18,19,20,21,22,23,24}; + Nd4jLong shapeInfo[] = {3, 2, 3, 4, 12, 4, 1, 8192, 1, 99}; + + NDArray input(buff, shapeInfo); + + std::vector dimensions = {0,1,2}; + Nd4jLong tadLength = shape::tadLength(input.shapeInfo(), dimensions.data(), dimensions.size()); + Nd4jLong numTads = input.lengthOf() / tadLength; + + shape::TAD tad; + tad.init(input.shapeInfo(), dimensions.data(), dimensions.size()); + tad.createTadOnlyShapeInfo(); + tad.createOffsets(); + + auto tadShapeInfo = new Nd4jLong[shape::shapeInfoLength(tad.tadOnlyShapeInfo[0])]; + std::memcpy(tadShapeInfo, tad.tadOnlyShapeInfo, shape::shapeInfoByteLength(tad.tadOnlyShapeInfo)); + + float* tadBuff = reinterpret_cast(input.buffer()) + tad.tadOffsets[0]; + NDArray tadArr(tadBuff, tadShapeInfo); + + ASSERT_TRUE(numTads==1); + ASSERT_TRUE(input.isSameShapeStrict(tadArr)); + ASSERT_TRUE(input.equalsTo(&tadArr)); + + delete[] tadShapeInfo; +} + +TEST_F(TadTests, TadNoAxis_1) { + auto array = NDArrayFactory::create('c', {2, 3}); + + shape::TAD tad; + tad.init(array.shapeInfo(), nullptr, 0); + tad.createTadOnlyShapeInfo(); + tad.createOffsets(); + + ASSERT_TRUE(tad.wholeThing); + + ASSERT_TRUE(shape::equalsStrict(tad.tadOnlyShapeInfo, array.shapeInfo())); +} + +TEST_F(TadTests, TadEdgeCase_1) { + auto array = NDArrayFactory::create('c', {5, 4, 1}); + auto exp = NDArrayFactory::create('c', {5, 4}); + array.linspace(1); + + auto tad = array(0, {2}); + + ASSERT_TRUE(exp.isSameShape(tad)); +} + +TEST_F(TadTests, TestEdgeCase_2) { + + auto array = NDArrayFactory::create('f', {2, 3, 1}, {1, 4, 2, 5, 3, 6}); + + for (int e = 0 ; e < array.lengthOf(); e++) { + auto tad = array(e, {0,1}); + ASSERT_NEAR(tad.e(0), array.e(e), 1e-5); + } +} + +TEST_F(TadTests, TadEdgeCase_2) { + auto array = NDArrayFactory::create('c', {2, 3, 4}); + + auto tad = array(0, {0,2}); + + ASSERT_EQ(3, tad.lengthOf()); +} + + +TEST_F(TadTests, test_Tad_Ews_optimization_1) { + shape::TAD xTad; + + std::array array = {1,2}; + ASSERT_TRUE(xTad.dimensionsDescending(3, array.data(), array.size())); +} + +TEST_F(TadTests, test_Tad_Ews_optimization_2) { + shape::TAD xTad; + + std::array array = {0,2}; + ASSERT_FALSE(xTad.dimensionsDescending(3, array.data(), array.size())); +} + +TEST_F(TadTests, test_Tad_Ews_optimization_3) { + shape::TAD xTad; + + std::array array = {1}; + ASSERT_TRUE(xTad.dimensionsDescending(2, array.data(), array.size())); +} + +TEST_F(TadTests, test_Tad_Ews_optimization_4) { + shape::TAD xTad; + + std::array array = {0}; + ASSERT_TRUE(xTad.dimensionsDescending(1, array.data(), array.size())); +} + +TEST_F(TadTests, test_Tad_Ews_optimization_5) { + shape::TAD xTad; + + std::array array = {2,3}; + ASSERT_TRUE(xTad.dimensionsDescending(4, array.data(), array.size())); +} + +TEST_F(TadTests, test_TAD_empty_dims_1) { + Nd4jLong xShape[8] = {2, 150, 1, 3, 1, 16384, 3, 99}; + shape::TAD xTad; + xTad.init(xShape, reinterpret_cast(112L), 0); + xTad.createTadOnlyShapeInfo(); + xTad.createOffsets(); +} + +TEST_F(TadTests, test_tad_order_1) { + Nd4jLong xShape[8] = {2, 150, 10, 10, 1, 8192, 1, 99}; + Nd4jLong tShape[8] = {2, 1, 10, 1, 1, 8192, 1, 99}; + shape::TAD xTad; + int dim = 1; + xTad.init(xShape, &dim, 1); + xTad.createTadOnlyShapeInfo(); + + ASSERT_TRUE(shape::equalsStrict(tShape, xTad.tadOnlyShapeInfo)); +} + +TEST_F(TadTests, test_tad_order_2) { + Nd4jLong xShape[8] = {2, 150, 10, 10, 1, 8192, 1, 99}; + Nd4jLong tShape[8] = {2, 1, 150, 1, 10, 8192, 10, 99}; + shape::TAD xTad; + int dim = 0; + xTad.init(xShape, &dim, 1); + xTad.createTadOnlyShapeInfo(); + + ASSERT_TRUE(shape::equalsStrict(tShape, xTad.tadOnlyShapeInfo)); +} + + +TEST_F(TadTests, test_tad_order_3) { + Nd4jLong xShape[10] = {3, 10, 20, 30, 600 ,30, 1, 8192, 1, 99}; + Nd4jLong tShape[8] = {2, 1, 30, 1, 1, 8192, 1, 99}; + shape::TAD xTad; + int dim = 2; + xTad.init(xShape, &dim, 1); + xTad.createTadOnlyShapeInfo(); + + ASSERT_TRUE(shape::equalsStrict(tShape, xTad.tadOnlyShapeInfo)); +} + + +TEST_F(TadTests, test_tad_order_4) { + Nd4jLong xShape[10] = {3, 10, 20, 30, 600 ,30, 1, 8192, 1, 99}; + Nd4jLong tShape[8] = {2, 20, 30, 30, 1, 8192, 1, 99}; + shape::TAD xTad; + int dim[2] = {1, 2}; + xTad.init(xShape, dim, 2); + xTad.createTadOnlyShapeInfo(); + + ASSERT_TRUE(shape::equalsStrict(tShape, xTad.tadOnlyShapeInfo)); +} + +TEST_F(TadTests, test_column_1) { + auto x = NDArrayFactory::create('c', {5, 2}); + auto tadPack = sd::ConstantTadHelper::getInstance().tadForDimensions(x.shapeInfo(), 0); + + ASSERT_EQ(1, shape::rank(tadPack.primaryShapeInfo())); + ASSERT_EQ(5, shape::length(tadPack.primaryShapeInfo())); + ASSERT_TRUE(shape::isVector(tadPack.primaryShapeInfo())); + + auto scalarViewPack = sd::ConstantTadHelper::getInstance().tadForDimensions(tadPack.primaryShapeInfo(), 0); + + ASSERT_TRUE(shape::equalsStrict(tadPack.primaryShapeInfo(), scalarViewPack.primaryShapeInfo())); +} + +/////////////////////////////////////////////////////////////////// +TEST_F(TadTests, calcOffsets_1) { + + Nd4jLong shapeInfoF[10] = {3, 2,3,4, 1,2,6, 8192, 1, 102}; + Nd4jLong shapeInfoC[10] = {3, 2,3,4, 12,4,1, 8192, 1, 99}; + Nd4jLong shapeInfoFC[10] = {3, 2,3,4, 1,2,6, 8192, 1, 99};; + + Nd4jLong expOffsetsF[24] = {0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23}; + Nd4jLong expOffsetsC[24] = {0,12,4,16,8,20,1,13,5,17,9,21,2,14,6,18,10,22,3,15,7,19,11,23}; + + Nd4jLong offsets[24]; + + shape::calcOffsets(shapeInfoF, offsets, 'f'); + + for (int e = 0; e < 24; e++) + ASSERT_TRUE(offsets[e] == expOffsetsF[e]); + + shape::calcOffsets(shapeInfoC, offsets, 'f'); + + for (int e = 0; e < 24; e++) + ASSERT_TRUE(offsets[e] == expOffsetsC[e]); + + shape::calcOffsets(shapeInfoFC, offsets, 'f'); + + for (int e = 0; e < 24; e++) + ASSERT_TRUE(offsets[e] == expOffsetsF[e]); +} + + +///////////////////////////////////////////////////////////////// +TEST_F(TadTests, outerArrayIndexes_1) { + + NDArray x('c', {2,3,4,5}, sd::DataType::FLOAT32); + int maxIdxs[120]; + + NDArray y1('c', {3,5}, sd::DataType::FLOAT32); + const std::vector dimsToExclude1 = {0,2}; + const int n1[] = {20,25,30,35, 80,85,90,95}; + int minIdx = 5; + + int N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y1.shapeInfo(), dimsToExclude1.data()); + ASSERT_TRUE(N == x.lengthOf()/y1.lengthOf()); + for(int i = 0; i < N; ++i) + ASSERT_TRUE(n1[i] == maxIdxs[i]); + + NDArray y2('c', {4,5}, sd::DataType::FLOAT32); + const std::vector dimsToExclude2 = {0,1}; + const int n2[] = {12,32,52, 72,92,112}; + minIdx = 12; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y2.shapeInfo(), dimsToExclude2.data()); + ASSERT_TRUE(N == x.lengthOf()/y2.lengthOf()); + for(int i = 0; i < N; ++i) + ASSERT_TRUE(n2[i] == maxIdxs[i]); + + NDArray y3('c', {2,5}, sd::DataType::FLOAT32); + const std::vector dimsToExclude3 = {1,2}; + const int n3[] = {64,69,74,79,84,89,94,99,104,109,114,119}; + minIdx = 9; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y3.shapeInfo(), dimsToExclude3.data()); + ASSERT_TRUE(N == x.lengthOf()/y3.lengthOf()); + for(int i = 0; i < N; ++i) + ASSERT_TRUE(n3[i] == maxIdxs[i]); + + NDArray y4('c', {2,3}, sd::DataType::FLOAT32); + const std::vector dimsToExclude4 = {2,3}; + const int n4[] = {20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39}; + minIdx = 1; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y4.shapeInfo(), dimsToExclude4.data()); + ASSERT_TRUE(N == x.lengthOf()/y4.lengthOf()); + for(int i = 0; i < N; ++i) + ASSERT_TRUE(n4[i] == maxIdxs[i]); + + NDArray y5('c', {2,4}, sd::DataType::FLOAT32); + const std::vector dimsToExclude5 = {1,3}; + const int n5[] = {65,66,67,68,69, 85,86,87,88,89, 105,106,107,108,109}; + minIdx = 5; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y5.shapeInfo(), dimsToExclude5.data()); + ASSERT_TRUE(N == x.lengthOf()/y5.lengthOf()); + for(int i = 0; i < N; ++i) + ASSERT_TRUE(n5[i] == maxIdxs[i]); + + NDArray y6('c', {2,3,4}, sd::DataType::FLOAT32); + const std::vector dimsToExclude6 = {3}; + const int n6[] = {65,66,67,68,69}; + minIdx = 13; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y6.shapeInfo(), dimsToExclude6.data()); + ASSERT_TRUE(N == x.lengthOf()/y6.lengthOf()); + for(int i = 0; i < N; ++i) + ASSERT_TRUE(n6[i] == maxIdxs[i]); + + NDArray y7('c', {4}, sd::DataType::FLOAT32); + const std::vector dimsToExclude7 = {0,1,3}; + const int n7[] = {15,16,17,18,19, 35,36,37,38,39, 55,56,57,58,59, 75,76,77,78,79, 95,96,97,98,99, 115,116,117,118,119}; + minIdx = 3; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y7.shapeInfo(), dimsToExclude7.data()); + ASSERT_TRUE(N == x.lengthOf()/y7.lengthOf()); + for(int i = 0; i < N; ++i) + ASSERT_TRUE(n7[i] == maxIdxs[i]); + + NDArray y8('c', {5}, sd::DataType::FLOAT32); + const std::vector dimsToExclude8 = {0,1,2}; + const int n8[] = {0,5,10,15, 20,25,30,35, 40,45,50,55, 60,65,70,75, 80,85,90,95, 100,105,110,115}; + minIdx = 0; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y8.shapeInfo(), dimsToExclude8.data()); + ASSERT_TRUE(N == x.lengthOf()/y8.lengthOf()); + for(int i = 0; i < N; ++i) + ASSERT_TRUE(n8[i] == maxIdxs[i]); + + NDArray y9('c', {2}, sd::DataType::FLOAT32); + const std::vector dimsToExclude9 = {1,2,3}; + const int n9[] = {60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119}; + minIdx = 1; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y9.shapeInfo(), dimsToExclude9.data()); + ASSERT_TRUE(N == x.lengthOf()/y9.lengthOf()); + for(int i = 0; i < N; ++i) + ASSERT_TRUE(n9[i] == maxIdxs[i]); + + NDArray y10('c', {3,4,5}, sd::DataType::FLOAT32); + const std::vector dimsToExclude10 = {0}; + const int n10[] = {11, 71}; + minIdx = 11; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y10.shapeInfo(), dimsToExclude10.data()); + ASSERT_TRUE(N == x.lengthOf()/y10.lengthOf()); + for(int i = 0; i < N; ++i) + ASSERT_TRUE(n10[i] == maxIdxs[i]); + + NDArray y11('c', {2,4,5}, sd::DataType::FLOAT32); + const std::vector dimsToExclude11 = {1}; + const int n11[] = {66, 86, 106}; + minIdx = 26; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y11.shapeInfo(), dimsToExclude11.data()); + ASSERT_TRUE(N == x.lengthOf()/y11.lengthOf()); + for(int i = 0; i < N; ++i) + ASSERT_TRUE(n11[i] == maxIdxs[i]); + + NDArray y12('c', {3,2}, sd::DataType::FLOAT32); + const std::vector dimsToExclude12 = {0,2}; + const int n12[] = {0,2,4,5,7,9,10,12,14,15,17,19,60,62,64,65,67,69,70,72,74,75,77,79}; + minIdx = 0; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y12.shapeInfo(), dimsToExclude12.data()); + for(int i = 0; i < N; ++i) + ASSERT_TRUE(n12[i] == maxIdxs[i]); + + NDArray y13('c', {3,2}, sd::DataType::FLOAT32); + const std::vector dimsToExclude13 = {0,2}; + const int n13[] = {1,3,6,8,11,13,16,18,61,63,66,68,71,73,76,78}; + minIdx = 1; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y13.shapeInfo(), dimsToExclude13.data()); + for(int i = 0; i < N; ++i) + ASSERT_TRUE(n13[i] == maxIdxs[i]); + + NDArray y14('c', {4,5}, sd::DataType::FLOAT32); + const int n14[] = {12,32,52, 72,92,112}; + minIdx = 12; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y14.shapeInfo(), nullptr); + ASSERT_TRUE(N == x.lengthOf()/y14.lengthOf()); + for(int i = 0; i < N; ++i) + ASSERT_TRUE(n14[i] == maxIdxs[i]); + + NDArray y15('c', {3,4,5}, sd::DataType::FLOAT32); + const int n15[] = {11, 71}; + minIdx = 11; + + N = shape::outerArrayIndexes(maxIdxs, minIdx, x.shapeInfo(), y15.shapeInfo(), nullptr); + ASSERT_TRUE(N == x.lengthOf()/y15.lengthOf()); + for(int i = 0; i < N; ++i) + ASSERT_TRUE(n15[i] == maxIdxs[i]); +} + + + +#endif //LIBND4J_TADTESTS_H diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ThreadsTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ThreadsTests.cpp new file mode 100644 index 000000000..e9965b077 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/ThreadsTests.cpp @@ -0,0 +1,271 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include + +using namespace samediff; +using namespace sd; +using namespace sd::ops; +using namespace sd::graph; + +class ThreadsTests : public testing::Test { +public: + ThreadsTests() { + nd4j_printf("\n",""); + } +}; + +TEST_F(ThreadsTests, th_test_1) { + ASSERT_EQ(1, ThreadsHelper::numberOfThreads(6, 1023)); + ASSERT_EQ(1, ThreadsHelper::numberOfThreads(6, 1024)); + ASSERT_EQ(1, ThreadsHelper::numberOfThreads(6, 1026)); + + ASSERT_EQ(1, ThreadsHelper::numberOfThreads(6, 2043)); + ASSERT_EQ(2, ThreadsHelper::numberOfThreads(6, 2048)); +} + + +TEST_F(ThreadsTests, th_test_2) { + // in this case we'll get better split over second loop - exactly 32 elements per thread + ASSERT_EQ(2, ThreadsHelper::pickLoop2d(32, 48, 1024)); + ASSERT_EQ(2, ThreadsHelper::pickLoop2d(6, 4, 16384)); + + // in this case we'll get better split over first loop - 2 loops/2048 elements per thread + ASSERT_EQ(1, ThreadsHelper::pickLoop2d(32, 64, 1024)); + ASSERT_EQ(1, ThreadsHelper::pickLoop2d(6, 6, 16384)); + + // in this case none of loops are good enough, but second loop is too small for split + ASSERT_EQ(1, ThreadsHelper::pickLoop2d(6, 64, 32)); + + // all loops are good enough, but we go with bigger one, since small + ASSERT_EQ(1, ThreadsHelper::pickLoop2d(2, 64, 32)); + + // obviously split goes into second loop, to give 1024 elements per thread + ASSERT_EQ(2, ThreadsHelper::pickLoop2d(2, 1, 2048)); +} + +TEST_F(ThreadsTests, th_test_3) { + // typical conv cases + ASSERT_EQ(1, ThreadsHelper::pickLoop3d(4, 32, 3, 128)); + ASSERT_EQ(2, ThreadsHelper::pickLoop3d(4, 1, 128, 64)); + ASSERT_EQ(3, ThreadsHelper::pickLoop3d(4, 1, 3, 128)); + + // checking for optimal threads for conv inference + ASSERT_EQ(6, ThreadsHelper::numberOfThreads3d(6, 1, 3, 128)); + ASSERT_EQ(4, ThreadsHelper::numberOfThreads3d(4, 1, 3, 128)); + ASSERT_EQ(8, ThreadsHelper::numberOfThreads3d(8, 1, 3, 128)); + + // checking for optimal threads for conv training + ASSERT_EQ(6, ThreadsHelper::numberOfThreads3d(6, 16, 3, 128)); + ASSERT_EQ(6, ThreadsHelper::numberOfThreads3d(6, 8, 3, 128)); + + + ASSERT_EQ(6, ThreadsHelper::numberOfThreads3d(6, 8, 3, 64)); + ASSERT_EQ(1, ThreadsHelper::pickLoop3d(6, 8, 3, 64)); +} + +TEST_F(ThreadsTests, th_test_5) { + ASSERT_EQ(6, ThreadsHelper::numberOfThreads3d(6, 32, 112, 112)); + + ASSERT_EQ(1, ThreadsHelper::pickLoop3d(6, 32, 112, 112)); + + for (auto e = 0; e < 6; e++) { + auto span = Span3::build(1, e, 6, 0, 32, 1, 0, 112, 1, 0, 112, 1); + + nd4j_printf("Span start: %lld; stop: %lld\n", span.startX(), span.stopX()); + } +} + +TEST_F(ThreadsTests, th_test_4) { + // typical conv cases + ASSERT_EQ(2, ThreadsHelper::numberOfThreads2d(2, 32, 3)); + ASSERT_EQ(4, ThreadsHelper::numberOfThreads2d(4, 32, 3)); + ASSERT_EQ(6, ThreadsHelper::numberOfThreads2d(6, 32, 1)); + ASSERT_EQ(8, ThreadsHelper::numberOfThreads2d(8, 16, 64)); + + ASSERT_EQ(1, ThreadsHelper::pickLoop2d(4, 32, 1)); + ASSERT_EQ(1, ThreadsHelper::pickLoop2d(8, 19, 17)); + + // primes edge cases + ASSERT_EQ(6, ThreadsHelper::numberOfThreads2d(6, 19, 17)); + ASSERT_EQ(8, ThreadsHelper::numberOfThreads2d(8, 19, 17)); + + ASSERT_EQ(1, ThreadsHelper::pickLoop2d(8, 19, 17)); + + for (auto e = 0; e < 6; e++) { + auto span = Span2::build(1, e, 6, 0, 19, 1, 0, 17, 1); + + nd4j_printf("Span start: %lld; stop: %lld\n", span.startX(), span.stopX()); + } + + nd4j_printf("-----------------------\n",""); + for (auto e = 0; e < 6; e++) { + auto span = Span2::build(1, e, 6, 0, 32, 1, 0, 3, 1); + + nd4j_printf("Span start: %lld; stop: %lld\n", span.startX(), span.stopX()); + } +} + + +TEST_F(ThreadsTests, test_span_converage_1) { + for (int b = 1; b <= 128; b++) { + for (int c = 1; c <= 64; c++) { + for (int t = 1; t <= 64; t++) { + + auto threads = ThreadsHelper::numberOfThreads2d(t, b, c); + auto loop = ThreadsHelper::pickLoop2d(threads, b, c); + + if (t > 1 && threads == 1 && (b > 1 && c > 1)) { + nd4j_printf("Got 1 thread for [%i, %i] loop; initial max threads: %i\n", b, c, t) + } + + auto sum = 0; + for (auto a = 0; a < threads; a++) { + auto span = Span2::build(loop, a,threads, 0, b, 1, 0, c, 1); + + if (loop == 1) + sum += span.stopX() - span.startX(); + else if (loop == 2) + sum += span.stopY() - span.startY(); + else + throw std::runtime_error("Bad loop!"); + } + + if (loop == 1) + ASSERT_EQ(b, sum); + else + ASSERT_EQ(c, sum); + } + } + } +} + +TEST_F(ThreadsTests, validation_test_2d_1) { + if (1 > 0) + return; + + std::vector threads({1, 2, 4, 6, 8, 12, 16, 20, 32, 48, 64}); + + for (int e = 1; e < 1024; e++) { + for (int i = 1; i <= 1024; i++ ) { + for (auto t:threads) { + std::atomic sum; + sum.store(0); + + auto func = PRAGMA_THREADS_FOR_2D { + for (auto x = start_x; x < stop_x; x += inc_x) { + for (auto y = start_y; y < stop_y; y += inc_y) { + sum++; + } + } + }; + + samediff::Threads::parallel_for(func, 0, e, 1, 0, i, 1, t, true); + + ASSERT_EQ(e * i, sum.load()); + } + } + + nd4j_printf("Finished iteration %i\n", e); + } +} + +TEST_F(ThreadsTests, reduction_test_1) { + + auto func = PRAGMA_REDUCE_LONG { + int64_t sum = 0; + + for (auto e = start; e < stop; e++) { + sum++; + }; + + return sum; + }; + + auto sum = samediff::Threads::parallel_long(func, LAMBDA_AL {return _old + _new;}, 0, 8192, 1, 4); + ASSERT_EQ(8192, sum); +} + +static void _code(int thread_id) { + auto x = NDArrayFactory::create('c', {65536 * 16}); + x.assign(1.1f); +} + +TEST_F(ThreadsTests, crash_test_1) { + if (!Environment::getInstance().isCPU()) + return; + + for (int e = 0; e < 3; e++) { + std::vector threads(std::thread::hardware_concurrency()); + + // creating some threads + for (int t = 0; t < threads.size(); t++) + threads[t] = std::thread(_code, t); + + // blocking until everything is finished + for (auto &t:threads) + t.join(); + } +} + +/* +TEST_F(ThreadsTests, basic_test_1) { + if (!Environment::getInstance().isCPU()) + return; + + auto instance = samediff::ThreadPool::getInstance(); + + auto array = NDArrayFactory::create('c', {512, 768}); + auto like = array.like(); + auto buffer = array.bufferAsT(); + auto lbuffer = like.bufferAsT(); + + auto func = PRAGMA_THREADS_FOR { + PRAGMA_OMP_SIMD + for (uint64_t e = start; e < stop; e += increment) { + buffer[e] += 1.0f; + } + }; + + auto timeStartThreads = std::chrono::system_clock::now(); + samediff::Threads::parallel_for(func, 0, array.lengthOf()); + auto timeEndThreads = std::chrono::system_clock::now(); + auto outerTimeThreads = std::chrono::duration_cast (timeEndThreads - timeStartThreads).count(); + + auto timeStartOmp = std::chrono::system_clock::now(); + PRAGMA_OMP_PARALLEL_FOR_SIMD + for (uint64_t e = 0; e < array.lengthOf(); e ++) { + lbuffer[e] += 1.0f; + } + auto timeEndOmp = std::chrono::system_clock::now(); + auto outerTimeOmp = std::chrono::duration_cast (timeEndOmp - timeStartOmp).count(); + + ASSERT_NEAR((float) array.lengthOf(), array.sumNumber().e(0), 1e-5f); + + nd4j_printf("Threads time: %lld us; OMP time: %lld us; %p\n", outerTimeThreads, outerTimeOmp, instance) +} + */ \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/TypeCastTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/TypeCastTests.cpp new file mode 100644 index 000000000..351d64482 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/TypeCastTests.cpp @@ -0,0 +1,74 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by raver119 on 02/07/18. +// + +#include "testlayers.h" +#include +#include + +using namespace sd; +using namespace sd::ops; +using namespace sd::graph; + +class TypeCastTests : public testing::Test { +public: + +}; + +TEST_F(TypeCastTests, Test_Cast_1) { +#ifndef __CUDABLAS__ + const int limit = 100; + auto src = new double[limit]; + auto z = new float[limit]; + auto exp = new float[limit]; + + for (int e = 0; e < limit; e++) { + src[e] = static_cast(e); + exp[e] = static_cast(e); + } + + TypeCast::convertGeneric(nullptr, reinterpret_cast(src), limit, reinterpret_cast(z)); + + for (int e = 0; e < limit; e++) { + ASSERT_NEAR(exp[e], z[e], 1e-5f); + } + + delete[] src; + delete[] z; + delete[] exp; +#endif +} + +TEST_F(TypeCastTests, Test_ConvertDtype_1) { + + #ifndef __CUDABLAS__ + + float src[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; + float16 dst[5]; + float16 exp[] = {(float16) 1.0f, (float16) 2.0f, (float16) 3.0f, (float16) 4.0f, (float16) 5.0f}; + + convertTypes(nullptr, ND4J_FLOAT32, src, 5, ND4J_FLOAT16, dst); + + for (int e = 0; e < 5; e++) + ASSERT_NEAR(exp[e], dst[e], (float16) 0.01f); + + #endif +} diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/VariableProxyTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/VariableProxyTests.cpp new file mode 100644 index 000000000..9cc433a04 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/VariableProxyTests.cpp @@ -0,0 +1,175 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include + +using namespace sd; +using namespace sd::graph; + +class VariableProxyTests : public testing::Test { +public: + +}; + + +TEST_F(VariableProxyTests, Test_Simple_1) { + auto x = NDArrayFactory::create_('c', {2, 2}, {1, 2, 3, 4}); + VariableSpace ref; + + ref.putVariable(119, x); + + ASSERT_TRUE(ref.hasVariable(119)); + + VariableProxy proxy(&ref); + + ASSERT_TRUE(proxy.hasVariable(119)); +} + + +TEST_F(VariableProxyTests, Test_Simple_2) { + auto x = NDArrayFactory::create_('c', {2, 2}, {1, 2, 3, 4}); + VariableSpace ref; + + ASSERT_FALSE(ref.hasVariable(119)); + + VariableProxy proxy(&ref); + + ASSERT_FALSE(proxy.hasVariable(119)); + + proxy.putVariable(119, x); + + ASSERT_FALSE(ref.hasVariable(119)); + + ASSERT_TRUE(proxy.hasVariable(119)); +} + + +TEST_F(VariableProxyTests, Test_Simple_3) { + auto x = NDArrayFactory::create_('c', {2, 2}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create_('c', {2, 2}, {4, 2, 3, 1}); + VariableSpace ref; + + ref.putVariable(119, x); + + ASSERT_TRUE(ref.hasVariable(119)); + + VariableProxy proxy(&ref); + + ASSERT_TRUE(proxy.hasVariable(119)); + + proxy.putVariable(119, y); + + ASSERT_TRUE(ref.hasVariable(119)); + + ASSERT_TRUE(proxy.hasVariable(119)); + + auto z0 = ref.getVariable(119)->getNDArray(); + auto z1 = proxy.getVariable(119)->getNDArray(); + + ASSERT_FALSE(z0 == z1); + ASSERT_TRUE(y == z1); + ASSERT_TRUE(x == z0); +} + +TEST_F(VariableProxyTests, Test_Simple_4) { + auto x = NDArrayFactory::create_('c', {2, 2}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create_('c', {2, 2}, {4, 2, 3, 1}); + auto z = NDArrayFactory::create_('c', {2, 2}, {4, 1, 3, 2}); + VariableSpace ref; + + ref.putVariable(119, x); + ref.putVariable(118, z); + + ASSERT_TRUE(ref.hasVariable(119)); + + VariableProxy proxy(&ref); + + ASSERT_TRUE(proxy.hasVariable(119)); + + proxy.putVariable(119, y); + + ASSERT_TRUE(ref.hasVariable(119)); + ASSERT_TRUE(ref.hasVariable(118)); + + ASSERT_TRUE(proxy.hasVariable(119)); + ASSERT_TRUE(proxy.hasVariable(118)); + + auto z0 = ref.getVariable(119)->getNDArray(); + auto z1 = proxy.getVariable(119)->getNDArray(); + + ASSERT_FALSE(z0 == z1); + ASSERT_TRUE(y == z1); + ASSERT_TRUE(x == z0); +} + + +TEST_F(VariableProxyTests, Test_Cast_1) { + auto x = NDArrayFactory::create_('c', {2, 2}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create_('c', {2, 2}, {4, 2, 3, 1}); + VariableSpace ref; + + ref.putVariable(-119, x); + + ASSERT_TRUE(ref.hasVariable(-119)); + + VariableProxy proxy(&ref); + auto cast = (VariableSpace *) &proxy; + + ASSERT_TRUE(cast->hasVariable(-119)); + + cast->putVariable(-119, y); + + ASSERT_TRUE(ref.hasVariable(-119)); + + ASSERT_TRUE(cast->hasVariable(-119)); + + auto z0 = ref.getVariable(-119)->getNDArray(); + auto z1 = cast->getVariable(-119)->getNDArray(); + + ASSERT_FALSE(z0 == z1); + ASSERT_TRUE(y == z1); + ASSERT_TRUE(x == z0); +} + + +TEST_F(VariableProxyTests, Test_Clone_1) { + auto x = NDArrayFactory::create_('c', {2, 2}, {1, 2, 3, 4}); + auto y = NDArrayFactory::create_('c', {2, 2}, {4, 2, 3, 1}); + VariableSpace ref; + + ref.putVariable(118, x); + + VariableProxy proxy(&ref); + + proxy.putVariable(119, y); + + ASSERT_TRUE(proxy.hasVariable(118)); + ASSERT_TRUE(proxy.hasVariable(119)); + + auto clone = proxy.clone(); + + ASSERT_TRUE(clone->hasVariable(118)); + ASSERT_TRUE(clone->hasVariable(119)); + + delete clone; +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/VariableSpaceTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/VariableSpaceTests.cpp new file mode 100644 index 000000000..da686c207 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/VariableSpaceTests.cpp @@ -0,0 +1,222 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include +#include +#include + +using namespace sd; +using namespace sd::graph; + +class VariableSpaceTest : public testing::Test { +public: + int *cShape = new int[8]{2, 2, 2, 2, 1, 0, 1, 99}; + int *fShape = new int[8]{2, 2, 2, 1, 2, 0, 1, 102}; + + + ~VariableSpaceTest() { + delete[] cShape; + delete[] fShape; + } +}; + + +TEST_F(VariableSpaceTest, SettersGettersTest1) { + auto space1 = new VariableSpace(); + auto arrayA = NDArrayFactory::create_('c', {5, 5}); + auto arrayB = NDArrayFactory::create_('c', {3, 3}); + + space1->putVariable(1, arrayA); + space1->putVariable(2, arrayB); + + auto arrayRA = space1->getVariable(1); + auto arrayRB = space1->getVariable(2); + + ASSERT_TRUE(arrayA == arrayRA->getNDArray()); + ASSERT_TRUE(arrayB == arrayRB->getNDArray()); + + // we should survive this call + delete space1; +} + + +TEST_F(VariableSpaceTest, SettersGettersTest2) { + auto space1 = new VariableSpace(); + auto arrayA = NDArrayFactory::create_('c', {5, 5}); + auto arrayB = NDArrayFactory::create_('c', {3, 3}); + + auto varA = new Variable(arrayA); + auto varB = new Variable(arrayB); + + varA->markExternal(true); + + space1->putVariable(-1, varA); + space1->putVariable(2, varB); + + Nd4jLong expExternal = (25 * 4) + (8 * 8); + Nd4jLong expInternal = (9 * 4) + (8 * 8); + + ASSERT_EQ(expExternal, space1->externalMemory()); + ASSERT_EQ(expInternal, space1->internalMemory()); + + delete space1; +} + +TEST_F(VariableSpaceTest, EqualityTest1) { + VariableSpace space; + + std::string name("myvar"); + + auto arrayA = NDArrayFactory::create_('c', {3, 3}); + auto variableA = new Variable(arrayA, name.c_str()); + + space.putVariable(1, variableA); + + std::pair pair(1,0); + + ASSERT_TRUE(space.hasVariable(1)); + ASSERT_TRUE(space.hasVariable(pair)); + ASSERT_TRUE(space.hasVariable(&name)); + + auto rV1 = space.getVariable(1); + auto rV2 = space.getVariable(pair); + auto rV3 = space.getVariable(&name); + + ASSERT_TRUE(rV1 == rV2); + ASSERT_TRUE(rV2 == rV3); +} + +TEST_F(VariableSpaceTest, EqualityTest2) { + VariableSpace space; + + auto arrayA = NDArrayFactory::create_('c', {3, 3}); + + space.putVariable(1, arrayA); + + std::pair pair(1,0); + + ASSERT_TRUE(space.hasVariable(1)); + ASSERT_TRUE(space.hasVariable(pair)); + + auto rV1 = space.getVariable(1); + auto rV2 = space.getVariable(pair); + + ASSERT_TRUE(rV1 == rV2); +} + +TEST_F(VariableSpaceTest, CloneTests_1) { + VariableSpace spaceA; + + auto arrayA = NDArrayFactory::create_('c', {3, 3}); + arrayA->assign(1.0); + + spaceA.putVariable(1, arrayA); + + auto spaceB = spaceA.clone(); + + std::pair pair(1,0); + + ASSERT_TRUE(spaceB->hasVariable(1)); + ASSERT_TRUE(spaceB->hasVariable(pair)); + + auto arrayB = spaceB->getVariable(1)->getNDArray(); + + ASSERT_TRUE(arrayA->equalsTo(arrayB)); + + arrayB->assign(2.0); + + ASSERT_FALSE(arrayA->equalsTo(arrayB)); + + delete spaceB; +} + +TEST_F(VariableSpaceTest, CloneTests_2) { + VariableSpace spaceA; + + auto arrayA = NDArrayFactory::create_('c', {3, 3}); + arrayA->assign(1.0); + + auto variableA = new Variable(arrayA, "alpha"); + + std::string str("alpha"); + std::pair pair(2, 3); + + spaceA.putVariable(pair, variableA); + + ASSERT_TRUE(spaceA.hasVariable(&str)); + ASSERT_TRUE(spaceA.hasVariable(pair)); + + auto spaceB = spaceA.clone(); + + ASSERT_FALSE(spaceB->hasVariable(1)); + ASSERT_FALSE(spaceB->hasVariable(2)); + ASSERT_TRUE(spaceB->hasVariable(pair)); + ASSERT_TRUE(spaceB->hasVariable(&str)); + + auto arrayB = spaceB->getVariable(pair)->getNDArray(); + + ASSERT_TRUE(arrayA->equalsTo(arrayB)); + + arrayB->assign(2.0); + + ASSERT_FALSE(arrayA->equalsTo(arrayB)); + + delete spaceB; + + ASSERT_TRUE(spaceA.hasVariable(&str)); + ASSERT_TRUE(spaceA.hasVariable(pair)); +} + + +TEST_F(VariableSpaceTest, Test_DType_Conversion_1) { + /* + VariableSpace spaceA; + + auto arrayA = NDArrayFactory::create_('c', {3, 3}); + arrayA->assign(1.0); + + auto variableA = new Variable(arrayA, "alpha"); + + std::string str("alpha"); + std::pair pair(2, 3); + + spaceA.putVariable(pair, variableA); + + + auto sd = spaceA.template asT(); + auto sf = sd->template asT(); + + ASSERT_TRUE(sf->hasVariable(pair)); + + auto xf = sf->getVariable(pair)->getNDArray(); + + ASSERT_TRUE(arrayA->isSameShape(xf)); + ASSERT_TRUE(arrayA->equalsTo(xf)); + + delete sd; + delete sf; + */ +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/VariableTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/VariableTests.cpp new file mode 100644 index 000000000..bf7c3f162 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/VariableTests.cpp @@ -0,0 +1,227 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef LIBND4J_VARIABLETESTS_H +#define LIBND4J_VARIABLETESTS_H + +#include "testlayers.h" +#include +#include +#include + +using namespace sd; +using namespace sd::graph; + +class VariableTests : public testing::Test { +public: + +}; + +TEST_F(VariableTests, TestClone_1) { + auto array1 = NDArrayFactory::create_('c', {5, 5}); + array1->assign(1.0); + + auto var1 = new Variable(array1, "alpha"); + var1->setId(119); + + + auto var2 = var1->clone(); + + ASSERT_FALSE(var1->getNDArray() == var2->getNDArray()); + auto array2 = var2->getNDArray(); + + ASSERT_TRUE(array1->equalsTo(array2)); + ASSERT_EQ(var1->id(), var2->id()); + ASSERT_EQ(*var1->getName(), *var2->getName()); + + delete var1; + + std::string str("alpha"); + ASSERT_EQ(*var2->getName(), str); + array2->assign(2.0); + + ASSERT_NEAR(2.0, array2->meanNumber().e(0), 1e-5); + + delete var2; +} + +TEST_F(VariableTests, Test_FlatVariableDataType_1) { + flatbuffers::FlatBufferBuilder builder(1024); + auto original = NDArrayFactory::create('c', {5, 10}); + original.linspace(1); + + auto vec = original.asByteVector(); + + auto fShape = builder.CreateVector(original.getShapeInfoAsFlatVector()); + auto fBuffer = builder.CreateVector(vec); + auto fVid = CreateIntPair(builder, 1, 12); + + auto fArray = CreateFlatArray(builder, fShape, fBuffer, sd::graph::DType::DType_FLOAT); + + auto flatVar = CreateFlatVariable(builder, fVid, 0, sd::graph::DType::DType_FLOAT, 0, fArray); + + builder.Finish(flatVar); + + auto ptr = builder.GetBufferPointer(); + + auto restoredVar = GetFlatVariable(ptr); + + auto rv = new Variable(restoredVar); + + ASSERT_EQ(1, rv->id()); + ASSERT_EQ(12, rv->index()); + + auto restoredArray = rv->getNDArray(); + + ASSERT_TRUE(original.isSameShape(restoredArray)); + ASSERT_TRUE(original.equalsTo(restoredArray)); + + delete rv; +} + +TEST_F(VariableTests, Test_FlatVariableDataType_2) { + flatbuffers::FlatBufferBuilder builder(1024); + auto original = NDArrayFactory::create('c', {5, 10}); + original.linspace(1); + + auto vec = original.asByteVector(); + + auto fShape = builder.CreateVector(original.getShapeInfoAsFlatVector()); + auto fBuffer = builder.CreateVector(vec); + auto fVid = CreateIntPair(builder, 1, 12); + + auto fArray = CreateFlatArray(builder, fShape, fBuffer, sd::graph::DType::DType_DOUBLE); + + auto flatVar = CreateFlatVariable(builder, fVid, 0, sd::graph::DType::DType_DOUBLE, 0, fArray); + + builder.Finish(flatVar); + + auto ptr = builder.GetBufferPointer(); + + auto restoredVar = GetFlatVariable(ptr); + + auto rv = new Variable(restoredVar); + + ASSERT_EQ(1, rv->id()); + ASSERT_EQ(12, rv->index()); + + auto restoredArray = rv->getNDArray(); + + ASSERT_TRUE(original.isSameShape(restoredArray)); + ASSERT_TRUE(original.equalsTo(restoredArray)); + + delete rv; +} + + +TEST_F(VariableTests, Test_FlatVariableDataType_3) { + flatbuffers::FlatBufferBuilder builder(1024); + auto original = NDArrayFactory::create('c', {5, 10}); + auto floating = NDArrayFactory::create('c', {5, 10}); + original.linspace(1); + floating.linspace(1); + + auto vec = original.asByteVector(); + + auto fShape = builder.CreateVector(original.getShapeInfoAsFlatVector()); + auto fBuffer = builder.CreateVector(vec); + auto fVid = CreateIntPair(builder, 1, 12); + + auto fArray = CreateFlatArray(builder, fShape, fBuffer, sd::graph::DType::DType_DOUBLE); + + auto flatVar = CreateFlatVariable(builder, fVid, 0, sd::graph::DType::DType_DOUBLE, 0, fArray); + + builder.Finish(flatVar); + + auto ptr = builder.GetBufferPointer(); + + auto restoredVar = GetFlatVariable(ptr); + + auto rv = new Variable(restoredVar); + + ASSERT_EQ(1, rv->id()); + ASSERT_EQ(12, rv->index()); + + auto restoredArray = rv->getNDArray(); + auto conv = restoredArray->asT(); + + ASSERT_TRUE(floating.isSameShape(restoredArray)); + ASSERT_TRUE(floating.equalsTo(conv)); + + delete rv; +} + +/* +TEST_F(VariableTests, Test_FlatVariableDataType_4) { + flatbuffers::FlatBufferBuilder builder(1024); + auto original = NDArrayFactory::create('c', {5, 10}); + std::vector exp({5, 10}); + + auto vec = original.asByteVector(); + + auto fShape = builder.CreateVector(original.getShapeAsFlatVector()); + auto fVid = CreateIntPair(builder, 37, 12); + + auto flatVar = CreateFlatVariable(builder, fVid, 0, sd::graph::DType::DType_FLOAT, fShape, 0, 0, VarType_PLACEHOLDER); + + builder.Finish(flatVar); + + auto ptr = builder.GetBufferPointer(); + + auto restoredVar = GetFlatVariable(ptr); + + auto rv = new Variable(restoredVar); + + ASSERT_EQ(37, rv->id()); + ASSERT_EQ(12, rv->index()); + + //auto restoredArray = rv->getNDArray(); + ASSERT_EQ(PLACEHOLDER, rv->variableType()); + ASSERT_EQ(exp, rv->shape()); + + //ASSERT_TRUE(original.isSameShape(restoredArray)); + //ASSERT_TRUE(original.equalsTo(restoredArray)); + + delete rv; +} +*/ +TEST_F(VariableTests, Test_Dtype_Conversion_1) { + auto x = NDArrayFactory::create_('c', {2, 3}, {1, 2, 3, 4, 5, 6}); + Variable v(x, "alpha", 12, 3); + + auto vd = v.template asT(); + auto vf = vd->template asT(); + + ASSERT_EQ(*v.getName(), *vf->getName()); + ASSERT_EQ(v.id(), vf->id()); + ASSERT_EQ(v.index(), vf->index()); + + auto xf = vf->getNDArray(); + + ASSERT_TRUE(x->isSameShape(xf)); + ASSERT_TRUE(x->equalsTo(xf)); + + delete vd; + delete vf; +} + +#endif //LIBND4J_VARIABLETESTS_H diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/WorkspaceTests.cpp b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/WorkspaceTests.cpp new file mode 100644 index 000000000..4703cab3a --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/WorkspaceTests.cpp @@ -0,0 +1,291 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef LIBND4J_WORKSPACETESTS_H +#define LIBND4J_WORKSPACETESTS_H + +#include "testlayers.h" +#include +#include +#include +#include + +using namespace sd; +using namespace sd::memory; + +class WorkspaceTests : public testing::Test { + +}; + + +TEST_F(WorkspaceTests, BasicInitialization1) { + Workspace workspace(1024); + + ASSERT_EQ(1024, workspace.getCurrentSize()); + ASSERT_EQ(0, workspace.getCurrentOffset()); +} + +TEST_F(WorkspaceTests, BasicInitialization2) { + Workspace workspace(65536); + + ASSERT_EQ(0, workspace.getCurrentOffset()); + LaunchContext ctx; + ctx.setWorkspace(&workspace); + auto array = NDArrayFactory::create('c', {5, 5}, &ctx); + + array.p(0, 1.0f); + array.p(5, 1.0f); + + auto v = array.reduceNumber(reduce::Sum); + auto f = v.e(0); + + ASSERT_NEAR(2.0f, f, 1e-5); + + ASSERT_TRUE(workspace.getCurrentOffset() > 0); +} + + +TEST_F(WorkspaceTests, BasicInitialization3) { + Workspace workspace; + + ASSERT_EQ(0, workspace.getCurrentOffset()); + LaunchContext ctx; + ctx.setWorkspace(&workspace); + + auto array = NDArrayFactory::create('c', {5, 5}, &ctx); + + array.p(0, 1.0f); + array.p(5, 1.0f); + + auto v = array.reduceNumber(reduce::Sum); + auto f = v.e(0); + + ASSERT_NEAR(2.0f, array.reduceNumber(reduce::Sum).e(0), 1e-5); + + ASSERT_TRUE(workspace.getCurrentOffset() == 0); +} + + +TEST_F(WorkspaceTests, ResetTest1) { + Workspace workspace(65536); + LaunchContext ctx; + ctx.setWorkspace(&workspace); + + auto array = NDArrayFactory::create('c', {5, 5}, &ctx); + array.p(0, 1.0f); + array.p(5, 1.0f); + + workspace.scopeOut(); + for (int e = 0; e < 5; e++) { + workspace.scopeIn(); + + auto array2 = NDArrayFactory::create('c', {5, 5}, &ctx); + array2.p(0, 1.0f); + array2.p(5, 1.0f); + + ASSERT_NEAR(2.0f, array2.reduceNumber(reduce::Sum).e(0), 1e-5); + + workspace.scopeOut(); + } + + ASSERT_EQ(65536, workspace.getCurrentSize()); + ASSERT_EQ(0, workspace.getCurrentOffset()); + ASSERT_EQ(0, workspace.getSpilledSize()); +} + + +TEST_F(WorkspaceTests, StretchTest1) { + if (!Environment::getInstance().isCPU()) + return; + + Workspace workspace(128); + void* ptr = workspace.allocateBytes(8); + workspace.scopeOut(); + ASSERT_EQ(0, workspace.getSpilledSize()); + ASSERT_EQ(0, workspace.getSpilledSecondarySize()); + ASSERT_EQ(0, workspace.getCurrentOffset()); + ASSERT_EQ(0, workspace.getCurrentSecondaryOffset()); + + + workspace.scopeIn(); + for (int e = 0; e < 10; e++) { + + workspace.allocateBytes(128); + + } + ASSERT_EQ(128 * 9, workspace.getSpilledSize()); + workspace.scopeOut(); + workspace.scopeIn(); + + ASSERT_EQ(0, workspace.getCurrentOffset()); + + // we should have absolutely different pointer here, due to reallocation + void* ptr2 = workspace.allocateBytes(8); + + //ASSERT_FALSE(ptr == ptr2); + + + ASSERT_EQ(1280, workspace.getCurrentSize()); + ASSERT_EQ(0, workspace.getSpilledSize()); +} + +TEST_F(WorkspaceTests, NewInWorkspaceTest1) { + if (!Environment::getInstance().isCPU()) + return; + + Workspace ws(65536); + + ASSERT_EQ(65536, ws.getCurrentSize()); + ASSERT_EQ(0, ws.getCurrentOffset()); + + ASSERT_FALSE(MemoryRegistrator::getInstance().hasWorkspaceAttached()); + + MemoryRegistrator::getInstance().attachWorkspace(&ws); + + ASSERT_TRUE(MemoryRegistrator::getInstance().hasWorkspaceAttached()); + + auto ast = NDArrayFactory::create_('c', {5, 5}); + + ASSERT_TRUE(ws.getCurrentOffset() > 0); + + delete ast; + + MemoryRegistrator::getInstance().forgetWorkspace(); + + ASSERT_FALSE(MemoryRegistrator::getInstance().hasWorkspaceAttached()); + ASSERT_TRUE(MemoryRegistrator::getInstance().getWorkspace() == nullptr); +} + + +TEST_F(WorkspaceTests, NewInWorkspaceTest2) { + Workspace ws(65536); + LaunchContext ctx; + ctx.setWorkspace(&ws); + + ASSERT_EQ(65536, ws.getCurrentSize()); + ASSERT_EQ(0, ws.getCurrentOffset()); + + MemoryRegistrator::getInstance().attachWorkspace(&ws); + + auto ast = NDArrayFactory::create_('c', {5, 5}, &ctx); + + ASSERT_TRUE(ws.getCurrentOffset() > 0); + + delete ast; + + MemoryRegistrator::getInstance().forgetWorkspace(); +} + +TEST_F(WorkspaceTests, CloneTest1) { + if (!Environment::getInstance().isCPU()) + return; + + Workspace ws(65536); + + ws.allocateBytes(65536 * 2); + + ASSERT_EQ(65536 * 2, ws.getSpilledSize()); + + auto clone = ws.clone(); + + ASSERT_EQ(65536 * 2, clone->getCurrentSize()); + ASSERT_EQ(0, clone->getCurrentOffset()); + ASSERT_EQ(0, clone->getSpilledSize()); + + delete clone; +} + +TEST_F(WorkspaceTests, Test_Arrays_1) { + Workspace ws(65536); + LaunchContext ctx; + ctx.setWorkspace(&ws); + + auto x = NDArrayFactory::create('c', {3, 3}, {1, 2, 3, 4, 5, 6, 7, 8, 9}, &ctx); + + // x.printIndexedBuffer("x0"); + + auto y = NDArrayFactory::create('c', {3, 3}, {-1, -2, -3, -4, -5, -6, -7, -8, -9}, &ctx); + + // x.printIndexedBuffer("x2"); + + auto z = NDArrayFactory::create('c', {3, 3}, {0, 0, 0, 0, 0, 0, 0, 0, 0}, &ctx); + + MmulHelper::mmul(&x, &y, &z); + + y.assign(&x); + + + // x.printIndexedBuffer("x3"); + // y.printIndexedBuffer("y"); + // z.printIndexedBuffer("z"); +} + +#ifdef GRAPH_FILES_OK +TEST_F(WorkspaceTests, Test_Graph_1) { + auto graph = GraphExecutioner::importFromFlatBuffers("./resources/ae_00.fb"); + auto workspace = graph->getVariableSpace()->workspace(); + + auto status = GraphExecutioner::execute(graph); + ASSERT_EQ(Status::OK(), status); + + delete graph; +} +#endif + +TEST_F(WorkspaceTests, Test_Externalized_1) { + if (!Environment::getInstance().isCPU()) + return; + + char buffer[10000]; + ExternalWorkspace pojo((Nd4jPointer) buffer, 10000, nullptr, 0); + + ASSERT_EQ(10000, pojo.sizeHost()); + ASSERT_EQ(0, pojo.sizeDevice()); + + Workspace ws(&pojo); + ASSERT_EQ(10000, ws.getCurrentSize()); + ASSERT_EQ(10000, ws.getAllocatedSize()); + LaunchContext ctx; + ctx.setWorkspace(&ws); + + auto x = NDArrayFactory::create('c', {10, 10}, &ctx); + + // only buffer size goes into account + ASSERT_EQ(400, ws.getUsedSize()); + ASSERT_EQ(400, ws.getCurrentOffset()); + + x.assign(2.0); + + float m = x.meanNumber().e(0); + ASSERT_NEAR(2.0f, m, 1e-5); +} + +// TODO: uncomment this test once long shapes are introduced +/* +TEST_F(WorkspaceTests, Test_Big_Allocation_1) { + Workspace ws(65536); + NDArray x('c', {256, 64, 384, 384}, &ws); +} +*/ + + +#endif //LIBND4J_WORKSPACETESTS_H \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/WorkspaceTests.cu b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/WorkspaceTests.cu new file mode 100644 index 000000000..4d4538dae --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/WorkspaceTests.cu @@ -0,0 +1,62 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include "testlayers.h" +#include +#include +#include +#include + +using namespace sd; +using namespace sd::memory; + +class CudaWorkspaceTests : public testing::Test { + +}; + +TEST_F(CudaWorkspaceTests, Basic_Tests_1) { + Workspace workspace(65536, 65536); + + ASSERT_EQ(0, workspace.getCurrentOffset()); + LaunchContext ctx; + ctx.setWorkspace(&workspace); + auto array = NDArrayFactory::create('c', {5, 5}, &ctx); + + ASSERT_EQ(108, workspace.getCurrentOffset()); + ASSERT_EQ(0, workspace.getCurrentSecondaryOffset()); + + array.e(0); + + ASSERT_EQ(100, workspace.getCurrentSecondaryOffset()); +} + +TEST_F(CudaWorkspaceTests, Basic_Tests_2) { + Workspace workspace(65536, 65536); + + ASSERT_EQ(0, workspace.getCurrentOffset()); + LaunchContext ctx; + ctx.setWorkspace(&workspace); + auto array = NDArrayFactory::create('c', {5, 5}, {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}, &ctx); + + ASSERT_EQ(108, workspace.getCurrentOffset()); + ASSERT_EQ(0, workspace.getCurrentSecondaryOffset()); +} \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/suppressions.txt b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/suppressions.txt new file mode 100644 index 000000000..6e5924015 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/suppressions.txt @@ -0,0 +1,2 @@ +#std::vector +leak:std::vector \ No newline at end of file diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/testinclude.h b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/testinclude.h new file mode 100644 index 000000000..f792cd835 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/testinclude.h @@ -0,0 +1,54 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// Created by agibsonccc on 1/15/17. +// + +#ifndef LIBND4J_TESTINCLUDE_H +#define LIBND4J_TESTINCLUDE_H +#include "testlayers.h" +#include +#include + +//https://stackoverflow.com/questions/228005/alternative-to-itoa-for-converting-integer-to-string-c +FORCEINLINE std::string int_array_to_string(Nd4jLong int_array[], Nd4jLong size_of_array) { + std::string returnstring = "["; + for (int temp = 0; temp < size_of_array; temp++) { + returnstring += std::to_string(int_array[temp]); + if(temp < size_of_array - 1) + returnstring += ","; + } + returnstring += "]"; + return returnstring; +} + +FORCEINLINE ::testing::AssertionResult arrsEquals(Nd4jLong n, Nd4jLong *assertion,Nd4jLong *other) { + for(int i = 0; i < n; i++) { + if(assertion[i] != other[i]) { + std::string message = std::string("Failure at index ") + std::to_string(i) + std::string(" assertion: ") + int_array_to_string(assertion,n) + std::string(" and test array ") + int_array_to_string(other,n) + std::string(" is not equal"); + return ::testing::AssertionFailure() << message; + } + + } + return ::testing::AssertionSuccess(); + +} + + +#endif //LIBND4J_TESTINCLUDE_H diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/testlayers.h b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/testlayers.h new file mode 100644 index 000000000..fb8032df8 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/layers_tests/testlayers.h @@ -0,0 +1,43 @@ +/* ****************************************************************************** + * + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * See the NOTICE file distributed with this work for additional + * information regarding copyright ownership. + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#ifndef LIBND4J_TESTLAYERS_H +#define LIBND4J_TESTLAYERS_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#endif //LIBND4J_TESTLAYERS_H diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/libnd4j_tests/CMakeLists.txt b/cavis-native/cavis-native-lib/src/test/tests_cpu/libnd4j_tests/CMakeLists.txt new file mode 100644 index 000000000..f9a81567d --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/libnd4j_tests/CMakeLists.txt @@ -0,0 +1,301 @@ +cmake_minimum_required(VERSION 3.9) +project(dev_tests) +message("Starting up tests build") +set(CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/../../cmake" ${CMAKE_MODULE_PATH}) + +# Download and unpack googletest at configure time +configure_file(../CMakeLists.txt.in googletest-download/CMakeLists.txt) +execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . + RESULT_VARIABLE result + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/googletest-download ) +if(result) + message(FATAL_ERROR "CMake step for googletest failed: ${result}") +endif() +execute_process(COMMAND ${CMAKE_COMMAND} --build . + RESULT_VARIABLE result + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/googletest-download ) +if(result) + message(FATAL_ERROR "Build step for googletest failed: ${result}") +endif() + +# OPTIONAL MKL-DNN +if ("${BUILD_MKLDNN}") + # Download and unpack mkl-dnn at configure time + configure_file(../../CMakeLists.txt.mkldnn.in mkldnn-download/CMakeLists.txt) + execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . + RESULT_VARIABLE result + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-download ) + if(result) + message(FATAL_ERROR "CMake step for mkldnn failed: ${result}") + endif() + execute_process(COMMAND ${CMAKE_COMMAND} --build . + RESULT_VARIABLE result + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-download ) + if(result) + message(FATAL_ERROR "Build step for mkldnn failed: ${result}") + endif() + + add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src + ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-build + EXCLUDE_FROM_ALL) + set(mkldnn_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/mkldnn-src) + set(HAVE_MKLDNN 1) + add_definitions("-DHAVE_MKLDNN") + include_directories(${mkldnn_SOURCE_DIR}/include ${mkldnn_SOURCE_DIR}/external/mklml_lnx_2019.0.3.20190220/include ${mkldnn_SOURCE_DIR}) + set(MKLDNN dnnl) +endif() + +if (${HELPERS_armcompute}) + find_package(ARMCOMPUTE REQUIRED) + + if(ARMCOMPUTE_FOUND) + message("Found ARMCOMPUTE: ${ARMCOMPUTE_LIBRARIES}") + set(HAVE_ARMCOMPUTE 1) + # Add preprocessor definition for ARM Compute NEON + add_definitions(-DARMCOMPUTENEON_ENABLED) + include_directories(${ARMCOMPUTE_INCLUDE}) + endif() + +endif() + +# Download and unpack flatbuffers at configure time +configure_file(../../CMakeLists.txt.in flatbuffers-download/CMakeLists.txt) +execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . + RESULT_VARIABLE result + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/flatbuffers-download ) +if(result) + message(FATAL_ERROR "CMake step for flatbuffers failed: ${result}") +endif() +execute_process(COMMAND ${CMAKE_COMMAND} --build . + RESULT_VARIABLE result + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/flatbuffers-download ) +if(result) + message(FATAL_ERROR "Build step for flatbuffers failed: ${result}") +endif() + +# Add flatbuffers directly to our build. +add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/flatbuffers-src + ${CMAKE_CURRENT_BINARY_DIR}/flatbuffers-build + EXCLUDE_FROM_ALL) + +set(HAVE_FLATBUFFERS 1) +set(FLATBUFFERS_PATH ${CMAKE_CURRENT_BINARY_DIR}/flatbuffers-src) +include_directories(${FLATBUFFERS_PATH}/include) + + +# Prevent overriding the parent project's compiler/linker +# settings on Windows +set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) + +# Add googletest directly to our build. This defines +# the gtest and gtest_main targets. +add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/googletest-src + ${CMAKE_CURRENT_BINARY_DIR}/googletest-build + EXCLUDE_FROM_ALL) + +set(gtest_SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/googletest-src) +add_definitions(-D__STANDALONE_BUILD__=true) + +include_directories(${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) +include_directories(../../include) +if(LINUX) + link_directories(/usr/local/lib) + link_directories(/usr/lib) + link_directories(/lib) +endif() + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) + +if(APPLE) + message("Using apple") + link_directories(/usr/local/lib) + link_directories(/usr/lib) + link_directories(/lib) +endif() +if(WIN32) + get_property(dirs DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} PROPERTY INCLUDE_DIRECTORIES) + foreach(dir ${dirs}) + message(STATUS "dir='${dir}'") + endforeach() +endif() + +# -fsanitize=address +# -fsanitize=leak +if (APPLE) + set(CMAKE_CXX_FLAGS " -O0 -g -fPIC -std=c++11 -D__APPLE_OS__=true -DSD_APPLE_BUILD=true") +elseif(WIN32) + if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + set(CMAKE_CXX_FLAGS " -g -fPIC -std=c++11 -Wa,-mbig-obj") + endif() +else() + set(CMAKE_CXX_FLAGS " ${CMAKE_CXX_FLAGS} -ffast-math -DFFAST_MATH=true -DLINUX_BUILD=true") + + if ("${_RELEASE}" OR CMAKE_BUILD_TYPE STREQUAL "Release") + message("Release build for tests") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3 -fPIC -std=c++11 -D_RELEASE=true") + if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64*") + set(CMAKE_CXX_FLAGS " ${CMAKE_CXX_FLAGS} -mcpu=native") + else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native -mtune=native") + endif() + else() + set(CMAKE_CXX_FLAGS " -g -O0 -fPIC -std=c++11 ") + if (NOT SD_CUDA) + set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fsanitize=address") + endif() + endif() + + if (${F16C}) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mf16c -DSD_F16C=true") + endif() +endif() + +if ("${_RELEASE}" OR CMAKE_BUILD_TYPE STREQUAL "Release") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DRELEASE_BUILD=true") +else() + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DDEBUG_BUILD=true") +endif() + +if ("${SD_EXPERIMENTAL}" STREQUAL "yes") + message("Experimental mode ENABLED") + set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -D__ND4J_EXPERIMENTAL__=true") + set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__ND4J_EXPERIMENTAL__=true") +endif() + +# tests are always compiled with all ops included +SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DSD_ALL_OPS=true -DDEFAULT_ENGINE=samediff::ENGINE_CPU") + +if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang") + # using Clang + SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE} -Wno-logical-op-parentheses -Wno-inconsistent-missing-override -Wno-implicit-conversion-floating-point-to-bool -Wno-delete-non-virtual-dtor") +elseif("${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang") + message("AppleClang used") + SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE} -Wno-logical-op-parentheses -Wno-inconsistent-missing-override -Wno-implicit-conversion-floating-point-to-bool -Wno-delete-non-virtual-dtor") +elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Intel") + # using Intel C++ + SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${ARCH_TUNE} -fp-model fast") +elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC") + # using Visual Studio C++ + +elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + # using GCC + SET( CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fmax-errors=2") +endif() + + +IF(${CMAKE_SYSTEM_NAME} MATCHES "Linux") + include_directories("/usr/include") + include_directories("/usr/local/include") +ENDIF(${CMAKE_SYSTEM_NAME} MATCHES "Linux") + +if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 4.9) + message(FATAL_ERROR "You need at least GCC 4.9") +endif() + +message("Looking for OpenMP") +find_package(OpenMP) +if (OPENMP_FOUND) + set (CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") + set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") +else() + message("OPENMP NOT FOUND") +endif() + +if ("${OPENBLAS}" OR CMAKE_BUILD_TYPE STREQUAL "Release" OR "${BUILD_MKLDNN}") + message("Looking for BLAS") + find_package(BLAS REQUIRED) + if (BLAS_FOUND) + message("Found external BLAS library: ${BLAS_LIBRARIES}") + add_definitions(-D__EXTERNAL_BLAS__=true) + endif() +endif() + +file(GLOB_RECURSE PERF_SOURCES false ../../include/performance/*.cpp ../../include/performance/*.h) +file(GLOB_RECURSE EXCEPTIONS_SOURCES false ../../include/exceptions/*.cpp ../../include/exceptions/*.h) +file(GLOB_RECURSE EXEC_SOURCES false ../../include/execution/*.cpp ../../include/execution/*.h) +file(GLOB_RECURSE TYPES_SOURCES false ../../include/types/*.cpp ../../include/types/*.h) +file(GLOB_RECURSE ARRAY_SOURCES false ../../include/array/*.cpp ../../include/array/*.h) +file(GLOB_RECURSE MEMORY_SOURCES false ../../include/memory/*.cpp ../../include/memory/*.h) +file(GLOB_RECURSE GRAPH_SOURCES false ../../include/graph/*.cpp ../../include/graph/*.h) +file(GLOB_RECURSE CUSTOMOPS_SOURCES false ../../include/ops/declarable/generic/*.cpp) +file(GLOB_RECURSE CUSTOMOPS_GENERIC_SOURCES false ../../include/ops/declarable/helpers/cpu/*.cpp ../../include/ops/declarable/helpers/impl/*.cpp) +file(GLOB_RECURSE OPS_SOURCES false ../../include/ops/impl/*.cpp ../../include/ops/declarable/impl/*.cpp ../../include/ops/*.h) +file(GLOB_RECURSE INDEXING_SOURCES false ../../include/indexing/*.cpp ../../include/indexing/*.h) +file(GLOB_RECURSE HELPERS_SOURCES false ../../include/helpers/*.cpp) +file(GLOB_RECURSE LEGACY_SOURCES false ../../include/legacy/impl/*.cpp ../../include/legacy/cpu/*.cpp ../../include/legacy/*.h) +file(GLOB_RECURSE LOOPS_SOURCES false ../../include/loops/*.cpp ../../include/loops/*.h) + +# optionally build mkldnn +if ("${BUILD_MKLDNN}") + file(GLOB_RECURSE CUSTOMOPS_PLATFORM_SOURCES false ../../include/ops/declarable/platform/mkldnn/*.cpp) +endif() + +if(HAVE_ARMCOMPUTE) + file(GLOB_RECURSE CUSTOMOPS_ARMCOMPUTE_SOURCES false ../include/ops/declarable/platform/armcompute/*.cpp ../include/ops/declarable/platform/armcompute/armcomputeUtils.h) +endif() + +message("CPU backend") +add_definitions(-D__CPUBLAS__=true) + +if (CMAKE_BUILD_TYPE STREQUAL "Debug" AND NOT(MINGW) AND NOT(APPLE)) + SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -rdynamic -Wl,-export-dynamic") + SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -export-dynamic") +endif() + + file(GLOB_RECURSE COMPILATION_UNITS false ../include/ops/declarable/helpers/cpu/compilation_units/*.cpp.in + ../include/loops/cpu/compilation_units/*.cpp.in ../include/helpers/cpu/loops/*.cpp.in) + + foreach(FL_ITEM ${COMPILATION_UNITS}) + genCompilation(FL_ITEM) + endforeach() + +# this function strips path from file name, basically making up short file name, i.e. file.cpp +function(SHORTNAME LONG_NAME OUTPUT) + SET(_TMP_STR "") + string (REGEX REPLACE ".*/" "" _TMP_STR "${LONG_NAME}") + set (${OUTPUT} "${_TMP_STR}" PARENT_SCOPE) +endfunction() + +# now we ned to join two lists +# first of all we'll build truncated list of files in platform sources +# and list of priority implementations from platform helpers +#set(CUSTOMOPS_HELPERS_SOURCES "") +#set(SHORT_NAMES "") +#foreach(LONG_NAME ${CUSTOMOPS_PLATFORM_SOURCES}) +# SHORTNAME("${LONG_NAME}" "SHORT_NAME") +# set(CUSTOMOPS_HELPERS_SOURCES ${CUSTOMOPS_HELPERS_SOURCES} ${LONG_NAME}) +# set(SHORT_NAMES ${SHORT_NAMES} ${SHORT_NAME}) +#endforeach() + +# now we're going to filter generic helpers, to exclude platform implementations +#foreach(LONG_NAME ${CUSTOMOPS_GENERIC_SOURCES}) +# SHORTNAME("${LONG_NAME}" "SHORT_NAME") + + # and now we add this op ONLY if it wasn't announced in platform helpers +# string(FIND "${SHORT_NAMES}" "${SHORT_NAME}" "LOC") +# if (${LOC} EQUAL -1) +# set(CUSTOMOPS_HELPERS_SOURCES ${CUSTOMOPS_HELPERS_SOURCES} ${LONG_NAME}) +# endif() +#endforeach() + + +file(GLOB_RECURSE TEST_SOURCES false ../layers_tests/*.cpp ../layers_tests/*.h) + + +# Filter out any source files from */CMakeFiles/* paths. these tend to cause problems such a multiple main definitions. +set (EXCLUDE_DIR "/CMakeFiles/") +foreach (TMP_PATH ${TEST_SOURCES}) + string (FIND ${TMP_PATH} ${EXCLUDE_DIR} EXCLUDE_DIR_FOUND) + if (NOT ${EXCLUDE_DIR_FOUND} EQUAL -1) + list (REMOVE_ITEM TEST_SOURCES ${TMP_PATH}) + endif () +endforeach(TMP_PATH) + + +add_executable(runtests ${LOOPS_SOURCES} ${LEGACY_SOURCES} ${EXEC_SOURCES} ${HELPERS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES} + ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_PLATFORM_SOURCES} + ${CUSTOMOPS_ARMCOMPUTE_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} + ${OPS_SOURCES} ${TEST_SOURCES} ${PERF_SOURCES}) + +target_link_libraries(runtests gtest ${MKLDNN} ${ARMCOMPUTE_LIBRARIES} gtest_main ${BLAS_LIBRARIES}) + diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/resources/arr_3,4_float32.npy b/cavis-native/cavis-native-lib/src/test/tests_cpu/resources/arr_3,4_float32.npy new file mode 100644 index 0000000000000000000000000000000000000000..ead15844c59e2c77138c84d14a805aa07d4172e4 GIT binary patch literal 176 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+l>qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I%oItnJ5ItsN4WCJb+Flev`QVu`_#0@~a0EiC&@dF@caAaT*0Ad9oHgE(0 D6k#F} literal 0 HcmV?d00001 diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/resources/assert_type_rank2_int64.fb b/cavis-native/cavis-native-lib/src/test/tests_cpu/resources/assert_type_rank2_int64.fb new file mode 100644 index 0000000000000000000000000000000000000000..8d643d6c424b65c682c15ea23131038980f93cc3 GIT binary patch literal 1664 zcmb`HziU%b6vt1}*rYK?lqf+uWGEQ0*bW&S3a0Ad;NXx^h|lDq!Sp4*JfQ&x2NC}Q zaddQWaOmLRAPC~*V*daqad7D5q{h$py?0{?6*2X|_nmX@efRvnw`vKM?KDEp3RbeR zRe;$Qa2~Wk71V**F?I{Fv}k!W(L!dDQmpdHf^9yR;W_a22qmaL)eHP?yE6zwzmxo;(+M5s zth3>`;#+V^D!+tZ2a5Z_kIAG7Z-7;xTGc=WDA%;SNJ(z7F!MQz-V26(uj5}LL)%Sl z%9iRTyUXApn=Btl^6PE)M%rBYnZ$S)dqwQr+zfBs+wy}5t985fPxGQW_Q5ltxp@rk zgCXbw)w2aOFB@>>t(?o^A3wjXES-IY7|(}dQBS;XcRDZ9H_;u<>Tca}#Us+)XrR>6 z1)X~X6pw0ngLsnwOeEL)PP)NUsM%Q?6L~eBYGtQa(B)E$eQRztlgj5J(Cp}i);q2DSNT%8jP9%a z*1+FLvC59f9;1K<05&hwn}ea`uv@A;kcJkM=5E-g1ta#(HZOoM4OO=gu@PW~8(f*H^Z+CeKY zW+`Rr?9mcaOHS$0s*Im(*8h1NI~_5mj~k4B!I%@DHzxf#@+*um_WyZ`wu%p44 zll8_-Eiq=Sn)^ONIqI$k@iywNH)b~&e}?m*6>>o_UdK24Xsg&90bGq>tZ!;*wnI#Dv zsBN*4biaWa2tOQhTa7&eg*z{1D}Qn0iS0oMGwV0(|myrF=mx zG%$zN=SumgY6wMU|4{5eVi(|e$@NS|^I|DhzCV4-ixWD23)E>0z+KD?xCyQS>Fpwr zzDBTWQ!{7;TdqBJ<%u7EUm-H9u{w^EaWtJwA93e3uAQKm&Uy#>(?bKP?s#G_(GN%O zk?xJ+a48H9$0e&S6T+=sC?1{tbNKqZ?QeC1OYkGiH8ua}?(zx;g&5XadsB9OtE*+dw*zp4x!ZiPmP9H0Nt-&?9#Ea5CN7=f#02dG$#HUQ8!` zi29bT-rccTbvBe=j~ZTDOkKy}too@mAwlMfqT?l}|K2W$OpUK1;o_;xpx4T1;KX z;jH@B2zV4LNv}QlHA>kq^*Z5mh81V4Z-=YbJ@Bc|i)mj_ue;cr>fBx|%ed*e9tLu~UBwUvDk=I!ap>GoJFL+c%~|;0x{>``L?XuXqS6 z-@Ss>8hjk9_bjKdpT+(Jt623qR zBPi#2q=WLs^fdo@gi3x|Y<1iql2-ea$*`FHsv zPanE|8r%fZ#1xPwN{KTs&8<7Wtp3HT^0T11&_6IRh{@(#I4-7(T?5~=byQEOk~&_^ z2fu2M;;Ku*_hJ$HALU}0>sQ+6Y=>9zxD9A7H37{{`KMS?85d73hQJpb|4sk&#$CQS z3ueGga1C4rlVAdjgE64E(1%r@Dv1m0s8?+K2Zs{z?qr;jUaF*yt1GLquk-DK^Z)KA zF#bMgXyYFOUvT{2`-`{Fi2lpdf&4TIX2A@&0j9uZZ~-_iNCzqMmFl3Cy2W(hy}#(* zPtk0uv|g5g1L>jGw+rc^{a*NbIOcPPrH5-F@CEg7l|Ahee6#!Co*v|JK^&xscd_A1;Im6P!G<;pOy;8oQ7RFc`oL_qWWA$)YdBp;~ zpLEa6bL3~iG`In-0mTQcCvMGa7?5U9 zg4ciohIG6WD?K)YMxZ$1l8yz%3w;aKv3#j=`P$-QPFm=?>T}CirKP+Z`E?z?t{){c zs1qdJU$IO5(EVrUGMO{jPM|s14%Gf8&;V2>H;G2qf@#V=7t~j&y{aR;N&{V7PwlhM$P&?gyxI8aOAIafu0|y3Ei4l%x zL*lK2>%~N3aC`CqUy*tL&2zMTa1Se=x0;>h=x&ATNOzOqJa`L?g45s_(6uzc#aZ30 zBCj#~*~pp?fAGcnpUMw{#=w45re(^jU3wRgr@OD|t1~D4RP4Gd=d5_8HKC16L?=k< z$0XPXbl)D@8^tbW9E0zxC*bUYQ`*)1la8A~A#HGaUQ)YUOlu+*>P$@rE z?R$mX{~0W0PbOcj#L88B?ekU5BN{*5PvbiGU+*1`uIr0I-gs@D|Nj90l(!W5-{amX zz`F--jc;YV-uIFv>>*&tYwtL`dWR(5O5SOI$LCc4HGcAe{z;_q(EM_-L}Ogc7rATx zyO!7Ut>Tk-PvIA>#nPw#FZ27H{k43GNMSpv`e zcU|7M_d}L9zYovTX1(Jp+t=M4=28K^pTj3zl+U-(tYmND^x*pa5q#1^CEpwN7vTFU zZ5H=hP;5NBxxjxfEKi4Bmj&avXaQ@|zpLg`4PHQC?2al=p| z7T?_dbYx4!D{C*>AF+D6m+jRLsf>QdW9eA8^s&-DPt9-dbM!;+Z8!P-d02Z#Y4;TJIlh8|aTG)2#E6k;s-Ukt30KG9Kw4 zNJshxQt?P4mC%2)o;{o#PHeUuuD*A-`F_{+a(%D(`Xly2#(2K(A@BO~EBdIoJmL3M zu}OXLMxBG4r1KAwS3>bn$>~DJ<+iV#IxpL2tNhrqmo{rX)L0ao6Z2z>F;@$4z741L zUo9nYx_7WES*zyFe|WUtD(Ai2cMI@rrj3gKrNsSV@6RavDQ?~_z;zrhy{jmLD=n^j z1-O0#m)4|G;*onke2%pz=%>|i>p4Se$sa@SS>PF4$J<-?`Q@~RJx@L9ZKrKpeU;W2 z<>f0sRjGYWf7)*-x%|WS-->;CZDEoQTTd)!f1*`c%xZJ3-Ae}hbcZ%+ZJo!H-DfW6 zJGtZ0e}x@Sjg8f#mfKsaH`7-krmBl$^;2^q`dWFAZ=4_0J|%G}Ih*xQ#GcOj z%gtYv(}JMhtzMuH+KW$s7ND`x?xl`9xbe{cqrKkO+C4jO2mbH#<)oFK$p>hpJ%e1#v?m}V-R z9&**_M;RbQj;|1ADiaf$AR&?-qXNOK6ryR45LpB!9ZB)Le!}RQTP-v~OS9=jgovd6 z7@-&O!a#Kw&0~GF1hYa2BSBAyAZQ7V)L!rDXOq}=*1`7Ew4!?7wqJZW?9`B`nB{4+ zmSx)u3-hA03T)e`NpPPJA%tKpvlUshZJ8{PTaX={^-Mxpb|Om~sn<;pD;G%OGdyFw zInT-D&E2i{|9Z=L*X0p%yIg0^ZR|Vq4tA7?_c5Ky{)D!fLRbh&LOU6nMNknsXh04j zn9xb`76K0+zU6cpiF2$(5O}?{gkS=%Im^gIL&toS+hjhQ3c1BUv*z2@7iQULe{VaT zW@_`+EAOSOmvw4nC6jP+LnI_W=dz<@yPY%(ag21c&3XdI704?Oi>^2@kC;M z#pLg$>C{L5>PDyL@1*vTzvWZZ$No(uu5TfKyQfnh`K$iU=_Iv}eBQk6V_RJ|lFv(q zolahln!8RXv3Y#X>j;+#7YJtw?SzwrCPE{jp1}8lg*fLaKJzNV#=v(fMh*SjsW_G& z=?W7iMb?61dtosr4W7$~u!wN~y)Y)38AbgvJ&LS6hB?p0elqH)Qq2uSJikAc*{5_* zz2uJ5X`nLau!qUNBz7P1Rh(O3&$GU4TVu`36Wt!2yhb_#>*ck2P=2KDdx&kdmgg38 zqnh@}K>Ua3i1_7FM z;WcMp9wj~}e~*mCCuwuj!^+9>%-IJfUWZ1iaBMR18mg(x{_vh~T}h4@h)-u+T0A-< zN$+&TPPr9b6&pRfA>D@a1igv$b`OKseerTM^@hQhEsTAHZ8cX*S$2AaVx!sh< z=>*JBQkmn3V=ImNIDRMJ|MB;2fop7lg2&k>jBoJFk?!f&QXkL7etPI0Q+X~qZsap3 z*WHmc-v6>kftL6Ka_lpFE?G2)?iKRSWn&E)A^Qkq8`~s2{lj#yxjsM~pUc>V>qE@A z(tZ^33hOZLi5sAAPesMnc+9HYjC0P{W1h7~I$=n}7gj`|?!AjRH}(cBEw#YW;7ADF z;eeqW2kt67fSSrO{K?bTLBHOF5g(O-?XF!)JYs;s{c21K4u`Ei$8h^G6DBKjpgKhd z7fX~FJunxI2XA4(J}s_{a=>pojjKM=3u#t&zqIF0B&KwwBYx5;HT=2?!~We4r{Ax` zU~?-5wZ02+wVfEAb{bz%IiTT<0eE-xEG(~9Ve-HdoHL@A4yE+sFI#t`Zo>ii=i(}; zNVCAFm3o|YI2NNuXT#_<2W(Q9u%y%%e)RPv2o2F;z|Yg-(`Qt{CqM3ndEo=lw@C}T zGIQYg4mBJ?BQCp|2VLK{uqjr{YRf_$N!>ylKaz=5UDf?ZwnZX{fkT4Oa)tamJ}SeCqAB zu;*GH6gKDJ@{l5U?L#fj4z7Tqg>|?nWxsU0zZ3Rcsg;h@4Z&AuBH+n0I(TuLfWO7o zN-r*Vz$s|Mpu#q(xh)Yc?Hrcu-}Ym`>I(RFj|P7ATL;AY8KlF1v7_z$T6q6#IW!Jl zhMPx>7`;Cht|=_&v-uqSv_2J1@2kPJzdw&fqkADKtQ#`7Xz|H`3^=;Mh$YE4!EiJY zOIHj-Y*!VQwJPEE1|QTdYl7xK4B@Lcl5zg(jre%z!uYQ5s^H*>e)LWK0O!9NfpIfU z@Vp@sUwP*amfz{dEv@NN(hdV0&R4D4*&=Ywl^dx1>o9aBT!yaD7-)OT1h0o_aKjhv zaC4mnK1U@k3$^0{Ysae00m=*ehj8(dM2u3jOX?U0s{fcQEsApB{N^0#sXupOOwwtr zY*~k8rZ)I`fd&Kj_Mv^#RrsK_U9#G$(C3L7IAI#VrRFqg`1ncm*=WJQt*y}Bp+d!h zdN{tk9K*h8f-lz$!@r)7z{fK>Ff1q-qw#I2>36-b^fik2b3IrTOLK)9@v*ob#DFB+ z_#gACgYOA=p-kY?*OQ?yLxZs=j2L+MI!3D6u&_yw&jtugHOD}gr4}5Y*27oXEzoZ0 zmO8qIq&NENr2c1nA>^(aRfkloR>O24_IdO@7qEbZR|e7BM%Kq)X=Huw99>B| z_`Se69ti9!f6Z}^@N@ka7a~BRrFOP)LcZoIyW+~*Q_0s$Xz&B_b!OCf#uoF(e7JW_ Y*dBhnvHkqj$e*l~ZJGS88h_LLA0s$Vxc~qF literal 0 HcmV?d00001 diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/resources/channels_last_b1_k2_s1_d1_SAME_crelu.fb b/cavis-native/cavis-native-lib/src/test/tests_cpu/resources/channels_last_b1_k2_s1_d1_SAME_crelu.fb new file mode 100644 index 0000000000000000000000000000000000000000..a9066228228644bf6bad903fcac6945298fb442a GIT binary patch literal 4320 zcmb_fVQf=X6uyJO227l@XoHJ3IAep=RYWt8-dmv2Mh96C-84ffYuPHLE$g=6TwsYv zG!ilwViu#zOo%}wNC@JD4IAahM6zEn`Jr)|XfPs~pa?Ro-}hc`*OqM{>?Gg2@1A?^ zIp>~p?tS-Jg-L~CHePZ?z9Xy~q8hwufcI8d0(vmndy5e90wJ~mRlxBXLhJxcnky;{kMe}q@X9KR zg-w*e6Nhj@WIp^`B;3Ly=9&|xQ-2&72L1$21K$BBfIc7&bOUs^4Xpzx2Dn}~a6drG z0l0=-pgs_4L}T#SzYVB~gq!C1OP8*|SScUi)%}r3U{(%wQH~(#*67%br=g2)HD)*C zIhOtmp=5`CG8=5;012TsVoke_3yFjqEu&>;AK-c|0K%Q=8~UTsbTIwF^?`6K7zsBE z(F&a&@a6;5N8L7HG@bNoEMl6YzF2H!U45jnwoV*8mq?&wg!3}vC}#0}BhUjJ0CoZ$ zKpW5kGyoxhx%8st`dFu2*U5d&58d%||DCkYJWd6Wvl<^Jk;m9jEE$xtWYxYK-nm<;nB{@z| z?=}=zYqIcYvjp$m81MsyKo11#01FVucpZ=nSg)A%o1I4;8_JMqhrH2!v_3bf;$9j>?yRY>m$aUV#*}eInQ$} zE3p-#XzD?P6h^YYQ}Fm>i|FVY5erxVXEUFXT;Ak`N(Q&NxDq)p^G zoctc~$|9Pv$KXeX?{mQ8orX5kzg<+bdldN<;L)Ga?ls#H3;fv__x=jV=yh(z9NwK; z4SnV~*GVvbY)8nvyti?+?6f%L-LH1Zp}H=$DblWrTbk5^5Bws_mU)%qSd$D~iMlG@cgx?lTh&TS ziQJI8OTDtotDY;UlUpk7a-g9=J#u=c^1O3FO?%6x7ChV~-;8@6KUOAOQ~Cy7uMIm@ z<(Jh;os7HQ-P)?|`Jr3xKh&ly<$G1p-h*=Cwzw*JVZVC0u!}j=vDEQn+)29Cbw1Bz z8)rLTHDAvYHd8?N)9Z>hg5j5(l;}5do>@~(;IU2_`}CSJ??Xk%lg=4sQaQ`0%R|sX zpRWE67vT?G&(F-cv47xj3FcusYH=zgbB*fYI`~I#&df9FoqLk|k?UCv$d1~u9I{ra zyq`|>yuEo)+K#$izeZ0>``TTyy?K*tFFx9PW_Fvre911q{(ZOn;MsEdz57V-obZs` zaK5XzVqi#qyvnJ>>rJl7-ww*dCELH4_PnF_$Kk#sU){brJwH0$*)8KnqAZD(P11aw Y4`DOdsdJ#l>>AIsn5_%chXY*We}(63v;Y7A literal 0 HcmV?d00001 diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/resources/cond_true.fb b/cavis-native/cavis-native-lib/src/test/tests_cpu/resources/cond_true.fb new file mode 100644 index 0000000000000000000000000000000000000000..003f7868a2d6be3d5dc663dd64b66862b69470c3 GIT binary patch literal 4088 zcmb`KPiS0K6voe_Ni(%k$8;*81erw^S%fx9C`Fn7?n*>l6^U9TcKZA7dw2RWzD7;v!1vC(=iYbj zx#!<|Z^Qyh*(GlJEonoRvNSNufootLjDT@q_8hZGcX!qX8Hvs$+9A0imjUq0IkVlf zW}E$H>ApbMdE|4?n;r6C6a4h7*$!xdc`yp@4w;RCEiHyN|7f>i7YDhtz#iBE+u#n^ z0vkYfHi1l;hfji0kOHs$cz*xF=bx~%St5i^Q7RNG%~JDYFUM5F1Y4+93RfCyrRKu& ztwl|SbljX@Bd$GUuXsuq=nU8aH-YxK_j|jYg-?MAAbT?)4LH7A19KFr@WT$5i%OMR zIloYRGhZ%S8ee6RNtfa!z4AevPQ#>y+-|s- zzWOpU`Od{TkTDAAX>vdBx%Sk4ijD4T+!VqEe?I+Ej z?pSa4z&_fDE!OiZ#p{6k`v>i|YU$^YkK3Hr1jrvVI^RxV(BdoM0?^nmm{d zu3qnd$tU}(_xY3yb)Gc-I)!}T`^Pr=dYMlLJ+QAkpDt6EisMzdV)XFP^`*3AE$}_~ z4yZ4E4Yt4rSO>}%<;gr;J{tuoARl=0F-;;Umnw}K&02FlUpl>Mu2i|vgM6CP}6-FruV_l>d7yP_Ie?o$Dr$%{k7Vh3hQJ`w5^krGffw^jr*M zNj@x;X!ea#wQ?Jug*vB@OP_RSK%7p@i6!;qTE5Z9E!V57OUriH13MMp#&W)&pm&z>N&NZ57wN1**Q}Zy^hDo-8O3q2UR{CN zWok+_aTT6~zYf>@xXu5czIsviw!j|P0oy?D@GYeOBvRbVui}bjg*AF^I-JjNul$iMGp&C@_2K>R>&K5R!zAR# z3&GG`m80&xTCEiIaeCkFch1?9W9;pDggsijADNdu0yGfvU=()Gxcy#Wh1<{ftx4IS|K-32S<$hfPxayV*9z6LwE24L__FiQGHfp% Vqf7rpI*Z=$NwoPo>*$H}#K(F09321v literal 0 HcmV?d00001 diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/resources/identity_n_2.fb b/cavis-native/cavis-native-lib/src/test/tests_cpu/resources/identity_n_2.fb new file mode 100644 index 0000000000000000000000000000000000000000..d850a4483cdc11cffafd20a69cee8ab01fdcbbe6 GIT binary patch literal 1120 zcmbW0KTDfY6vj_tOl*}B!dnQfg?C6HLnsOkD#16Ul_ErkP92IiS}8BY82pDqa46C) z9h}52r60k;!O@|ElOH9cqv`Lt$&I$Piw7RgJ@=k_&U4Q>H*a0Z+7lKj8?q70S^&%{ zU;oAz>oh?j{+k@LSYcK!X_vg}z%d?$mUE&d;9)f899Ehv(et>rc zty+t_Y7l||=neAm(hK3uz83em(X7{2R$tZXbvq?ynLWi)9Tj&J+!d4Ue?&QJyYZ$D z0(x^Z8QpFvP6gaIK382YmS1js+Fqj{WNKe`I;zzbS~dEDrfW&nAPdyr98eAZ>_Ltm zYNFE4GhniEthk%dy#wIyzvaGSSJ65zu-02O1JRqs{b>GaHJWPvn5nNHM$^yRGm~fU zH0|Og@!I#+>waHC=Vsns>6$J_+*qPqIeyyRT*aXU#0h%vOXHbxuToz9%wV%PkV|3$8m8C`W#ieem#@Qr#TWoanW2Tx8ezZ G|DoSQbAx{X literal 0 HcmV?d00001 diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/resources/non2d_0A.fb b/cavis-native/cavis-native-lib/src/test/tests_cpu/resources/non2d_0A.fb new file mode 100644 index 0000000000000000000000000000000000000000..2276c43420d862534a070b963f2fd2a3033391b4 GIT binary patch literal 1960 zcmb`IJ!n%=6vuCxv`MvSEh$v(P^3e!MG96a*yg2LI*Al4)j_PUNz~}G4ao}z2Pq=b z(ZQi(hayEBIyj1>i{K<8PEJzP&CzQ7{ohATLWWX%;Q!9Yz4v_Fd)~c$#;0y`3?XI` zrpt7jBrqliW`PU(z#uT@5Mv2X>M(IMQ8ZEus0}R%VjvF6S=+Ho=j-H2qd6|00eCxD z?;sXPf!7#*XfgEZcdZ8Njq)jimtY+{11n$|I3NeE0i}?Er-1V90=ehWBS%*+J&YiC z0>vx4cJ?k@HifFUp?cKHruoezo9YpYD7j#g{Kk+s;mS?=jcq@_QRHDT0tSHgOSx)J zNP(MB-Qfq33X4U@&gL)Jj$>X0KIKbw%WpquhAz+7bZ)$ zD);h>E*@v`H-$|#D?jnw=2yQ5I*&U+lBHj@nsR;s*PN8O==E*fOnJTAM}z^oHN|eI zj!`pa#!beHkJ{^<6~YEs15d$YPz4Ui0i9W^e!6;ka_HyK*PDu~oRo)KaPqL7exi-d zQpGJ+Tv>f;WedkbE}w$(`-1R63ZytyB@^TCNX?7p{ zHTTp;;8VS{*Z)ZOR-ru4Rp1t~^QC-|UQN<7|879Pi;~RCfdAPKonP3yr#Lg+XJ_HE zPg4sg9O*`01LBx!bHVrwT;uvDb>_+=T~0@=%hB)F{c~5W>UK80b#mD{{_Tu4y|QHu zX5Lz3A4et*KdPix-(R0tdhss78ugS&u|@JF3VG0K{)L1+)U)0!|06$hcBx=5sNZ|I zWA$37v-M)cv4}dOs$m9%wVa?H#p(BI?euDu9>k7>IbeePOq=7^za^i2KW(}*!x&Gn zsW$DYf4}#TQ<$ITh55BNtb4;cMQr>(dw31QwT-puUAYYg0ZTcX4ra!@^~zbVr@u1p S-RbN{bG^~s)QhWzH}fyyrR{5Kij-t*V(X%mXotEPNM2^1nZfyE-i!@0 zq$z}$omrI9MHdl@bP*&%7hO~kb>Tu>Xb~3*1wjg;XsNN<==1yTdw25k`l4+nCww{g zo_p`P=boSY-kU*jC?-Q_^hi|tB_?q|qyc1rHDC}J21E{EEUH>ziJ(X74%sesp%({w zfH3funQK(*PTR^ECBtoBd{QLcCvp{&YP}*?x<&Gs-w)iu=IcNS`1NU#4PX(-0JpH= z3~>7?ksCk~Af;qd*&acg_76%zlCU@>BXStBBXU%ZNg91cVQeo2U=6qotN<;5XQn(0 zOaVOBQE;9e=BRUpz+tjH0o2WWGiOd)&9VsPJU5OaXDDz<_c;!A=jkIg*&Wwi>f_i7 z{Y5XWJLSj{qY8&zS(|fBKPgyu|%E=BLbWwk^^Z(&yv4 z^Dvb0;~J8mJC<(CbIvsNpXbOI%VxzYRx59d+=rc$ka_J)q90_Z-Uq`E5o1;}EGtu} zSDSMMnfiTy^db$VXiZehwURlH_O-{{9^q;JHEwNImp_o0oQ;v zU=>&b=(7@#2Us^q9d%CHLtB_viLZbBG~RdU2AV3DSU;=_BcE4&<{V|{Pd^W?DM+Tr zG^hWWt1OuWIx#mO)V~4HChA}OqjUe$SH()LWaP{Vqg0YxuFk`d@74db&6qE$cn-tY zY4khM$C%B&RG(|21RKzL8udp%o&~3mVt+aDq|aKwP2f6k6}Swn0+cgn8CS;h6gX{4 z05M=~bn*2MU-|xBG~4l{Z_N3cQOQrE{H=CeY0t|_uCtb_ocl9K9#l`1@eH)T0Av9Y z<@8_nPuo5QzE{6_W3YhwXYDS7Gv`1{cP^aVI`G04_oq#? zy<_Sb6!99F>BGJ1_vg`zTh2S%ad7TEtDarGA;xAN zAZOfWG0q&Ba>r>0=SP9D^`+;=e>ilOZTg&~x$FA&{igepRMXn>oYPhRX`@vwo9E2p zT%jRy+}X2>`)hmd!XDNvW6E!*A=CqIq;8JqJ@baJminU3cxZ>+)>vr#O!rUBd2N^9 z?%nL<8vU($N$GaJho(GxU&7veFIe)7(`N)pbEoI*tnJYL;P>+OJI^TIlko0}cU}iV5z~FWKW*bZKT)k%2!(v9bocIv#R6n}-f^COAKwg^_daB_ zf$`WrlgP8`JUexMg}KzX^L=Ij-|4)WsD(FELx8o+qMikq zt8$xT#Q4c)E|9k=>9m6Uz+Lrk65j;aybKxhk@L2ne*>^n+f1w5Ht+D^Cs@l2;Cg=J z2u?QCUgte)Ux?3Q-eJQ4mLV6ynAxNDrQQAbV;=DL7S;@Ng=h4hGwN1o&^h@NiThqT Is^2;O1whM0MF0Q* literal 0 HcmV?d00001 diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/resources/pad_1D.fb b/cavis-native/cavis-native-lib/src/test/tests_cpu/resources/pad_1D.fb new file mode 100644 index 0000000000000000000000000000000000000000..228c4c95d1ed847436139981871eb0498225d2f8 GIT binary patch literal 1336 zcmbW1v1?ON6voeM(vS9D(93e04S!AWMC1>h>kffkQ9i)K~uZGi=>VO@7lfhe6-v8vTz zZ`g)4dEc^4+vfC!i?%c3Nx%{K0A7O^;5mpuAKV6#&|%cRDo_qTp6q`-x3zx(*Xdef zM2aP0-_6CXwQktIb&=8AtTp8|ZCVqb@SVvQ@Hq<<$02?WNV_!ZyrW;KPSth;G(iol zfmNUwZyA-JPR3Js!#Ih;-r!0YMV8BQ7nkxEuWB%hljWJFbHl^ofnMH(Z<}#mz7_V~ z-+A=#@!lXF!}uP)>vQr=*PsYbZxkPl!+4BKj&*A0J2&kn<=p}E>SwFFhrLmMApa?M zraS&g9;u$Ztq_rCEp6*EI;24>>b-)Mowu~0J1qgWo_vC5y64`qP8SA>>HZaaoSpfV z(ZAPTVW;+QpST8M-PfvzFH`I=zDsMD`OdsPeRUIk?%aepukitTl+(RTR}Ei0s%iV$ zooDSEXLs9|l6RM*(W~~>`D115&sERM{QN(z3h%*2H literal 0 HcmV?d00001 diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/resources/reduce_all_rank2_d0_keep.fb b/cavis-native/cavis-native-lib/src/test/tests_cpu/resources/reduce_all_rank2_d0_keep.fb new file mode 100644 index 0000000000000000000000000000000000000000..cc92e49f4b803af0e17413acd12399dc65fdf6a8 GIT binary patch literal 1240 zcmb`HJx?1!5QgWBWgB8emSrIeQMjO>prCM(B84NYR1vZeqRu5b;uGf&`GP>HDJWC8 zbde%ON`3@CiUyA7-SgTANF*f2dhcd-c0P9R9XBnPnr)yItzuQHSp%4LKpS*H6Knyq zxA-bfD_I#!vJzx-krjaVpa4D<&890T$hP%d{xTU=qiLJgBI~y8*e?FIy(6xJ?WCeS z3HSqM;1{?6=RoV9!68s-)l=LiQ16QMB3)1$VUjqVww(O|+l&0NeXRNqyx})#ZLaG* z(`z9g-}`p+ zN_khw8>f4t!8nP;fga!eCH?-ZXXPD6XCFtCLHGmx?Y1H=jOSN!-NH*h}@-S=@k43mQ>Jxz3ilC4omyvXTUXVz3#E1JBr)zM&u*OCC(sZ}Xk=?tAZ^d(XMwd+vF!uhw{_+HB@vjj1$MrrOj1V~&9qFb8TuJuv2-T&uK4 zkC}DUR35EJ`$e^PfyoaTGyjw^HSaU#1ypyUpyA#0P-D!2CyZI7;cOLhgE9M1cWjk9xQD$W*qDTOIwXO32MQ$E$9UG zU={RPZ&%*b!+D>QDk5ng7X>RWA8 zg3=r5{qsHEZT1%`n0T^3F))%G9QYC(V4NMC%N%EfdcJYGez=@d9_P#uj(@&)c<{*K zUejl&{`QSQtyMtEWHDXqn{WY?vJH7o<(mIS&koAV&=s}VyGk8 zKf>%T~+#h;%cWkc^xyRJ6rzH^{b z27_h90+P)zw&W%V5l$Flh~`6Q-$xNoGWL3HMZEXq~%Lz2>RrQgYCnIkFdm3XuUg|PQQb5 z3xD)@qU&Se4p;)>#Z_<_On|ds9Ec~~Nb$B7h;zT&@zZZLUHajVBH zZ${!C^FgHYjO|FRoo1xwvlY1s`7xyCd-hGsqoeZm6sY(8tG(|Yb#dwzm;;JC;?Zej zBd7z_fRtsWXe+Itt#pqZO{RK!z1?d!b@ea)72?xpZM#+*V^+f@?|u~frIa?Vt&_4t zev%zJ9>S(LP>~&LIKF`F^y~ZZW0YgXLiwf%L_iJDT2mWWd^N7RB7^a0+}9~x@|*K{ zKAkH{0tf2X#iK>^wNWosUyyTczb}7XXq_J#j`b&=1swnK*Xgu!%A3BNe%p(4y6*(I z3KZMUfpIVfP6Ay^0r7J`QvPyr?)%5Ky!F-hum7(6E@)0|GD$pDT|U|fvUvG%Wt*Jz zOQGxTJ}0HO5{Nqwi&vuy;qzw3*IDX3%Z5ze)W22!JO{b++KqI&ntonK@@Qka|KQED zxUG3z0t;XcOoJ(K1zZ9XK=Z8Wk&Sggx{Aqr#E-1`_BOGsyvK`QU82=WDYe~-0uICt z*(s?ws}js_PTvln7mL15`Al{+0{Kp3sREkwLM4+IZOmC44NBD)njafotF@Pz;y=XTdlSAG(pU%dLBjO-1^IwV*#Sd{}Wtb=k576ym|B{Wjcc zODSz!TS@gTJ6E%>oLLN?2g|-6Ti+NGx_E69?fIEoJT&9pLT*m9P`7^ zO6v1Wv2D2Dm2%oT9Zt$NtsCW*tKsW`KZnoPv%bz!`6XL2`MQk0F!|cC6h2>XlB+4_ zxQ$dkH0tL;;_z9ZAaVyGe$Ingpx8JGt^n~>8<%`m2dY6a`(rcr7vgO0djT5xrQG&T zk0hoV*)92?=f0`GgwMM=d>*CGuz1%Og1%s^+(2%o{rk+amp_R!r@;c42Qy#_Tm_ea z+cU+FUh1XtV;gM?@ndK>5${UIxzj79v~g`EHUF~jVLX`rYxq2<$G1`X4T}e<5cCCk zu#sF$aiYQO=eeVb{Gx+3UqdcbXCnL;qJeBr<8Lv$O`D0@eT?cR#r82U4-`YDfnvyI z;J7l*vG~%B%-sK8y0`6XPpp6Oy5>Vz0|o{MhY&0EdX5Y6xgf&nh7Q6e_yMJ+F_4pfAXisXJcYqy1d--G-x- z`{=z-4bc1Mdz@bei$L>!4(V*E=Gf&u9NL^?#@6z4|&tXMCdhdqM ziz~j)Qs+@NWX78+d>rLmIr@U!I8Rd+|qW>i;wT>?$$-*w@uqRl@UL z_P^7t4)-Ho>Hj^oU^Ve)`d;`vt7p!l^j{7yGI^FlpVnrn`hpyr3}Jl1_;ZjPP4VXo zNb$D894IQU(wvQhC9nYIz%-ZwSHLAO0Tibc&*Y0bPz{QCPE&jfAF%NO;N-4G7 ziUJN4_o{pu@i#NY%d4OhXwJK_FNzFvKfL6}cKJ+pGy?f9-yX~9oRZ1K8nVs%U+{R@ zvh+kjsruDnkfRsL<<@cjsedT`RXifUX|6T@*SV(LUG?m_^uBd@&bC==+h}ZRGsm8* z7%I0HpJRD3?tC7-E$k`E4NKGOK@qF#p zR1^D%o?|q=qMsqi8t!^N%39m!{qL1N7Xs1aJ;ZpekyH7lgkHjGmGnc2M0!2A)mK> zW{-aZ=k&bP5^U3+yJmCOhnIMHwq)x)h+f6;)x-j?E56OiGrsejQ*M!ClmGwa55qY; zkTu6|&_=%7vI?8r|E9j^>k*gTT8E*@=0zj#(-`)RHaH&x#Y{NlootM7nb z@nnkcCnLuqd-g<*N8-tNWMD8A=^gBgM-qJreS7uH(d0;?)#`E$y{5=k?Ay1SW7)SI_$6rH#$%BNh1Z?KiCDa~>FLOx zNFXj-?YVYHU$M3{GxC}hTQqKGib{^gJV5yyxwWe8ZgI2g(n|fcT*pqo`7e d)h<@+>uay2i#5#^tMSWm>z7^s6Tc`v{TD$=8=n9G literal 0 HcmV?d00001 diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/resources/scatter_nd_update.fb b/cavis-native/cavis-native-lib/src/test/tests_cpu/resources/scatter_nd_update.fb new file mode 100644 index 0000000000000000000000000000000000000000..609382182986e37462b3398ea0db1a2c8177de09 GIT binary patch literal 2584 zcmeHJ&r4KM6h5P)Gv;KISPddFi=k#=VBx}ry{FbfLZVP(5M_MxBb%ItVP=FOga(n2 zXyGPsr4lY8Tu8Vo2ot0zh&B-gfeRP*FN8LI-+S+t48tHAwCTY2&b{~Cd+xdCeD}`y z#1gMGqu~^{c*HABfJh(E35)OL)QYkE|EcC zsYYbkDUt`;0ZLG=oBO)tin+H>TEs5_jJHWU>>Y@rOSE?*t(MN=FA!jY3eHfrQ@@k}@s3s0tFMlNpK8R7|Gz_MS9 z4Jj)IA(d>_4C|}*szx`6P7C1WAcSM90Q1cK#`cQEZs-o64PXxZKoj5r?6jd}oxyCB zeCp9pq|!+v8oz8Lld@*}q%YQyewmYUJ~3QYxUZ2Dpzhm$d>B)>w-i@bn z@Y?XBSmYk3p~;_CXx1kP7Lk_+*kfp#Qf^Gyl%`Lv=h@oO@pAye`mY1O>u<;PSL+X- zUD)HW^9mHwW{kK+!xGqaGxze?e`pWnTr)@+NaQFKMm+39Ug7_ zMMxW++|s_yg|wla9c}H*svh1;YB$D0+D2eazn|~bJDo92@158CGEMsN;al3r#aVsg zi&t-)xUFr!4`>TjL2cn{{?e-fubz!gYu_LC>)FR+dgNR|uRmYV<=lq$=Cj9*J;@~s zW1!%PHAQKgV%7r!i|R-1c~o4RNQE!59l^a;=bpu6b?#@#2Wzox=Oq9e*ZH9J!he*@ z!T$P?+Z0t7rByw|Z<;IBr(M8apZOA!@%)ume6BNjN-nB;4dZT*8-wPZsdKcT&3e%e z;kE>)IQPe?c%?>zrk p@Q#y*2j$tJ?^mG$cPreza`paKWBcoOSf0_xO(T;qBFQ*2{|hdYh9m$0 literal 0 HcmV?d00001 diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/resources/simpleif_0_alt.fb b/cavis-native/cavis-native-lib/src/test/tests_cpu/resources/simpleif_0_alt.fb new file mode 100644 index 0000000000000000000000000000000000000000..4a7e751c3f7eb16ddc5de7d693e363708a0dfd39 GIT binary patch literal 7648 zcmchce`u9e7{||a&RMILZMAC15Q$-eSKA*Y78hKwKTJkRNQ~Td_nMdPw#)8atq58% zV#Ek335oq7Au&W@NPk%P#~}X@kr>t=68<3}Vr3=$w4ZOE?|I*IcQ)Tmb8kKHz0dji zoaa2xbDs0Q?`<}YR+?25t4*z`GxesyG?|su_k(V50W^aYXa&aH&X`(Tsxix{iI!Fc z5!os~OF#pt2FpRduQ*uVxxp)^SkIWVcN#wZjbifeOyK9`KUfcUYfL-y_hlHI{x((pF04oJrqAlu328skVZ z;s_VWXN62pZcC<6FyoO-$x<9huJ#=#ljc~`x|yMTNu3{G2!cD{XTG;~e^KREz;W;) zI0*KGF)$2d`vQ>Nx5E8B?V!FHIMSYEd!aWs@LUeAHm!dP=;`mvu74uxpe~xPJ?x$R zD_J@tvUwAgIyFFqsg0X1xYnBfI(a{=`6iT^lpo5-lkIL+9>>(|9?10OwgJj-UkrlV z;P1he8#a9#o-37y!4y!;y0f*9@+cSvWg!1{!lh3$s0ZEPp7w`7I@EaNh)gZN)`HUD zuJS;pN7E|)t)p!HyE{`T<#zX}&(T7CZj9qv=T0CWp8(QX5}RssiXXJ5^+S1J(09Xp?F^qql3NdC?~@?T{wem7rmqdG11Zn~8i4jHJ2;ZO zIGUU1Mk1M#r9Dfo_8lja=2+6Y&i-485#<)Gzk{;Qi^`{d3iF=yJP59UOW-WfevgAm zFagGa{2+h%dmN>%yW5d$wT#1|d!+PYzTC4XwBH))(y0;nXJh?ScF?8`wCg^J>7zBI z>-D>FK9@c^_sZ=m7ua}r41&j6PLF8>l&dti3n=gM# zKwfbmm(BNN%F4ehiLqMveQ=q2j2(JCL0LA@Y z@NeQwk5n`NGG=DTbl#nq9{#<74LWMebx4JC^M;e8$n)@`R(%r^_ z9~cG0pbT`6cEV+^W>62hf&YC#awRYB`+)L+jT0%gGoPhep@sS?gS>*)(Ki_J52JR) zQ5UENL*(oex*O7S*?^X47KUaMP}{mTof- z^Cci{$Lgn@8I~F~U(Yp@U;^w1V_*cRy$EDv`K<#kzqNn{&9!SR}g3~ z@>fUXFP(sY;jbFaF97-K0N4vgz$?J{YCC25OZvA0`KlJQ1KCsegYIWe&|pt)XKAoE zY?nXuzWS6JB64-jL*%P!)?Tb2@rayU^_{13S`Cn7?xC!2L6cE^n!9NqW#3=T)TO7+ zu;%zAQR!U7b91T=-$hhw)FQF3pt(D*1waHt5|tH zmn*jCd+iw}e1p&@HN;J|J&)mjH_I-!=MaU`nvapWJ%|^ z`$G1AIx~>Z>?-8e59Bgg>W&ubb7Oq2J~pXPaI zwhyJ>d2`)fH<6n(cfRjBtC8*O0rJH&dLJCC0B{r_EWo?CN_ zHDzb5xzKyw8E>vza~1p6T*>lmIqMzp7UXN68l$+P%5z_xjYFL^B-xkU6 zW9dEC*4+7?_g}Yjs%FLCr_9lsm00nf{f@?r|C{l8u2FouXO1soba?>VD5l*VC$01! zUY(6}#s6}vr}MY7pZutt6PG_i{Z~|F>E8pSH%kXayZ(wQ*-Ltvm~No;rH}s840pQR+7&6Y?$I>kB)dW_ zVJS|JKIF-U+LOz*8hy&{dD+q?%8DUp3-yUUpkKR_uiXEj@`2lf?CtGgX^MIqaB?)) S*~@X7pc=-^+*331Q|2ECAxo10 literal 0 HcmV?d00001 diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/resources/simplewhile_1.fb b/cavis-native/cavis-native-lib/src/test/tests_cpu/resources/simplewhile_1.fb new file mode 100644 index 0000000000000000000000000000000000000000..c4fa26e2ab10bf0fc1ba5c4eef5ebb57e3c281d0 GIT binary patch literal 12504 zcmb`Oe`sB28OPtWS(|oO7rS(eU66|wtXXWcnLlc&c-^v*l{m^;q?EEIORnb5n;#^% zOIE^mg&;x7D4IWlgb~IVWpt1~!ib|t(Ed?Mk*<9 zoSSsbdElG(yzl#*=Y8Mj`My8yIVVlV$rke&E^AGlNtt@nU>eO9@(bVy*j{5y6KRXC zfia&%tIoUHU^bGItgk8gqTD*r0BXTTkQ*2}Ryg`l=(^QxrF`$_jX8zUrq3C3Ep5zM zIx<9Kmp^SxV|5v$bb~M*lf%!$b&Xuz{M|62F5@J z)PkiQ)FJQux>O?Bx{W_$U=ExDGvFkc04g5>y+H2jAZ-Qeb3ND(zH<5Q3-|uuJNGag z!_YiW=X$dPgjp4*tbk*=@FlLy)rC{pcmOC>oF$6c5g^<7rBW$F+6His z*T*!eVwwWdaUvUWayNyT7veY88DH;cYajaM#{U5z`=CdIL$cDu)NBdT0?sZ_e3 zbkDCmzmGy0&(D`hFMvfb4^D$wFbz%s`MryDH%J4uueiwfPBbR@UfdqYjunpP3fbZA zLT+#%@W1#w+Ly~`cRh*56HWn^2XZ?)fwsL;+v1SYl>Q*l+vAgqPwJiIA(`@J2Xd~B(r(%5;EC9v&RWJ=sf(f8_%Wu0u z8YtdFU_WpoH?J0NmZ7+KI}>KL`m&;bp*F=?rPbnm3vOqqZv!nrzDxs+Aq6xaoXAd` zMCW106>C$rl9-3n*h@0b@BF+}(!6Q@y zGw&uFw?}?iDv3w?NZDS^53gB&YmUjk8rK|{1ykSzD1tGdc*x(Kq~ekD_aUzDZ{Pi+ z*SEg$hVyqFWh2L)DGYb_6u2z=UGv+np9hXTqY7WDmAiVE%D%XKn_z3PJwBl7)$`gU`46 z6(awOD=Ybb{%Z36ulsGX^I#wQl-9_nNOkW0o?UlFH8+aj444Os!3>xJlb{F`ANgPY zZUt$e`?|nk;Ka>^S_-mbIo(#C2Pv-YJa~+nisOT~##gJ8jj*TGnQ+3`Qxd8Iiv0nUL%Fb`a8 zPjao;jswN@AgTJH@yJFWxZU%)^lj{(@K5iRv0bg6+n0%AE59l|cvqhXET_TgdnZ~gY%XgVp?@6RTBy&X2S zcq3@z?6%ir{LZ$U4)d_R-m3XJ1FnF};2e-&=D@2!bti#hnJ4W8yMcV8_O*_Nq_o=lIq`?&{lDbfyi@?eTtE2XpFxUyh8~K^lmwDWH3t_-0(ToJQBzCEuoO-O2ou;fna*NH+ds3D_GN&i3Ydx(iwQ`W9=1 z;`9?z&CQSAx4aU^#GR{P39QDeu4-v*Lehq`>5N%3V4dr9);Ci|rP()?F1u3L3I4i>;1m<5XI2~Y%D z6Y@ZD?IhK>(x3qx0$08~_oc>j52&gfTg5S-9T{-=6KLXUol|LsR&(^=J&Hw6``%N~$9wk4#d9zeXks40m9sZeJGi9|iyD{zV%)UYQ z7eKYRsje3&ZtbMn6BV;Ma1h9!92>m2u&o^H$sQd!)^E$zXU$K49anK#qUEwf6E8|Z zvG5El^+omCt5tWI%sCYxYNL;9%}w!6`58DlMcM{hfVieLME%xyr4tsJQe9N$*#{Uw zO+6L5M|hD=>Wt(!zby58UoP)K*NI}*o9oXGjPUB>I8h*e8Op-_Op}l6r)!Vu3emjF zjXc&rRCxX=&;8_Qu-i_$7Q1RhG0KfRI@mvy&uaeON}#{u&9}4Jq4r$AH?Q}w@5rx@ zkgC1of4BRDVqFC1z#^Cjr@#!D1VvB)grmGqq{+*7U7!~@k?+J&fj4*AVN#W;U7ZOm zKY8x9>oLEULS&^n)ysz7+eJm?inY=){Xx`LBiE-%-8$2LYwhumo?oZ2xtBJ?`5i#~ zUy1uvSBU!Q)|qSAI!L|^xOtT3x)Fruk%p+UIFX+9ul*|>#X0;;sFRpyD2ukESccmn z>-p4osj_8}zs1#okwW)C!I=^E7ft^GLcRBN^73DQ|Y zY2bRu&y#rKxFDZeE_l53I8u!_fv(3J-%|MbS7muwA#|a3#d-0iiA=o;ko8~=907{S znGciaOFePOU^^+kbR}Rf%9+>LzqC$$@Nc_^i!=Lu&UAddbJJG&xd;})0=S*^qzS!x z&;`QJImHRti+i4_HRBuI!@2Hf__Sd-+ucjv$x3xjC!C57nk!1xXNkDtcyalo3xb3(9oduN0y>-zFKn)du%?VPZ| zub=SeYAyNbIU)a_apv@K_eri2S0rg`gQN24Sq=>~q@@+tN6zeq5TnpRM z5LFf@(p&AEP@tW5^rG#=_1Sl4TvOiJEPPhbQ_szeSh3K2(7NE>fqcw$VN^E8!6JyR z3q{JtKpu#XoutlJ8p9#b3!G>SZoi10TZ8?g^2}PTTsCE|8V}Wu>QDNPQ+!*wK27S@ zhg<(ko`)UyAt@eCVM{)Y@}rS5@k?WpXo%u^HIeRWTx{TtSoj-Fy(3c$Ga%Z&%ERrG z_4=SRD$#T7VfI1w_Xg`IY*;5d+C#*j=cv=SY2uvw_HYaHI;!_FdO9nu^!?kunEQUg zee#{&fmwf4d|PK;h`DDodx*XR(HgBU8D5Fe*@=$6RnY!y?b>fneioy97+v){(6!%3 zd=#VmU38tVt=-@qezUh8>#sep zzY?Rn6J4G6;?8fEV)VsBopXY|xO2u9)*MAsD}Xq&(R*ha9sA4NFO8sYcF!o-wrcX! z4FR14DDs}ih!a{bgSZ7ezhGlgBhE=j_-Zj&U&$Ji_WqXFxgMJ}&D^WKCg=-upZEF= zu+xo1PcS!fQV?f}ZcUND6JR@t_Ajtw{oUt{X;Y2vRlnqCUuxt7jZ0$p`5^B~T)peB zbl;`VvK6#tf6wU+bz`L0S{v|xjrPng;O4fz<+tP8=*=DJg#DL6S8bFfZLi|`Gr$utOv$?1g{|aW^$49xx{?SU>(11@iY2lN4jb%)bw$bP2S41=>^4$OdK zUcXKzPOTW{x{-NvYlsW!KO{X09m zI_~j5ABXQz%Z3fvP$-|}w^Z$@kD}%JAS?xwfT{EP@cwu@og`idWcRfo0W_v^pnlpA zh_J~8iW$l7yV`biJk-|JWtx87mnm7*Kys@)fltQRLM_uUFnp)P1I> zZCA(RfOyX@r_-9abF`S=S&D() zb7|a|LU?SOma1)BMP7Zsjjbo)(U_$d&1nKuf^u*@`cD2vy!swOUYz>g{wwLU#<}uk zU*Go;mp)H{6W|D#0AoP9mM!!H>AM9Q)^|7QCSXG$`u-Tq5i8t^FpHHhFa{B|c|Q=% zwXWx|?Ih)K>bd%*c=gVd7Z~A&>tNcLy`6p;Er$496hC=x)zopW* z_NaV%KZT7a;fYi43FcEWwT9yG=iSJQQ}3I8EuH=heOvOXr}x$J?efc7*OOor3sP5|{+L23CNwKXyszdq*!nQ`g!_>1YZ_N$Z7PeWh$9bXTUe*hGsher6sAJ)Uo z4G(EW@lbXt4K6VMvydsDt2luzsjvaw@e5{!WnFa!pG^rZFB3{`&;pc3o^9HAG2 z)rD-cqpved4;AEv{olhuxvjidx#hD|ZOMm==Gy0u{Z{;XsbJocw5fTJZ_r%IE;S$J zKx1Dlh{v8^Tai~!`g-NL=huzjPN(li-le1dI+O2{4rjp(I0hzxbbJsDfqo$U%D<@} z8czl220t?Qp;I6Gg`fYZ)omGho$VdFdpmodCAzTAH2$!D9}V${PoWSVpELVYZOd-rN3VGtDjmACqQ_ujF27z z(%At}tUu&&OO{^b55qcHteomuLGveHE}F|9o;ni0J`&%F*B^ExuNeKh{%{lbMd|qB z@A~_l*5MeK2N%F;FaxZv$B9eV(r+&`Y_rn+ZNLWkpk*8ac|2zJ?(gi~vC}^%tS4Ox zWar_v_oT=3p!e{`)vsShg&-_eZu!Dg{gGes&b!bqjG;AHSn5wH@kwZl;@~VcpM=J_ zK5Ln0$<(}61Ffej5VhvXjtuIX-5*oPt0#TE@?3kEjX_>_Pe*&_j<((o`t>6B)K5VV zL)Sx}_^xkfvN_q(6)+FZfmv_@908MH49L#=6d8JU- zo*LnK$Y-I{4~6hpUQ5-c-HUVWUxU9Nzx|vGWNJR*u%FuL;`JSV9=(Knp!#+EkNk6^ z;?nahm;v%7*P`cd{L(F7In@5-vsTY5so&MHcki~2pV`;e%me*3XD_fPnI04ju1*Kbn+V1^>`1L#!$c#(Rhqznv)tY$`dM9-2-+202uDk<4 zx|s%(U<{0aA&dg3e(7`Q18D~nFQ~;RI>)O8tVA2v)&}dN@y^fz2~n)_sg$LgEAQt zn1`GPr@<+30?5{7(?d|%w`{oyRDe-X4-{l)37|FP?q!@a_sWfAz7zF#LNA|PP|a&? z2)jCZzS;p*UX-ja2jvud!EMVew|rqL*>?ir5K%f5$f8UWWl>SG-K`;Og0O@rnXa&tcwxa%MY*j#Z zl2m+=@86cf6aLQ4O;el|q{H<^`%BduD>X zjnK8wGp~8`y?-6?eqh((6!CE|0)~K{f6ce%UHwUbN+5km2R3M~w6+9yeYtDWTR)+b z>0;%!jWAVvvKQ$h-})N+v-tIMC6F1Hem0{I&E@sVbJy3}>2zA(_pgKA4m~;R>-!Gk z()TQw0mr~37z0CKF@3kgV?$WqeB;y6LxJ*2f$J|H4$5ua#mdDeS$ZeQ4a%3@iVo<5 zpf;bbm%iVV59|lCU$UAw5@rm z23iOC)&Rv>p_G*Mdv)kU^Xb~tjrMyJw401*_gd`O^?T2ASJXUgI_uflYWcukAln)T zuHWk>Ph*#zH9$3&s$T(`6kp{0D%#-lId401s^d5I!eFwVaPts|<<9^tGzPev)J_U|}DIgt;fMGBI zQb79A{A<3|p9H7`t)S5NmE`yR`F6jV_xFTcWmL-w(uZtDG}ryG{dD~LI1tz3Y1D4S!u#5$^{xU>byVJ49N#JpiQFW~ln1Iu&3iDD-)x_#gBT{Q5nJ%wqH*KVa9|74~WALjB+H557LM=2Bo9 z$cIdTQ7{a0tuw2SR(J~4$4B7Jo^LXOkgv@0gsJ+kKIGEFslSb14+-{#xb%=>|I!|* z{vY~#|6HTH`~e`l3hQ8ybPDu=E}(VM2-UhsfJ(3h6zcD^?`3=%i6G?4x8;_q&$1cO zTz;;RXW&W7>p%?%_qCj~Y|;kFi3;w%GaAU0tZE>+vH^|Xl}T1)2(N{CCzlv*VH9>I2$K6y#r`-aSf?ukncA*Jt=` z5dUw^vh!gKLT0_$I<_R5st+2EY*g_sl-3$3PJM9SH)?$22IxPK8Jqr*7pMM{=Zcq? z_5KHUwVO5PQ0e&io4#GGAuiiF2Tp^<>`L>M09q$uyCRv>wzRIKSJzJTo{`pC*se$} zgnVTSOreNuNn_NQ3bCs(_N*klar%`jfy}tH@6KgMq_)%f=$ZN_CUeaMT`ULNDlr+K;bmZ#U%iXQ++z$6$2(y8{I6p&Bo z0@7IsZB8evM7*=rouGc8n9m4}u}k52Odp zdsq+Mq}zcFnyXyTt+bAP|LlMAbm1QN*;FAsYD?{WP7x<4Ut<#;)CWQOai7~a|8xBM z8b@X^>qY&v_o2(k(!ADz8er#q4e@dyUIh)Y%8NyRh(A2%;katJXT>gTqJez3UCn8@ zUDZ|F!4xZ*6|}a(&+gy7=>P4PzT@8jz1VmB`j%hc;S!dq553?Bm;|F>7z~0G zP?_dTcB(N{1Jx-5%|LyXZRUF3pYaFTPyO9|$)Sn7&;9j`R{mkJ=l*Be<1~gkVD&70 zs^9K>ljV&;Wo{Vuo(`ZxvMMlcWFUpHHOg55j-C@47DdPujPk5FYVLNB7%GA*y^m^Azn1 z6j@4K-+fO&?S41>cJcG;-lH!dGcNyGhi;nbuWPSS{!@8TLF;cZe>p-sN%-7$a;=MC zO0$}CePz!Z>9o$9Yv+CYl0H&k226t~Fb+n5>@Ed*f&5+*)b17PLn~+pHmLnv`*Y^q zysIrMUzub}ZZ3P$nM39I?CJbF*EWxcyDGT$^uWKzXHWgeYNo|HVEuH0c-;2%Hti&l z>$W3Zy8D6VRB(f9Psh1O%J+Wll5bDaM?aVavZ-T0`{@`M045SY}vOf_YE&&@|9WH zVXAiI+jLgT_kLIZe_q@Ch1;-k*N#pHGHYR|2J$OaAWA1>M}|^z?VX_-JrrY??p@$L z?w_*DSKdvh)pwoUw_NA@xtimH;53*8(_jjWgArhDL_SA8rV-e&Xg<1wG{g(io5rqw z`L^J{nVZjUGQ6JMcyjcSD}R9@6p?)G&5PN~)PKcqFJIldnIbM0jV zcgOqa-$tnRpiA$2eqpudD+Q*(5m1Oy=={DneT4^{sYzaF}gS=o&Qazm6q>p3Vi40YkTgoTH{ZFX>cR; zCi}4KCtGX&TF6yKwX1^kCBK?${mi`=zuu~u>m+T)wSL~?KDrcnFaDo@zR`G6Ulaew z^?@!Rdu)V;b-9Ifq4tX67L-I~+!q=Fsim8uG&Sp&?d*TtU2x*@W%3_LU!QpJYz^ zbZzq@cTx5CRnCQYrzMh80c1bRxj%lE_+r2HdnuysChDr6nP-z%|E}x?&OUCN&rwh3 zNj+cn=P0Y5c{-xb%hb{L1$rkWJg?a`dABK|?nUbI8bRdmC!UPpU(5YXYfAR)x8?u7 z-@6gKkHC8me45X2zwB?&ZMYIKe(HP%e(5w^$G$a4w%Hz0?-lBOnzE1QsHe9L4@K0w zNWB{B=y_wfJ^S|QYoR*olk8%WJ@;zdr@JZTJ-ROS$?_j>-)ho%e{UzMxJ{3_f{FcDQhf-TSi|{vC^ToP*qX>7kzf=3LL|!g}EEsJ%`;6;bEg)VT|O?JM?g-G45E?=|@D zgh%%wN49x6g8v8buZP#|i@ia=8Bu2=KH!to(b-9FkT1tZ+%Y}{&pOI8YuoRQeWLSV zM4cC?b4Ns9?5?ysqRw}zlc0|F6#K2-FGlda1>f!Pxc#xGtdkLS*0Rp^E?j2)YqZ9^ z8^QYsyth)8@%MIB8FQC+PtNwYKMTM1jf}srtKPTMtr2w&Q%C+;_icYRZ5{r0DuVwE z{CaOr`+(2yd!tuFyms9kWUbrxb2EJkuQ+4g3DvV}?hE(~)ywqB*5N%u?_9lXiQ3Ru zZUvcYdeQyowhWDt$v}N~Se>~QcFn*6edtIM_U+)`afBEq@n`Nbc4f0Kp=Z}0JOjV_lQ~oR9%fSn z@3-KU&ctu;r1p7FPXzyQ_|=ar{@~feixK>n;Fn#gj+Jlk!>>p1uVcNd-JiG*{|>icgJU zJ#gncV;}z77zlI}=RChd-(^FYw)|e-=FRW2k|xTUfpnq%l?VA+>#_;qiz%NH^+UoZ zuH7m3X~rPCll|MZwv7Y)H*>UiAfmk_b=B@dNcgd+_Fkhs@n*Hhexg(88b>Sgk*tc6 z7Cu{n=9%X44f?M;TGcDO{aifZ@qZ~%cz>Qop3X^{Pdk31=BPi;rPah`%XYrCrp2#d wsp9eme!K417W$=Usy^pVL3eh^9s#pMt-gNe>n6!$^ZZW literal 0 HcmV?d00001 diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/resources/tensor_array_close_sz1_float32_nodynamic_noname_noshape.fb b/cavis-native/cavis-native-lib/src/test/tests_cpu/resources/tensor_array_close_sz1_float32_nodynamic_noname_noshape.fb new file mode 100644 index 0000000000000000000000000000000000000000..e2a89af8632baa8aab9006541fd97662fcd229ec GIT binary patch literal 2920 zcmb_eO-NKx6h33>XpLc}Hk5>*!2|;u&B8>Sk%S~nB*cQOcpUxc2w@TcssJAaXpH3?&;xMKgYUW?K_3Ew0OMN?lmkA1d#q-vu%?-Q7e=wf za9rz)9@pY=`C{s%t^h#2jHR1SEW0+^)kb3@6t@1*^(OS&)d#9bZ%gvh)hp3N3L0O+ zKkDasK77@6#^f{lLiDTXDGNxa;;&a)6^H#eLABIkJ*fN$fgw#2F?lQ%;`vuzbetA9LYhGO950k z*1N6pg>fb?mVr0GBCru(`XOUX{D2QQKQJ)yd@}us_B|k833?k>6i+DQoueJfZCq(Q zf;sQ3Z4EP{ZM04rI5|Tb$A_lEjfYm78=jtJ zevK$(d+K1b`piA8R?C_CVC9K6Vu_3O{I}#+)sg4oKSJwVfAb@OGtlm!I*TEv9)iv4 zjhTC%DBB0va%H@anaq^oA1ayb^ryhXm`nk+09-)+3b8X^?c{m+kJk9otLhf=>#^!A axfZ^?9&T}T#&xqeQb)IzjA^}b#qmFDfx8C) literal 0 HcmV?d00001 diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/resources/tensor_array_split_sz1_float32_nodynamic_noname_noshape.fb b/cavis-native/cavis-native-lib/src/test/tests_cpu/resources/tensor_array_split_sz1_float32_nodynamic_noname_noshape.fb new file mode 100644 index 0000000000000000000000000000000000000000..4e3cf40ef3310fbb80b21c6989bfd1023a6b5b63 GIT binary patch literal 3384 zcmb`KUuaxK6vk(hZ8ojj64z}F+J`_Z>_Z?`7imjKwipnlLeU@v>z~_XqYLXM?5-s# z)u&L32I+&Ki1;9V5XnQ3B9)fX2Oml;QbZ(FpNgP}h@fKd-}U#+o!Px!H?&PM@a?%X zXXebAbH01#Zoi3IpSg#JwCOcj(`WL)n1i4Qs-Pe22F7ed*K418OopCXcPeyEy)?*Y zjCq@fg)U?E(BBMBWsUj%4r303b6elE)BiCvVI+&ve z>My+fQn}W|;-}bC?sT5_e+xs&50TpQBcyZ%bI9^GW5z6Z`#He?Y%Opa`B8)!!=_}< zM|Znoz5v9hYv2M{0v~`kfpqggyi)Ary!zznw?5mk{Gyh%cy%YwF<#Zn!K7$q_3~5t z7>%AlA7jzY^__C5F`Zo;@x1AC6O+6-^~c8Z<{M{I@!bgC4E`DUxE^~+?tJlk7^;R` zM^c8&4WweW95sZk^}SDE3LyGOU+#%%B{T@#)4)rJ~vY~5AY5s?geg2-J7>sf0Uwm5a-=l z`km^}xV7Kr-Q3pmkNtA|l^7x+Ji+eR#@xkW$kVB-olbjNqTcXRMgkW(hwh>RoELn?x zh~jKxHE-8Pu^M4q@mO~H-8`$WcWm^LXI~wC{N973`)1x7sZO6B-GARtBYV#sR&1lF zwWV^Y=4I=B7_h@OhkEky)gG1Fb9*#)(#9gn<0O3_)9d&od(D_$>jYYtBH{N*O7EE; z_xoz0+ncK_$2omW_3y#QH1g$n=6X&};e*bse;Lt=jI9CuL$EQ8`Q`m%w%Y3ax*Khb pE7AAWP^C6mo<);0J6gWQ^S#BsE6NWp!NVSAr7bIT)|=JFe*rl`_{IPL literal 0 HcmV?d00001 diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/resources/tensor_array_stack_sz3-1_int32_dynamic_name_shape.fb b/cavis-native/cavis-native-lib/src/test/tests_cpu/resources/tensor_array_stack_sz3-1_int32_dynamic_name_shape.fb new file mode 100644 index 0000000000000000000000000000000000000000..4a5f83029b5ea710d03019ad7d8a78ff44972605 GIT binary patch literal 6280 zcmc(jU1(HC6vt;{%!gl1HELP{4H{vKfvC+xkx*7DtrQ;uhE|HLuDY?I>nfYA)l!O- zQ2G!`ksu;Lf}};L5256t4?%noQ4sM#S}8shDMiG#*0$DKo9*vE_fDL>*_gPy(1HJ+ zd*|cKnKS=$X7)Cjpfs3gd8jsZW`Suijlh^~;6*S5nm`;FvxwF@`&46Uc~hAa%ZJ6j z@>2^=(KtBYn67$b9@ZH1O|>xtRlL*oG#Fn_{W4=VgYl)tG=qUgV;aD@CC01)cNQD7 z1GIu+25kj#a2*2&WhCY1Z}~i8sa|EPi5YVr@dLODE`Tv`8jJv$nnr4E%|PQ8Chm@% z|6=6{9yQ#|1x#sjsZ8I#Y%-fo9(q69n@b&dA!uhQW+mvh?HknohT7Ys6|XZOEKk~B zb`&eBm(8-FgO@mH1`E_cnFB_F_D_A9@J|BS0ouSC&;oEXyB+jF*b=I6ualB(G`*Sq z>Ey1|%gJ=wq(hz3C07 zKCe2h=jg3`UOqX2gofrzq}p2k2zc`p&%;2RxdX0)D?q#%2aY@9&3@i}-b_5#Hh=u9 ziS1O|xuCjP0zBUAN#@>3Wg(7i(VAOR)+4Otm`7#vhEOmdi5T7r%E1)mRnKQTZdAZ~~QgLw^sXm+T1w2vK8wPdi1a2Z6 zS7zc#Dfwj;`aPc5{IY$m&lFo#K$KF(gv(*J!lHWRarxcPGt=?r@Sm0E&DBt6rFhfI zUM5&iIr^f!x%W#x-+;~uB*&YnA@|5IXnSP@dV1 zl#NY5TzspiXW-0d7j8&@Y%b>&(#tVq7h}}c3^cdeJ?^Yu%e!*8SI1jWrmHW;3(cW3 zUwKaLY5h<%XEo1S&$jUW^S>(3haW?o@=YAH0j;?a$WO9krlLOX-08Sq*>k5KeF^lJ zW8b6m!`>@4CUeQ|L;F+aZO%x=tt2u z^3jVMo}t?Aqqxbn%I8#H?+5m}GkI>Upxxuolyhqf1N>z4r2SD}P(&^9*z%f}>GIC~ zzbntP&GKxbO5QgxY=y)@3m&5y~8&V87Ig{ceQktst6@S{_!OC%I5(rFe1< zeVweKRDDs7H2zyYkF~eEQ|Vk%5`E>^r&!MIBR7?)NBru^?%Rtn2RPG}Cq6q-MSL)cQcI9y5#q0QN4a&A9kU(~T zHXx3+fMRE_(@k?hiT;lqLuWaBl3ngBZzR_}`hpQ;KiO7}{%Abi;C!z`=Pjh-;x1Ba zdVo~l8hA}MtWs`m0Lq=qIG^;q5_sq5wBHL}vpVVvbi9cU{eNpM)G_4@_WwH`v3gSJ1K7noc>^@`6cUc2^9pu|EMyoF44 z{Mi-OHX}XX1?#qFNxMr;=Hm(ZzR+(+m{WW4xO3*Xy`N9D`!ZwYr%%0h&+e+(*uA>Y z?%+RMv3PCo&FoGcG_SxH@khMXn#CXGF!5Tv)85H9@~7tyd%xew*}?j!`2Q&OD<5`w z?cO@vKF~uY$P`cD|Gr!@o72}aeHUC{PhjnRt^Ix$jfEB1pmnSQp8b#gZ@WH-2fMCr z#yMVP(53ru+-vvNH7i~*-0JgeY1hS9^eLnxhyt-!$d+px5v-0QjjQ8f8E*FHX zFy^Iz+tvsDjCDT9VuSKfTiD*pvux2ljUxAG>s7R+Lbd$w&I2cv>s6d|`|ogvSj29bs6BQ? sw5+Ry-BKt=u*ZBlSWNyb9^1-O@*yUZdur%aZ8L!QQZ&YkjMMg zwZ=4okyXa@08#xf`8;e5%Xm6QkPUA0ejQu|V_+1F0I3*+*`J`vs4of{K~Ex`J($^?$;1vnpU&oD{rg+};k#zXGt5#WpZ{e% z?zt~XmoTYKK9G(x<;L6uQ{WO1JphJyZUmBK95d zrOG2E&F;;ay)Pd~q;p8jA#V$9vRV3xFIK;>540z;1DRxc|7Ns(E|%R-ej~rjms^Z{ z`8Az*byjf8tZ@Q3#7?!?|?+H1dF9+Txa7Y8SL}AfmDg>*dv&pr z=QE(0!Hu92%F|DE}8f_pXZn`=in>h=i%ym`?i-~YC8s$ zLl?kVFa}P6<6s!5U-{JoSKV{@wX*ZO8|xn*{a))xrgAJ3OYcp<)g~J?HjpXqymiCq z^=Pr<+<3>ej%-le>oBwu%%~C4{biKrogjQZp8BKoc`@zFELC1~5u;M(#c||y;*(n$_o)fE`kPI+y%$%QE_*kw z=Vs7onrLsa{%EIic9>3^-wa=)Gg`TS348-S2IJsE5Z)K6H5)+#r~x0n`0m$hFAZp# z-3QvAy6)!g*4|7a7Pp0$)U^9V{egY5js9Rcwe#=QFF9dSo9xiuaCK65Yf-pPj^8Q$ zeu(~6e7-4mu1@xpKwh|3UYyS7b;e$XE4Qz~Rm;Zz_Ufaa74J!)`uG`GC=cTFwE@+# zYVgvYJ;R^A_2o6?c?99AQ-K_?@mHVnMs}({5aUhjc~|@bZM=E3P`kn#&IPR@yJe^P zQvLv~8Lo?GXG)(3)%dcLwd5O}A1y#O*8%xSdfbmR#^pgT@;Z25s=ROxjFmvY)#&5J z)_$+-isfQGhX)g8180KnE6sds*7x$o)*r=e5R8FQa1soIA&>%b&qNv`f<<8bMw_ex*CcAi^+ zV%7)>#oiY~?_5%@2fv5*`ZAXqe@T~aE}zCqATOL#b(}4Ce~Ix<8`-=C$QO#2_+s}# z*f)yRKHq>7oS(Yulu^UEiIr9JTmvvGKS7_)!A20MBkmqH?{|+|o{Wg*_W&p74svv^ zPQzKl_REI?6!E$DW|X~Neu($$!>32#1@y4@!B+ob@l^pkj?kw1s2pEtzLuEajBxt( z@$U3L$9(zaOu)W?{`Riud>rO@mYMVId5pFKy3Fw@zbx_lk}>aDy{hKKW-D#V?S;<% zi*`(UeIffb zW`THYABa)>pc7WFWo62J<+^fGbx?bwl6Tpnbqbw*7yY@xd_GE><`w(>^0GbOy*%zl zp7Yl=0a&m}967p>Xk5$lVs|0qUPcJGV?*4%+E@_)D0sVdUJvtrZ+d~(|> z92yrQ$E{DjY;{rd?s}3Fp6qt7o9ybe+GU$!ScNUBK~ZbdJU_X* zzz9*{{libkK|gtS-qvI~o;YNV1lA4Kp`)xP{a!0#U1@aBI>esMEj|6fMS}aw`P=b@{G&4>kVge;wpU+vaaF8EKc}noa=Ol~ zIp@(&(NOHX(p1a825#P58@0FPh9$+*KpYP6& zcJI6=Dbl7t_}uKy{N^`*e!rPn?W)^2Ich%3Wt-_R3Daqkrpp}Vehvm;3%a2fo&aM$ zjaG+ub=b6XCwZWy=8JdBPbaitce2s88a@2sB{)0bfjuEPeb!U`;bd|d(gd5kpd z=UMKD!HMQ*H$N+OBm6x2lruuQ7oCv0cd51&BTAe3`^PqNH)% zW__Eo700?g{p(sybMrK*&VrxZvg=LrScY}bI+G6;VIE$EG894culWv-ZH)U1;Kcbr zyP~4dk;IGc956P z<|d2z;wxNk`&IhZchCE!^}gM$T(%lTZK|G9#ib;?-pY(IOg=J9LWsolmq(jgTz0e8TN5lS9U$x(C?~5m1@%pH4W-rgILl0ssH!>V)xNcaa{t9 zX9bpE9x6}<`Je#u(GaQnk%AVJ5R&0+5#{%m{>y}*7fpPu&pnc#W>Zo8p0Bhf^_iyv0_QG-1%D$&LKSIi@7<23IHm;OchIP;xbQKmsXV9xq zh9YRc8Yb0zDNZ_J44wmtwgsX6v2gL?)adxd$*iAOA5%k=K4S|hmjv;4f0@+?eq9?% zo1H@wj4e%hA83D7pHq+o#j5(_M0TP?u^E0JvhLfItsanF%|nz;vLi+9x@Wxv>w19u z{n~3hhu&rHRo)AC?Q^`YOQ5r-3QI5#6)1!DgaXK~L!{w3I?w%$FIT_RwfUHewtcJq z7jiS+`y9=i?iD|FUl{UB$UvIq+tT1ycC^oFAGikUuSEAFpf>vd5x+kxzRms4Ki%^B z=KT5MU+sSV9a62YU;NvyFO97X8rwRoz!Jz8&DK|nGR4&xC`O!Uu6QKycqupiZJig+ z4?3f*AHJ@Vx<96NDx=zcBAXh6e9*;Bvvv0NKjZhy6aTU{kZDU=eQQWUT)P~x1>+T%(;`Hy6?8|!pdWuwQ?aF_xe>Goa z*n|y`kFUZa>}4Hi(Q~4?a_g8kBe`krm8Wq9YxP-`G+D>@mr>QTp?#xKTzg9A^n<9a z6xZ61M*RDwd-41G4mJ+XLzF8?cI*v?)dL^ zE-8LWumL(V<;NwMhcfK8-(=BoqP{na$6#OaV$M|wy$7n#mxAnRjv7Vz8n@Dleh}TK zxDLl*5A%~IzYlsLj(w+u&A8%l6I=h!IPCxElZiu}0!Q$GI|~$#!~VO*j~S20lD0;$8Ei`~pA~0`VE?QCM*TrVozv9O^UhA+GM{8UG`2NOJ$-MHU+i9M zzoUIKLib1Ls*mcc)eXM&d_O|}E%fysSaqy^{WpZyBkHs=HvOF${H^Kg@V9V;{wefz z76*UFs87Z`%35?|m5(&93!uJh-h)25_iy&S+}Th)=Z_!KhUx`*Me>bhW;NXPw}0mX{;+-#}l_*TGrj>6vebjyzjFGj-x0}~|yTg>71NlV# zDMhqLk=L$LPf|9CYHyPI+E0V_yk{SBjxr1Nv!cC4bCTtE0yIbJgZ@{^YPNeZE1mH8 W)9JnUN9SAzy_YB5JnIbM*8C4P&>=qn literal 0 HcmV?d00001 diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/run_minifier.sh b/cavis-native/cavis-native-lib/src/test/tests_cpu/run_minifier.sh new file mode 100644 index 000000000..42c68dd76 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/run_minifier.sh @@ -0,0 +1,153 @@ +#!/bin/bash +# +# /* ****************************************************************************** +# * +# * +# * This program and the accompanying materials are made available under the +# * terms of the Apache License, Version 2.0 which is available at +# * https://www.apache.org/licenses/LICENSE-2.0. +# * +# * See the NOTICE file distributed with this work for additional +# * information regarding copyright ownership. +# * Unless required by applicable law or agreed to in writing, software +# * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# * License for the specific language governing permissions and limitations +# * under the License. +# * +# * SPDX-License-Identifier: Apache-2.0 +# ******************************************************************************/ +# + + + + + + + +CXX=/usr/bin/g++ +#CXX_PATH=`$CXX --print-search-dirs | awk '/install/{print $2;}'` +#export CXX_PATH +export CXX + +../buildnativeoperations.sh -m -b debug +../blasbuild/cpu/blas/minifier -l -o nd4j_minilib.h +# +#echo "TESTING MINIFIER with all resources" +#echo "Testing adam_sum.fb" +#layers_tests/minifier -l -o nd4j_adam.h ./resources/adam_sum.fb +#echo "Done" +# +#echo "Testing ae_00.fb" +#layers_tests/minifier -l -o nd4j_ae.h ./resources/ae_00.fb +#echo "Done" +# +#layers_tests/minifier -l -o nd4j_conv.h ./resources/conv_0.fb +#layers_tests/minifier -l -o nd4j_expand_dim.h ./resources/expand_dim.fb +#layers_tests/minifier -l -o nd4j_inception.h ./resources/inception.fb +#layers_tests/minifier -l -o nd4j_nested_while.h ./resources/nested_while.fb +#layers_tests/minifier -l -o nd4j_partition_stitch_misc.h ./resources/partition_stitch_misc.fb +#layers_tests/minifier -l -o nd4j_reduce_dim_false.h ./resources/reduce_dim_false.fb +#layers_tests/minifier -l -o nd4j_reduce_dim.h ./resources/reduce_dim.fb +#layers_tests/minifier -l -o nd4j_reduce_dim_true.h ./resources/reduce_dim_true.fb +##layers_tests/minifier -l -o nd4j_simpleif01.h ./resources/simpleif_0_1.fb +##layers_tests/minifier -l -o nd4j_simpleif0.h ./resources/simpleif_0.fb +##layers_tests/minifier -l -o nd4j_simpleif_java.h ./resources/simpleif_0_java.fb +#layers_tests/minifier -l -o nd4j_simplewhile03.h ./resources/simplewhile_0_3.fb +#layers_tests/minifier -l -o nd4j_simplewhile04.h ./resources/simplewhile_0_4.fb +#layers_tests/minifier -l -o nd4j_simplewhile0.h ./resources/simplewhile_0.fb +#layers_tests/minifier -l -o nd4j_simplewhile1.h ./resources/simplewhile_1.fb +#layers_tests/minifier -l -o nd4j_simple_while.h ./resources/simple_while.fb +#layers_tests/minifier -l -o nd4j_simplewhile_nested.h ./resources/simplewhile_nested.fb +#layers_tests/minifier -l -o nd4j_tensor_array.h ./resources/tensor_array.fb +#layers_tests/minifier -l -o nd4j_tensor_array_loop.h ./resources/tensor_array_loop.fb +#layers_tests/minifier -l -o nd4j_tensor_dot_misc.h ./resources/tensor_dot_misc.fb +#layers_tests/minifier -l -o nd4j_tensor_slice.h ./resources/tensor_slice.fb +#layers_tests/minifier -l -o nd4j_three_args_while.h ./resources/three_args_while.fb +#layers_tests/minifier -l -o nd4j_transpose.h ./resources/transpose.fb +# +#echo "All Done (for g++)!!!" +# +#CXX=/usr/bin/g++-5 +#CXX_PATH=`$CXX --print-search-dirs | awk '/install/{print $2;}'` +#export CXX_PATH +#export CXX +# +#make -j4 && layers_tests/minifier -l -o nd4j_minilib.h +# +##echo "TESTING MINIFIER with all resources" +##echo "Testing adam_sum.fb" +##layers_tests/minifier -l -o nd4j_adam.h ./resources/adam_sum.fb +##echo "Done" +## +##echo "Testing ae_00.fb" +##layers_tests/minifier -l -o nd4j_ae.h ./resources/ae_00.fb +##echo "Done" +## +##layers_tests/minifier -l -o nd4j_conv.h ./resources/conv_0.fb +##layers_tests/minifier -l -o nd4j_expand_dim.h ./resources/expand_dim.fb +#layers_tests/minifier -l -o nd4j_inception.h ./resources/inception.fb +#layers_tests/minifier -l -o nd4j_nested_while.h ./resources/nested_while.fb +##layers_tests/minifier -l -o nd4j_partition_stitch_misc.h ./resources/partition_stitch_misc.fb +#layers_tests/minifier -l -o nd4j_reduce_dim_false.h ./resources/reduce_dim_false.fb +#layers_tests/minifier -l -o nd4j_reduce_dim.h ./resources/reduce_dim.fb +#layers_tests/minifier -l -o nd4j_reduce_dim_true.h ./resources/reduce_dim_true.fb +##layers_tests/minifier -l -o nd4j_simpleif01.h ./resources/simpleif_0_1.fb +##layers_tests/minifier -l -o nd4j_simpleif0.h ./resources/simpleif_0.fb +##layers_tests/minifier -l -o nd4j_simpleif_java.h ./resources/simpleif_0_java.fb +#layers_tests/minifier -l -o nd4j_simplewhile03.h ./resources/simplewhile_0_3.fb +#layers_tests/minifier -l -o nd4j_simplewhile04.h ./resources/simplewhile_0_4.fb +#layers_tests/minifier -l -o nd4j_simplewhile0.h ./resources/simplewhile_0.fb +##layers_tests/minifier -l -o nd4j_simplewhile1.h ./resources/simplewhile_1.fb +##layers_tests/minifier -l -o nd4j_simple_while.h ./resources/simple_while.fb +##layers_tests/minifier -l -o nd4j_simplewhile_nested.h ./resources/simplewhile_nested.fb +##layers_tests/minifier -l -o nd4j_tensor_array.h ./resources/tensor_array.fb +#layers_tests/minifier -l -o nd4j_tensor_array_loop.h ./resources/tensor_array_loop.fb +#layers_tests/minifier -l -o nd4j_tensor_dot_misc.h ./resources/tensor_dot_misc.fb +#layers_tests/minifier -l -o nd4j_tensor_slice.h ./resources/tensor_slice.fb +#layers_tests/minifier -l -o nd4j_three_args_while.h ./resources/three_args_while.fb +#layers_tests/minifier -l -o nd4j_transpose.h ./resources/transpose.fb +# +#echo "All Done!!!" +# +#CXX=/usr/bin/g++-7 +#CXX_PATH=`$CXX --print-search-dirs | awk '/install/{print $2;}'` +#export CXX_PATH +#export CXX +# +#make -j4 && layers_tests/minifier -l -o nd4j_minilib.h +# +#echo "TESTING MINIFIER with all resources" +#echo "Testing adam_sum.fb" +#layers_tests/minifier -l -o nd4j_adam.h ./resources/adam_sum.fb +#echo "Done" +# +#echo "Testing ae_00.fb" +#layers_tests/minifier -l -o nd4j_ae.h ./resources/ae_00.fb +#echo "Done" +# +##layers_tests/minifier -l -o nd4j_conv.h ./resources/conv_0.fb +#layers_tests/minifier -l -o nd4j_expand_dim.h ./resources/expand_dim.fb +#layers_tests/minifier -l -o nd4j_inception.h ./resources/inception.fb +#layers_tests/minifier -l -o nd4j_nested_while.h ./resources/nested_while.fb +#layers_tests/minifier -l -o nd4j_partition_stitch_misc.h ./resources/partition_stitch_misc.fb +#layers_tests/minifier -l -o nd4j_reduce_dim_false.h ./resources/reduce_dim_false.fb +#layers_tests/minifier -l -o nd4j_reduce_dim.h ./resources/reduce_dim.fb +#layers_tests/minifier -l -o nd4j_reduce_dim_true.h ./resources/reduce_dim_true.fb +#layers_tests/minifier -l -o nd4j_simpleif01.h ./resources/simpleif_0_1.fb +#layers_tests/minifier -l -o nd4j_simpleif0.h ./resources/simpleif_0.fb +#layers_tests/minifier -l -o nd4j_simpleif_java.h ./resources/simpleif_0_java.fb +#layers_tests/minifier -l -o nd4j_simplewhile03.h ./resources/simplewhile_0_3.fb +#layers_tests/minifier -l -o nd4j_simplewhile04.h ./resources/simplewhile_0_4.fb +#layers_tests/minifier -l -o nd4j_simplewhile0.h ./resources/simplewhile_0.fb +#layers_tests/minifier -l -o nd4j_simplewhile1.h ./resources/simplewhile_1.fb +#layers_tests/minifier -l -o nd4j_simple_while.h ./resources/simple_while.fb +#layers_tests/minifier -l -o nd4j_simplewhile_nested.h ./resources/simplewhile_nested.fb +#layers_tests/minifier -l -o nd4j_tensor_array.h ./resources/tensor_array.fb +#layers_tests/minifier -l -o nd4j_tensor_array_loop.h ./resources/tensor_array_loop.fb +#layers_tests/minifier -l -o nd4j_tensor_dot_misc.h ./resources/tensor_dot_misc.fb +#layers_tests/minifier -l -o nd4j_tensor_slice.h ./resources/tensor_slice.fb +#layers_tests/minifier -l -o nd4j_three_args_while.h ./resources/three_args_while.fb +#layers_tests/minifier -l -o nd4j_transpose.h ./resources/transpose.fb +# +echo "All Done!!!" diff --git a/cavis-native/cavis-native-lib/src/test/tests_cpu/run_tests.sh b/cavis-native/cavis-native-lib/src/test/tests_cpu/run_tests.sh new file mode 100644 index 000000000..b116e3558 --- /dev/null +++ b/cavis-native/cavis-native-lib/src/test/tests_cpu/run_tests.sh @@ -0,0 +1,68 @@ +#!/bin/sh + +# +# /* ****************************************************************************** +# * +# * +# * This program and the accompanying materials are made available under the +# * terms of the Apache License, Version 2.0 which is available at +# * https://www.apache.org/licenses/LICENSE-2.0. +# * +# * See the NOTICE file distributed with this work for additional +# * information regarding copyright ownership. +# * Unless required by applicable law or agreed to in writing, software +# * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# * License for the specific language governing permissions and limitations +# * under the License. +# * +# * SPDX-License-Identifier: Apache-2.0 +# ******************************************************************************/ +# + +set -exo pipefail + +while [[ $# -gt 0 ]] +do + key="$1" + value="${2:-}" + + case $key in + -c|--chip) + CHIP="${value}" + shift # past argument + ;; + -e|--chip-extension) + CHIP_EXTENSION="$value" + shift # past argument + ;; + *) + # unknown option + ;; + esac + + if [[ $# -gt 0 ]]; then + shift # past argument or value + fi +done + +CHIP="${CHIP:-cpu}" +export GTEST_OUTPUT="xml:surefire-reports/TEST-${CHIP}-results.xml" + +# On Mac, make sure it can find libraries for GCC +export DYLD_LIBRARY_PATH=/usr/local/lib/gcc/8/:/usr/local/lib/gcc/7/:/usr/local/lib/gcc/6/:/usr/local/lib/gcc/5/ + +# For Windows, add DLLs of MKL-DNN and OpenBLAS to the PATH +if [ -n "$BUILD_PATH" ]; then + if which cygpath; then + BUILD_PATH=$(cygpath -p $BUILD_PATH) + fi + export PATH="$PATH:$BUILD_PATH" +fi + +unameOut="$(uname)" +echo "$OSTYPE" + +../blasbuild/${CHIP}/${CHIP_EXTENSION}/tests_cpu/layers_tests/runtests +# Workaround to fix posix path conversion problem on Windows (http://mingw.org/wiki/Posix_path_conversion) +[ -f "${GTEST_OUTPUT#*:}" ] && cp -a surefire-reports/ ../target && rm -rf surefire-reports/