Softmax operation implementation for mkldnn (#286)
* libnd4j first step of softmax mkldnn implementation Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j raw implementation of mkldnn softmax Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j merge master and added softmax to MklDnnTests Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j some corrections for softmax mkldnn Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j merge branch, fixed problem with negative axis, fixed dnnl::memory::format_tag selection, test cases added Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j minor corrections to avoid risk connected with negative axis usage Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j fixed windows builds, added switcher to use mkldnn sofmax version only for 3D, 4D, 5D, 6D arrays Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j fixed dataType selection per request Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j fix for mac and windows builds Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j builds fix Signed-off-by: Oleg <oleg.semeniv@gmail.com>
This commit is contained in:
		
							parent
							
								
									1c89512ec0
								
							
						
					
					
						commit
						4d81af9fe9
					
				| @ -22,6 +22,7 @@ | ||||
| #ifndef DEV_TESTS_MKLDNNUTILS_H | ||||
| #define DEV_TESTS_MKLDNNUTILS_H | ||||
| 
 | ||||
| 
 | ||||
| #include <legacy/NativeOps.h> | ||||
| #include <array/NDArray.h> | ||||
| #include <dnnl.hpp> | ||||
| @ -86,6 +87,9 @@ namespace sd{ | ||||
|             DECLARE_PLATFORM(depthwise_conv2d_bp, ENGINE_CPU); | ||||
| 
 | ||||
|             DECLARE_PLATFORM(matmul, ENGINE_CPU); | ||||
| 
 | ||||
|             DECLARE_PLATFORM(softmax, ENGINE_CPU); | ||||
| 
 | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|  | ||||
							
								
								
									
										183
									
								
								libnd4j/include/ops/declarable/platform/mkldnn/softmax.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										183
									
								
								libnd4j/include/ops/declarable/platform/mkldnn/softmax.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,183 @@ | ||||
| /*******************************************************************************
 | ||||
|  * Copyright (c) 2019-2020 Konduit K.K. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0.
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
|  //
 | ||||
|  //  @author Oleg Semeniv <oleg.semeniv@gmail.com>
 | ||||
|  //
 | ||||
|  //
 | ||||
| 
 | ||||
| #include <ops/declarable/PlatformHelper.h> | ||||
| #include <ops/declarable/OpRegistrator.h> | ||||
| #include <system/platform_boilerplate.h> | ||||
| #include <helpers/MKLDNNStream.h> | ||||
| #include "mkldnnUtils.h" | ||||
| 
 | ||||
| using namespace dnnl; | ||||
| 
 | ||||
| namespace sd { | ||||
|     namespace ops { | ||||
|         namespace platforms { | ||||
| 
 | ||||
|             //////////////////////////////////////////////////////////////////////
 | ||||
|             static void softmaxMKLDNN(const NDArray* x, NDArray* z, const int axis) { | ||||
| 
 | ||||
|                 const auto xRank = x->rankOf(); | ||||
|                 const auto zRank = z->rankOf(); | ||||
| 
 | ||||
|                 std::vector<int64_t> dimsX(xRank), dimsZ(zRank); | ||||
|                 for (auto i = 0; i < xRank; i++) { | ||||
|                     dimsX[i] = x->sizeAt(i); | ||||
|                     dimsZ[i] = z->sizeAt(i); | ||||
|                 } | ||||
| 
 | ||||
|                 dnnl::memory::dims xShape = dnnl::memory::dims(dimsX); | ||||
|                 dnnl::memory::dims zShape = dnnl::memory::dims(dimsZ); | ||||
| 
 | ||||
|                 dnnl::memory::format_tag format = dnnl::memory::format_tag::a; // 1 == xRank
 | ||||
|                 if (2 == xRank && 1 == axis) { | ||||
|                     format = dnnl::memory::format_tag::ab; | ||||
|                 } | ||||
|                 else if (2 == xRank && 0 == axis) { | ||||
|                     format = dnnl::memory::format_tag::ba; | ||||
|                 } | ||||
|                 else if (3 == xRank) { | ||||
|                     format = dnnl::memory::format_tag::abc; | ||||
|                 } | ||||
|                 else if (4 == xRank && 3 == axis) { | ||||
|                     format = dnnl::memory::format_tag::abcd; | ||||
|                 } | ||||
|                 else if (4 == xRank && 1 == axis && dimsX[2] * dimsX[3] > 1) { | ||||
|                     format = dnnl::memory::format_tag::acdb; | ||||
|                 } | ||||
|                 else if (4 == xRank) { | ||||
|                     format = dnnl::memory::format_tag::abcd; | ||||
|                 } | ||||
|                 else if (5 == xRank) { | ||||
|                     format = dnnl::memory::format_tag::abcde; | ||||
|                 } | ||||
|                 else if (6 == xRank) { | ||||
|                     format = dnnl::memory::format_tag::abcdef; | ||||
|                 } | ||||
| 
 | ||||
|                 dnnl::memory::data_type xType = dnnl::memory::data_type::f32; | ||||
|                 dnnl::memory::data_type zType = dnnl::memory::data_type::f32; | ||||
| 
 | ||||
|                 dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, format); | ||||
|                 dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format); | ||||
| 
 | ||||
|                 if (x->ews() != 1 || x->ordering() != 'c') { | ||||
|                     x_user_md.data.format_kind = dnnl_blocked;    // overrides format
 | ||||
|                     for (auto i = 0; i < xRank; ++i) { | ||||
|                         x_user_md.data.format_desc.blocking.strides[i] = x->strideAt(i); | ||||
|                     } | ||||
|                 } | ||||
| 
 | ||||
|                 // z
 | ||||
|                 dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, zType, dnnl::memory::format_tag::any); | ||||
|                 dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, zType, format); | ||||
|                 if (z->ews() != 1 || z->ordering() != 'c') { | ||||
|                     z_user_md.data.format_kind = dnnl_blocked;    // overrides format
 | ||||
|                     for (auto i = 0; i < xRank; ++i) { | ||||
|                         z_user_md.data.format_desc.blocking.strides[i] = z->strideAt(i); | ||||
|                     } | ||||
|                 } | ||||
| 
 | ||||
|                 auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); | ||||
| 
 | ||||
|                 // Create attributes (to handle alpha and beta if necessary)
 | ||||
|                 dnnl::primitive_attr attr; // it is empty since we have usual values for alpha (=1) and beta (=0)
 | ||||
| 
 | ||||
|                 // operation primitive description
 | ||||
|                 // todo check this
 | ||||
|                 dnnl::softmax_forward::desc op_desc(dnnl::prop_kind::forward_inference, x_mkl_md, axis); | ||||
| 
 | ||||
|                 dnnl::softmax_forward::primitive_desc op_prim_desc(op_desc, attr, engine); | ||||
| 
 | ||||
|                 // arguments (memory buffers) necessary for calculations
 | ||||
|                 std::unordered_map<int, dnnl::memory> args; | ||||
| 
 | ||||
|                 dnnl::stream stream(engine); | ||||
| 
 | ||||
|                 // provide memory buffers and check whether reorder is required
 | ||||
| 
 | ||||
|                 // input
 | ||||
|                 auto x_user_mem = dnnl::memory(x_user_md, engine, x->getBuffer()); | ||||
|                 const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc(); | ||||
|                 auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem; | ||||
|                 if (xReorder) | ||||
|                     dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem); | ||||
|                 args[DNNL_ARG_SRC] = x_mkl_mem; | ||||
| 
 | ||||
|                 // z
 | ||||
|                 auto z_user_mem = dnnl::memory(z_user_md, engine, z->getBuffer()); | ||||
|                 const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc(); | ||||
|                 auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem; | ||||
|                 args[DNNL_ARG_DST] = z_mkl_mem; | ||||
| 
 | ||||
|                 // run calculations
 | ||||
|                 dnnl::softmax_forward(op_prim_desc).execute(stream, args); | ||||
| 
 | ||||
|                 // reorder outputs if necessary
 | ||||
|                 if (zReorder) | ||||
|                     dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem); | ||||
| 
 | ||||
|                 stream.wait(); | ||||
|             } | ||||
| 
 | ||||
| 
 | ||||
|             PLATFORM_IMPL(softmax, ENGINE_CPU) { | ||||
| 
 | ||||
|                 auto input = INPUT_VARIABLE(0); | ||||
|                 auto output = OUTPUT_VARIABLE(0); | ||||
| 
 | ||||
|                 const int rank = input->rankOf(); | ||||
|                 int dim = block.getIArguments()->size() > 0 ? INT_ARG(0) : rank - 1; | ||||
| 
 | ||||
|                 if (dim < 0) { | ||||
|                     dim += rank; | ||||
|                 } | ||||
| 
 | ||||
|                 REQUIRE_TRUE(dim < rank && dim >= 0, 0, "SOFTMAX_MKLDNN OP: the value of input integer parameter (dimension) must be less than input array rank %i, but got dimension = %i instead !", rank, dim); | ||||
| 
 | ||||
|                 REQUIRE_TRUE(rank <= 6, 0, "SOFTMAX_MKLDNN OP: the rank of input must be less or qual 4, but got rank = %i instead !", rank); | ||||
| 
 | ||||
|                 // mkldnnSoftMax
 | ||||
|                 softmaxMKLDNN(input, output, dim); | ||||
| 
 | ||||
|                 return Status::OK(); | ||||
|             } | ||||
| 
 | ||||
|             PLATFORM_CHECK(softmax, ENGINE_CPU) { | ||||
| 
 | ||||
|                 auto x = INPUT_VARIABLE(0); | ||||
|                 auto z = OUTPUT_VARIABLE(0); | ||||
| 
 | ||||
|                 const DataType xType = x->dataType(); | ||||
|                 const DataType zType = z->dataType(); | ||||
| 
 | ||||
|                 const int xRank = x->rankOf(); | ||||
|                 bool bSupportedRanks = (xRank > 2 && xRank < 7); | ||||
|                 /*
 | ||||
|                 Source     Destination | ||||
|                 f32 	    f32 | ||||
|                 */ | ||||
|                 return  block.isUseMKLDNN() && bSupportedRanks && (xType == DataType::FLOAT32 && zType == DataType::FLOAT32); | ||||
| 
 | ||||
|             } | ||||
| 
 | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @ -2921,8 +2921,10 @@ TEST_F(DeclarableOpsTests1, Test_Range_Integer_3) { | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, softmax_test1) { | ||||
|     auto input = NDArrayFactory::create<double>('c', {3, 3}, {-1, 1, -2, 2, -3, 3, -4, 4, 5}); | ||||
|     auto expOutput = NDArrayFactory::create<double>('c', {3, 3}, {1.14195199e-01, 8.43794734e-01, 4.20100661e-02, 2.68454951e-01, 1.80883523e-03, 7.29736214e-01, 9.02116571e-05, 2.68917160e-01, 7.30992629e-01}); | ||||
| 
 | ||||
|     NDArray input('c', { 3, 3 }, { -1.f, 1.f, -2.f, 2.f, -3.f, 3.f, -4.f, 4.f, 5.f }, sd::DataType::FLOAT32); | ||||
| 
 | ||||
|     NDArray expOutput('c', { 3, 3 }, { 1.14195199e-01, 8.43794734e-01, 4.20100661e-02, 2.68454951e-01, 1.80883523e-03, 7.29736214e-01, 9.02116571e-05, 2.68917160e-01, 7.30992629e-01 }, sd::DataType::FLOAT32); | ||||
| 
 | ||||
|     sd::ops::softmax op; | ||||
|     auto results = op.evaluate({ &input }, {}, {}, {}); | ||||
| @ -2937,8 +2939,8 @@ TEST_F(DeclarableOpsTests1, softmax_test1) { | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, softmax_test2) { | ||||
|     auto input = NDArrayFactory::create<double>('c', {3, 3, 3}, {-1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14}); | ||||
|     auto expOutput = NDArrayFactory::create<double>('c', {3, 3, 3}, {4.73142e-02,4.73847e-02,6.69062e-03, 9.50330e-01,8.67881e-04,9.92976e-01, 2.35563e-03,9.51747e-01,3.33106e-04, 4.74259e-02,2.26032e-06,4.74259e-02, 2.91395e-07,9.99998e-01,3.94360e-08, 9.52574e-01,1.12535e-07,9.52574e-01, 7.58256e-10,4.74259e-02,1.22325e-11, 1.00000e+00,1.32293e-11,1.19203e-01, 3.77513e-11,9.52574e-01,8.80797e-01}); | ||||
|     NDArray input('c', { 3, 3, 3 }, { -1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14 }, sd::DataType::FLOAT32); | ||||
|     NDArray expOutput('c', { 3, 3, 3 }, { 4.73142e-02,4.73847e-02,6.69062e-03, 9.50330e-01,8.67881e-04,9.92976e-01, 2.35563e-03,9.51747e-01,3.33106e-04, 4.74259e-02,2.26032e-06,4.74259e-02, 2.91395e-07,9.99998e-01,3.94360e-08, 9.52574e-01,1.12535e-07,9.52574e-01, 7.58256e-10,4.74259e-02,1.22325e-11, 1.00000e+00,1.32293e-11,1.19203e-01, 3.77513e-11,9.52574e-01,8.80797e-01 }, sd::DataType::FLOAT32); | ||||
| 
 | ||||
|     sd::ops::softmax op; | ||||
|     auto results = op.evaluate({ &input }, {}, { 1 }, {}); | ||||
| @ -2953,8 +2955,8 @@ TEST_F(DeclarableOpsTests1, softmax_test2) { | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, softmax_test3) { | ||||
|     auto input = NDArrayFactory::create<double>('c', {3, 3, 3}, {-1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14}); | ||||
|     auto expOutput = NDArrayFactory::create<double>('c', {3, 3, 3}, {2.47262e-03,1.23395e-04,3.35350e-04, 1.23395e-04,4.53979e-05,1.23395e-04, 6.14417e-06,1.23395e-04,5.56530e-09, 9.97527e-01,1.12521e-07,9.99665e-01, 1.52281e-08,9.99955e-01,2.06090e-09, 9.99994e-01,2.78912e-10,6.69285e-03, 3.05146e-07,9.99876e-01,4.13855e-08, 9.99877e-01,5.60254e-09,9.99877e-01, 7.58251e-10,9.99877e-01,9.93307e-01}); | ||||
|     NDArray input('c', { 3, 3, 3 }, { -1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14 }, sd::DataType::FLOAT32); | ||||
|     NDArray expOutput('c', { 3, 3, 3 }, { 2.47262e-03,1.23395e-04,3.35350e-04, 1.23395e-04,4.53979e-05,1.23395e-04, 6.14417e-06,1.23395e-04,5.56530e-09, 9.97527e-01,1.12521e-07,9.99665e-01, 1.52281e-08,9.99955e-01,2.06090e-09, 9.99994e-01,2.78912e-10,6.69285e-03, 3.05146e-07,9.99876e-01,4.13855e-08, 9.99877e-01,5.60254e-09,9.99877e-01, 7.58251e-10,9.99877e-01,9.93307e-01 }, sd::DataType::FLOAT32); | ||||
| 
 | ||||
|     sd::ops::softmax op; | ||||
|     auto results = op.evaluate({ &input }, {}, { 0 }, {}); | ||||
| @ -2969,8 +2971,8 @@ TEST_F(DeclarableOpsTests1, softmax_test3) { | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, softmax_test4) { | ||||
|     auto input = NDArrayFactory::create<double>('c', {1, 5}, {-1, 1, -2, 2, 3}); | ||||
|     auto expOutput = NDArrayFactory::create<double>('c', {1, 5}, {0.01198,0.08855,0.00441,0.24072,0.65434}); | ||||
|     NDArray input('c', { 1, 5 }, { -1, 1, -2, 2, 3 }, sd::DataType::FLOAT32); | ||||
|     NDArray expOutput('c', { 1, 5 }, { 0.01198,0.08855,0.00441,0.24072,0.65434 }, sd::DataType::FLOAT32); | ||||
| 
 | ||||
|     sd::ops::softmax op; | ||||
|     auto results = op.evaluate({ &input }, {}, { 1 }, {}); | ||||
| @ -2985,8 +2987,8 @@ TEST_F(DeclarableOpsTests1, softmax_test4) { | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, softmax_test5) { | ||||
|     auto input = NDArrayFactory::create<double>('c', {1, 5}, {-1, 1, -2, 2, 3}); | ||||
|     auto expOutput = NDArrayFactory::create<double>('c', {1, 5}, {1,1,1,1,1}); | ||||
|     NDArray input('c', { 1, 5 }, { -1, 1, -2, 2, 3 }, sd::DataType::FLOAT32); | ||||
|     NDArray expOutput('c', { 1, 5 }, { 1,1,1,1,1 }, sd::DataType::FLOAT32); | ||||
| 
 | ||||
|     sd::ops::softmax op; | ||||
|     auto results = op.evaluate({ &input }, {}, { 0 }); | ||||
| @ -3001,8 +3003,8 @@ TEST_F(DeclarableOpsTests1, softmax_test5) { | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, softmax_test6) { | ||||
|     auto input = NDArrayFactory::create<double>('c', {5, 1}, {-1, 1, -2, 2, 3}); | ||||
|     auto expOutput = NDArrayFactory::create<double>('c', {5, 1}, {0.01198,0.08855,0.00441,0.24072,0.65434}); | ||||
|     NDArray input('c', { 5, 1 }, { -1, 1, -2, 2, 3 }, sd::DataType::FLOAT32); | ||||
|     NDArray expOutput('c', { 5, 1 }, { 0.01198,0.08855,0.00441,0.24072,0.65434 }, sd::DataType::FLOAT32); | ||||
| 
 | ||||
|     sd::ops::softmax op; | ||||
|     auto results = op.evaluate({ &input }, {}, { 0 }, {}); | ||||
| @ -3017,8 +3019,8 @@ TEST_F(DeclarableOpsTests1, softmax_test6) { | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, softmax_test7) { | ||||
|     auto input = NDArrayFactory::create<double>('c', {5, 1}, {-1, 1, -2, 2, 3}); | ||||
|     auto expOutput = NDArrayFactory::create<double>('c', {5, 1}, {1,1,1,1,1}); | ||||
|     NDArray input('c', { 5, 1 }, { -1, 1, -2, 2, 3 }, sd::DataType::FLOAT32); | ||||
|     NDArray expOutput('c', { 5, 1 }, { 1,1,1,1,1 }, sd::DataType::FLOAT32); | ||||
| 
 | ||||
|     sd::ops::softmax op; | ||||
|     auto results = op.evaluate({ &input }, {}, { 1 }, {}); | ||||
| @ -3033,8 +3035,8 @@ TEST_F(DeclarableOpsTests1, softmax_test7) { | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, softmax_test8) { | ||||
|     auto input = NDArrayFactory::create<double>('c', {5}, {-1, 1, -2, 2, 3}); | ||||
|     auto expOutput = NDArrayFactory::create<double>('c', {5}, {0.01198,0.08855,0.00441,0.24072,0.65434}); | ||||
|     NDArray input('c', { 5 }, { -1, 1, -2, 2, 3 }, sd::DataType::FLOAT32); | ||||
|     NDArray expOutput('c', { 5 }, { 0.01198,0.08855,0.00441,0.24072,0.65434 }, sd::DataType::FLOAT32); | ||||
| 
 | ||||
|     sd::ops::softmax op; | ||||
|     auto results = op.evaluate({ &input }, {}, {}, {}); | ||||
| @ -3047,6 +3049,70 @@ TEST_F(DeclarableOpsTests1, softmax_test8) { | ||||
|     delete results; | ||||
| } | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, softmax_test9) { | ||||
|     NDArray input('c', { 2, 2, 2, 2 }, { -1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8 }, sd::DataType::FLOAT32); | ||||
|     NDArray expOutput('c', { 2, 2, 2, 2 }, { 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059 }, sd::DataType::FLOAT32); | ||||
| 
 | ||||
|     sd::ops::softmax op; | ||||
|     auto results = op.evaluate({ &input }, {}, { 2 }, {}); | ||||
|     auto z = results->at(0); | ||||
| 
 | ||||
|     ASSERT_EQ(Status::OK(), results->status()); | ||||
|     ASSERT_TRUE(expOutput.isSameShape(z)); | ||||
|     ASSERT_TRUE(expOutput.equalsTo(z)); | ||||
| 
 | ||||
|     delete results; | ||||
| } | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, softmax_test10) { | ||||
|     NDArray input('c', { 2, 2, 2, 2, 2 }, { -1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14, -14, 15, -15, 16,-16 }, sd::DataType::FLOAT32); | ||||
|     NDArray expOutput('c', { 2, 2, 2, 2, 2 }, { 0.119203, 0.880797, 0.017986, 0.982014, 0.002473, 0.997527, 0.000335, 0.999665, 0.000045, 0.999955, 0.000006, 0.999994, 0.000001, 0.999999, 0.000000, 1.000000, 0.000000, 1.000000, 0.000000, 1.000000, 0.000000, 1.000000, 0.000000, 1.000000, 0.000000, 1.000000, 1.000000, 0.000000, 1.000000, 0.000000, 1.000000, 0.00000 }, sd::DataType::FLOAT32); | ||||
| 
 | ||||
|     sd::ops::softmax op; | ||||
|     auto results = op.evaluate({ &input }, {}, { 4 }, {}); | ||||
|     auto z = results->at(0); | ||||
| 
 | ||||
|     ASSERT_EQ(Status::OK(), results->status()); | ||||
|     ASSERT_TRUE(expOutput.isSameShape(z)); | ||||
|     ASSERT_TRUE(expOutput.equalsTo(z)); | ||||
| 
 | ||||
|     delete results; | ||||
| } | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, softmax_test11) { | ||||
|     NDArray input('c', { 2, 2, 2, 2, 2, 2 }, { -1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14, -14, 15, -15, 16,-16, -2.1, 2.1, -2.2, 2.2, -2.3, 2.3, -2.4, 2.4, -2.5,2.5 ,-2.6,2.6, -2.7,2.7, -2.8,2.8, -2.9,2.9, -3.0,3.0, -3.1,3.1, -3.2,3.2, -3.3,3.3, 3.4, -3.4, 3.5, -3.5, 3.6,-3.6 }, sd::DataType::FLOAT32); | ||||
|     NDArray expOutput('c', { 2, 2, 2, 2, 2, 2 }, { 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, 0.731059, 0.268941, 0.268941, 0.731059, 0.000000, 1.000000, 1.000000, 0.000000, 0.268941, 0.731059, 0.731059, 0.268941, 0.524979, 0.475021, 0.475021, 0.524979, 0.524979, 0.475021, 0.475021, 0.524979, 0.524979, 0.475021, 0.475021, 0.524979, 0.524979, 0.475021, 0.475021, 0.524979, 0.524979, 0.475021, 0.475021, 0.524979, 0.524979, 0.475021, 0.475021, 0.524979, 0.001229, 0.998771, 0.998771, 0.001229, 0.475021, 0.524979, 0.524979, 0.475021 }, sd::DataType::FLOAT32); | ||||
| 
 | ||||
|     sd::ops::softmax op; | ||||
|     auto results = op.evaluate({ &input }, {}, { 4 }, {}); | ||||
|     auto z = results->at(0); | ||||
| 
 | ||||
|     ASSERT_EQ(Status::OK(), results->status()); | ||||
|     ASSERT_TRUE(expOutput.isSameShape(z)); | ||||
|     ASSERT_TRUE(expOutput.equalsTo(z)); | ||||
| 
 | ||||
|     delete results; | ||||
| } | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, softmax_test12) { | ||||
|     NDArray input('f', { 2, 2, 2, 2, 2, 2 }, { -1, 1, -2, 2, -3, 3, -4, 4, -5,5 ,-6,6, -7,7, -8,8, -9,9, -10,10, -11,11, -12,12, -13,13, 14, -14, 15, -15, 16,-16, -2.1, 2.1, -2.2, 2.2, -2.3, 2.3, -2.4, 2.4, -2.5,2.5 ,-2.6,2.6, -2.7,2.7, -2.8,2.8, -2.9,2.9, -3.0,3.0, -3.1,3.1, -3.2,3.2, -3.3,3.3, 3.4, -3.4, 3.5, -3.5, 3.6,-3.6 }, sd::DataType::FLOAT32); | ||||
|     NDArray exp('c', { 2, 2, 2, 2, 2, 2 }, { 0.982014, 0.598688, 0.982014, 0.598688, 0.017986, 0.401312, 0.017986, 0.401312, 0.982014, 0.598688, 0.000000, 0.001359, 0.017986, 0.401312, 1.000000, 0.998641, 0.982014, 0.598688, 0.000000, 0.001659, 0.017986, 0.401312, 1.000000, 0.998341, 0.982014, 0.598688, 0.000000, 0.001113, 0.017986, 0.401312, 1.000000, 0.998887, 0.017986, 0.401312, 0.017986, 0.401312, 0.982014, 0.598688, 0.982014, 0.598688, 0.017986, 0.401312, 1.000000, 0.998641, 0.982014, 0.598688, 0.000000, 0.001359, 0.017986, 0.401312, 1.000000, 0.998341, 0.982014, 0.598688, 0.000000, 0.001659, 0.017986, 0.401312, 1.000000, 0.998887, 0.982014, 0.598688, 0.000000, 0.001113 }, sd::DataType::FLOAT32); | ||||
| 
 | ||||
|     auto expOutput = NDArray('f', { 2, 2, 2, 2, 2, 2 }, sd::DataType::FLOAT32); | ||||
|     expOutput.assign(exp); | ||||
| 
 | ||||
|     sd::ops::softmax op; | ||||
|     auto results = op.evaluate({ &input }, {}, { 3 }, {}); | ||||
|     auto z = results->at(0); | ||||
| 
 | ||||
|     ASSERT_EQ(Status::OK(), results->status()); | ||||
|     ASSERT_TRUE(expOutput.isSameShape(z)); | ||||
|     ASSERT_TRUE(expOutput.equalsTo(z)); | ||||
| 
 | ||||
|     delete results; | ||||
| } | ||||
| //////////////////////////////////////////////////////////////////////
 | ||||
| TEST_F(DeclarableOpsTests1, Reverse_1) { | ||||
| 
 | ||||
| @ -3436,4 +3502,3 @@ TEST_F(DeclarableOpsTests1, Test_Release) { | ||||
|     auto x = NDArrayFactory::create<float>('c', { 8, 8 }); | ||||
|     // x.printShapeInfo("x shape");
 | ||||
| } | ||||
| 
 | ||||
|  | ||||
| @ -69,6 +69,8 @@ TEST_F(MklDnnTests, helpers_includer) { | ||||
| 
 | ||||
|     sd::ops::platforms::PLATFORM_matmul_ENGINE_CPU matmul; | ||||
| 
 | ||||
|     printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm, &matmul}); | ||||
|     sd::ops::platforms::PLATFORM_softmax_ENGINE_CPU softmax; | ||||
| 
 | ||||
|     printer({&conv2d, &conv2d_bp, &conv3d, &conv3d_bp, &avgpool2d, &avgpool2d_bp, &maxpool2d, &maxpool2d_bp, &avgpool3d, &avgpool3d_bp, &maxpool3d, &maxpool3d_bp, &lrn, &batchnorm, &matmul, &softmax }); | ||||
| #endif | ||||
| } | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user