cavis/libnd4j/tests_cpu/layers_tests/ReduceTests.cpp

155 lines
90 KiB
C++
Raw Normal View History

2019-06-06 14:21:15 +02:00
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// Created by agibsonccc on 1/15/17.
//
#include <helpers/ShapeUtils.h>
#include "testinclude.h"
#include <loops/reduce3.h>
#include <loops/reduce_float.h>
#include <ArrayOptions.h>
class ReduceTest : public testing::Test {
public:
Nd4jLong shape[2] = {500,3};
float x[1500] = {4.0,2.0,3.0,8.0,4.0,6.0,12.0,6.0,9.0,16.0,8.0,12.0,20.0,10.0,15.0,24.0,12.0,18.0,28.0,14.0,21.0,32.0,16.0,24.0,36.0,18.0,27.0,40.0,20.0,30.0,44.0,22.0,33.0,48.0,24.0,36.0,52.0,26.0,39.0,56.0,28.0,42.0,60.0,30.0,45.0,64.0,32.0,48.0,68.0,34.0,51.0,72.0,36.0,54.0,76.0,38.0,57.0,80.0,40.0,60.0,84.0,42.0,63.0,88.0,44.0,66.0,92.0,46.0,69.0,96.0,48.0,72.0,100.0,50.0,75.0,104.0,52.0,78.0,108.0,54.0,81.0,112.0,56.0,84.0,116.0,58.0,87.0,120.0,60.0,90.0,124.0,62.0,93.0,128.0,64.0,96.0,132.0,66.0,99.0,136.0,68.0,102.0,140.0,70.0,105.0,144.0,72.0,108.0,148.0,74.0,111.0,152.0,76.0,114.0,156.0,78.0,117.0,160.0,80.0,120.0,164.0,82.0,123.0,168.0,84.0,126.0,172.0,86.0,129.0,176.0,88.0,132.0,180.0,90.0,135.0,184.0,92.0,138.0,188.0,94.0,141.0,192.0,96.0,144.0,196.0,98.0,147.0,200.0,100.0,150.0,204.0,102.0,153.0,208.0,104.0,156.0,212.0,106.0,159.0,216.0,108.0,162.0,220.0,110.0,165.0,224.0,112.0,168.0,228.0,114.0,171.0,232.0,116.0,174.0,236.0,118.0,177.0,240.0,120.0,180.0,244.0,122.0,183.0,248.0,124.0,186.0,252.0,126.0,189.0,256.0,128.0,192.0,260.0,130.0,195.0,264.0,132.0,198.0,268.0,134.0,201.0,272.0,136.0,204.0,276.0,138.0,207.0,280.0,140.0,210.0,284.0,142.0,213.0,288.0,144.0,216.0,292.0,146.0,219.0,296.0,148.0,222.0,300.0,150.0,225.0,304.0,152.0,228.0,308.0,154.0,231.0,312.0,156.0,234.0,316.0,158.0,237.0,320.0,160.0,240.0,324.0,162.0,243.0,328.0,164.0,246.0,332.0,166.0,249.0,336.0,168.0,252.0,340.0,170.0,255.0,344.0,172.0,258.0,348.0,174.0,261.0,352.0,176.0,264.0,356.0,178.0,267.0,360.0,180.0,270.0,364.0,182.0,273.0,368.0,184.0,276.0,372.0,186.0,279.0,376.0,188.0,282.0,380.0,190.0,285.0,384.0,192.0,288.0,388.0,194.0,291.0,392.0,196.0,294.0,396.0,198.0,297.0,400.0,200.0,300.0,404.0,202.0,303.0,408.0,204.0,306.0,412.0,206.0,309.0,416.0,208.0,312.0,420.0,210.0,315.0,424.0,212.0,318.0,428.0,214.0,321.0,432.0,216.0,324.0,436.0,218.0,327.0,440.0,220.0,330.0,444.0,222.0,333.0,448.0,224.0,336.0,452.0,226.0,339.0,456.0,228.0,342.0,460.0,230.0,345.0,464.0,232.0,348.0,468.0,234.0,351.0,472.0,236.0,354.0,476.0,238.0,357.0,480.0,240.0,360.0,484.0,242.0,363.0,488.0,244.0,366.0,492.0,246.0,369.0,496.0,248.0,372.0,500.0,250.0,375.0,504.0,252.0,378.0,508.0,254.0,381.0,512.0,256.0,384.0,516.0,258.0,387.0,520.0,260.0,390.0,524.0,262.0,393.0,528.0,264.0,396.0,532.0,266.0,399.0,536.0,268.0,402.0,540.0,270.0,405.0,544.0,272.0,408.0,548.0,274.0,411.0,552.0,276.0,414.0,556.0,278.0,417.0,560.0,280.0,420.0,564.0,282.0,423.0,568.0,284.0,426.0,572.0,286.0,429.0,576.0,288.0,432.0,580.0,290.0,435.0,584.0,292.0,438.0,588.0,294.0,441.0,592.0,296.0,444.0,596.0,298.0,447.0,600.0,300.0,450.0,604.0,302.0,453.0,608.0,304.0,456.0,612.0,306.0,459.0,616.0,308.0,462.0,620.0,310.0,465.0,624.0,312.0,468.0,628.0,314.0,471.0,632.0,316.0,474.0,636.0,318.0,477.0,640.0,320.0,480.0,644.0,322.0,483.0,648.0,324.0,486.0,652.0,326.0,489.0,656.0,328.0,492.0,660.0,330.0,495.0,664.0,332.0,498.0,668.0,334.0,501.0,672.0,336.0,504.0,676.0,338.0,507.0,680.0,340.0,510.0,684.0,342.0,513.0,688.0,344.0,516.0,692.0,346.0,519.0,696.0,348.0,522.0,700.0,350.0,525.0,704.0,352.0,528.0,708.0,354.0,531.0,712.0,356.0,534.0,716.0,358.0,537.0,720.0,360.0,540.0,724.0,362.0,543.0,728.0,364.0,546.0,732.0,366.0,549.0,736.0,368.0,552.0,740.0,370.0,555.0,744.0,372.0,558.0,748.0,374.0,561.0,752.0,376.0,564.0,756.0,378.0,567.0,760.0,380.0,570.0,764.0,382.0,573.0,768.0,384.0,576.0,772.0,386.0,579.0,776.0,388.0,582.0,780.0,390.0,585.0,784.0,392.0,588.0,788.0,394.0,591.0,792.0,396.0,594.0,796.0,398.0,597.0,800.0,400.0,600.0,804.0,402.0,603.0,808.0,404.0,606.0,812.0,406.0,609.0,816.0,408.0,612.0,820.0,410.0,615.0,824.0,412.0,618.0,828.0,414.0,621.0,832.0,416.0,624.0,836.0,418.0,627.0,840.0,420.0,630.0,844.0,422.0,633.0,848.0,424.0,636.0,852.0,426.0,639.0,856.0,428.0,642.0,860.0,430.0,645.0,864.0,432.0,648.0,868.0,434.0,651.0,872.0,436.0,654.0,876.0,438.0,657.0,880.0,440.0,660.0,884.0,442.0,663.0,888.0,444.0,666.0,892.0,446.0,669.0,896.0,448.0,672.0,900.0,450.0,675.0,904.0,452.0,678.0,908.0,454.0,681.0,912.0,456.0,684.0,916.0,458.0,687.0,920.0,460.0,690.0,924.0,462.0,693.0,928.0,464.0,696.0,932.0,466.0,6
float result[1500] = {0};
int dimension[1] = {0};
std::vector<int> dim = {0};
int dimensionLength = 1;
float theoreticalMin[3] = {4,2,3};
float theoreticalMax[3] = {2000.00, 1000.00, 1500.00};
float theoreticalRange[3] = {1996.00, 998.00, 1497.00};
};
class StdTest : public testing::Test {
public:
Nd4jLong examplesShape[4] = {10,5,10,15};
int dimensionsForStd[3] = {0,2,3};
std::vector<int> dimsForStd = {0,2,3};
int dimensionLength = 3;
//standard deviation
int opNum = 1;
float x[7500] ={0.5786382,0.16236664,0.069020785,0.9840061,0.941816,0.76720303,0.7794372,0.46979624,0.73381734,0.9957244,0.6167372,0.53088397,0.28015637,0.826945,0.83352476,0.66504276,0.5793391,0.47484478,0.7076381,0.49456358,0.62396896,0.53332835,0.6388812,0.68836075,0.26663998,0.0014623206,0.19409843,0.56639415,0.98213744,0.68497056,0.867037,0.76840234,0.318186,0.28759065,0.11965875,0.53291357,0.53767395,0.55705845,0.7467155,0.1575149,0.18076386,0.8174763,0.22883898,0.5071535,0.86735153,0.9635827,0.24558435,0.15767147,0.458882,0.71102697,0.21914826,0.16241662,0.27248728,0.89015275,0.71070856,0.55088985,0.98992974,0.70927286,0.9261268,0.50781846,0.62151235,0.4590896,0.7487442,0.21744072,0.2636398,0.084352165,0.46951914,0.383644,0.6749645,0.24111961,0.83259743,0.05546627,0.4790621,0.68884027,0.90992177,0.23907907,0.5342047,0.221003,0.29615387,0.43343517,0.16554528,0.73144174,0.52923626,0.10688303,0.78197056,0.39259177,0.43832788,0.052234255,0.5795483,0.97033966,0.7392455,0.086584255,0.9092887,0.9402065,0.9126419,0.44749174,0.20514569,0.8749829,0.30917913,0.10170506,0.37034252,0.7427814,0.5497875,0.3116048,0.12112484,0.07918618,0.6003074,0.6188079,0.6292188,0.26580265,0.42029652,0.9863358,0.41489154,0.23757206,0.30395788,0.75231904,0.76751274,0.6324773,0.3231405,0.5016677,0.86029065,0.575702,0.7473972,0.118974194,0.115586124,0.62481487,0.91101325,0.6137756,0.71462154,0.995567,0.93439484,0.37260458,0.6033152,0.3444346,0.91579247,0.7452442,0.97466874,0.6299154,0.35426098,0.50121397,0.14155711,0.78726757,0.028531995,0.8435531,0.6444501,0.8826095,0.25354537,0.5547923,0.99555415,0.8430975,246.29712,253.4231,282.26755,215.6161,251.57019,239.20515,296.2021,234.32518,278.9852,235.4248,238.70155,256.9956,212.62695,288.38763,231.21237,284.80396,261.86835,223.92522,205.86221,234.742,262.11407,298.1942,242.60652,238.83704,251.6588,267.23315,294.4865,223.47488,259.24976,251.82695,265.01166,234.65732,265.1853,202.15352,244.42313,253.90427,212.09233,227.62961,237.77951,261.36838,234.32147,240.81522,273.62595,221.19333,284.11353,216.00859,284.36948,243.90376,282.61584,256.97165,275.08722,253.8055,265.1405,298.87567,223.393,288.02148,287.26102,276.36237,290.52777,299.57062,224.73566,290.82623,231.3513,238.51828,230.74028,224.97539,290.11844,238.00816,290.39606,291.32538,272.94766,211.88446,291.66742,210.34077,285.62628,246.31918,283.68738,282.34418,223.43613,245.08679,235.22693,246.01146,224.03375,280.5359,226.01413,262.18884,237.87335,238.89404,259.04294,202.59842,294.69302,209.01956,244.75763,264.3232,293.4627,287.69165,236.79088,282.37012,222.24211,293.5885,249.6388,273.91916,215.40356,255.45584,268.4702,275.81577,259.25064,224.95108,250.37906,267.89093,256.31766,227.89124,204.10915,263.38596,213.62708,218.84116,289.00494,216.93646,200.29439,284.1103,216.20671,260.57642,248.57745,241.73776,244.7205,286.86218,206.42664,204.06395,216.60626,224.02377,219.4697,287.2509,246.91132,289.83777,292.73767,202.73048,206.4165,294.0605,276.23276,288.51318,279.45175,253.69833,281.3311,249.44318,287.76288,262.2878,238.2247,203.41438,208.8359,274.0062,-9.999092,-9.99934,-9.999794,-9.999654,-9.999987,-9.999574,-9.99965,-9.999892,-9.999203,-9.999798,-9.999658,-9.999974,-9.999982,-9.999003,-9.999369,-9.999311,-9.999708,-9.999327,-9.999302,-9.999419,-9.999553,-9.9991665,-9.999842,-9.9991665,-9.999702,-9.999081,-9.9993725,-9.999735,-9.999399,-9.999073,-9.999045,-9.999458,-9.99971,-9.999414,-9.999165,-9.999782,-9.999417,-9.999513,-9.999398,-9.999933,-9.999367,-9.999933,-9.999302,-9.999572,-9.999926,-9.999371,-9.999746,-9.999628,-9.9995165,-9.999816,-9.9998255,-9.999983,-9.999482,-9.99976,-9.999302,-9.999825,-9.999026,-9.999029,-9.999147,-9.9995,-9.999214,-9.999216,-9.999818,-9.999334,-9.999354,-9.999414,-9.999564,-9.99962,-9.999615,-9.999496,-9.999803,-9.999454,-9.999789,-9.999615,-9.999473,-9.999701,-9.999164,-9.999112,-9.9991865,-9.999779,-9.999639,-9.999739,-9.999949,-9.999005,-9.999157,-9.999394,-9.999148,-9.999729,-9.999721,-9.999721,-9.999678,-9.999215,-9.99921,-9.999848,-9.999702,-9.999167,-9.999995,-9.999203,-9.999381,-9.999537,-9.999643,-9.9998
};
class EuclideanDistanceTest : public testing::Test {
public:
float x[16] = {1,2,3,4,5,6,7,8,1,2,3,4,5,6,7,8};
float y[16] = {2,3,4,5,6,7,8,9,2,3,4,5,6,7,8,9};
float result[9] = {0};
Nd4jLong shapeBuffer[12] = {4,2,2,2,2,8,4,2,1,0,1,99};
int dimensionLength = 3;
int dimension[3] = {1,2,3};
float extraVals[2] = {0,0};
int opNum = 1;
std::vector<int> dim = {1, 2, 3};
};
TEST_F(EuclideanDistanceTest,Test1) {
//int *tadShapeBuffer = shape::computeResultShape(shapeBuffer,dimension,dimensionLength);
nd4j::ArrayOptions::setDataType(shapeBuffer, nd4j::DataType::FLOAT32);
auto tadShapeBuffer = nd4j::ShapeUtils::evalReduceShapeInfo('c', dim, shapeBuffer, false, true, nullptr);
//shape::printShapeInfoLinear("tadShape", tadShapeBuffer);
functions::reduce3::Reduce3<float, float>::exec(opNum,
x,
shapeBuffer,
extraVals,
y,
shapeBuffer,
result,
tadShapeBuffer,
dimension,
dimensionLength);
ASSERT_EQ(result[1],result[0]);
}
TEST_F(StdTest,MultiDimTest) {
auto xShapeInfo = shape::shapeBuffer(4, nd4j::DataType::FLOAT32, examplesShape);
//int *resultShapeInfo = shape::computeResultShape(xShapeInfo,dimensionsForStd,dimensionLength);
auto resultShapeInfo = nd4j::ShapeUtils::evalReduceShapeInfo('c', dimsForStd, xShapeInfo, false, true, nullptr);
int resultLengthAssertion = 5;
ASSERT_EQ(resultLengthAssertion,shape::length(resultShapeInfo));
shape::TAD *tad = new shape::TAD;
tad->init(xShapeInfo,dimensionsForStd,dimensionLength);
float none[1] = {0};
tad->createTadOnlyShapeInfo();
tad->createOffsets();
int tadElementWiseStride = shape::elementWiseStride(tad->tadOnlyShapeInfo);
ASSERT_EQ(0,tadElementWiseStride);
float *result = new float[shape::length(resultShapeInfo)];
functions::reduce::ReduceFloatFunction<float,float>::exec(
opNum,
x,
xShapeInfo,
none,
result,
resultShapeInfo,
dimensionsForStd,
dimensionLength,
tad->tadOnlyShapeInfo,
tad->tadOffsets);
// for(int i = 0; i < shape::length(resultShapeInfo); i++)
// printf("%f\n",result[i]);
delete[] result;
delete tad;
delete[] xShapeInfo;
}
TEST_F(ReduceTest,MatrixTest) {
int opNum = 4;
auto xShapeInfo = nd4j::ShapeBuilders::createShapeInfo(nd4j::DataType::FLOAT32, 'c', 2, shape);
//int *resultShapeInfo = shape::computeResultShape(xShapeInfo,dimension,dimensionLength);
auto resultShapeInfo = nd4j::ShapeUtils::evalReduceShapeInfo('c', dim, xShapeInfo, false, true, nullptr);
int resultLengthAssertion = 3;
ASSERT_EQ(resultLengthAssertion,shape::length(resultShapeInfo));
shape::TAD *tad = new shape::TAD;
tad->init(xShapeInfo,dimension,dimensionLength);
float none[1] = {0};
tad->createTadOnlyShapeInfo();
tad->createOffsets();
auto tadElementWiseStride = shape::elementWiseStride(tad->tadOnlyShapeInfo);
ASSERT_EQ(3,tadElementWiseStride);
functions::reduce::ReduceFloatFunction<float,float>::exec(
opNum,
x,
xShapeInfo,
none,
result,
resultShapeInfo,
dimension,
dimensionLength,
tad->tadOnlyShapeInfo,
tad->tadOffsets);
// for(int i = 0; i < shape::length(resultShapeInfo); i++)
// printf("%f\n",result[i]);
delete tad;
delete[] xShapeInfo;
}