Shyrma adjust (#98)
* - add possibility of passing scalar-array as input parameter for scale factor in adjust hue/contrast/saturation ops - correct typo in function which calculates regularized incomplete beta integral Signed-off-by: Yurii <iuriish@yahoo.com> * - fix bug in betainc cuda kernel Signed-off-by: Yurii <iuriish@yahoo.com> * - start working on implementation of digamma function Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on digamma function (cpu) Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in digamma op Signed-off-by: Yurii <iuriish@yahoo.com> * - make correction n cuda kernel for polyGamma Signed-off-by: Yurii <iuriish@yahoo.com> * - remove unnecessary stuff from betaInc cuda kernel Signed-off-by: Yurii <iuriish@yahoo.com> * - resolve conflicts in DeclarableOpsTests3.cpp after master branch has been merged Signed-off-by: Yurii <iuriish@yahoo.com> * - restore id number of Not opertion in legacy_ops.h Signed-off-by: Yurii <iuriish@yahoo.com> * - correct padding calculation in mkl dnn conv1d causal Signed-off-by: Yurii <iuriish@yahoo.com> * restore empty check in adjust_contrast_v2 Signed-off-by: raver119 <raver119@gmail.com>master
parent
1e9ff114aa
commit
1f5e15b541
|
@ -112,7 +112,8 @@
|
||||||
(4, IsInfOrNan), \
|
(4, IsInfOrNan), \
|
||||||
(5, MatchConditionBool), \
|
(5, MatchConditionBool), \
|
||||||
(6, IsPositive) , \
|
(6, IsPositive) , \
|
||||||
(7, Not)
|
(7, Not), \
|
||||||
|
(8, IsNegative)
|
||||||
|
|
||||||
|
|
||||||
#define TRANSFORM_STRICT_OPS \
|
#define TRANSFORM_STRICT_OPS \
|
||||||
|
@ -279,7 +280,8 @@
|
||||||
(3, IsInfOrNan), \
|
(3, IsInfOrNan), \
|
||||||
(4, IsNan), \
|
(4, IsNan), \
|
||||||
(5, IsInf), \
|
(5, IsInf), \
|
||||||
(6, IsPositive)
|
(6, IsPositive), \
|
||||||
|
(7, IsNegative)
|
||||||
|
|
||||||
#define REDUCE_SAME_OPS \
|
#define REDUCE_SAME_OPS \
|
||||||
(0, Sum), \
|
(0, Sum), \
|
||||||
|
|
|
@ -27,7 +27,8 @@
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, -2, 0) {
|
////////////////////////////////////////////////////////////////////
|
||||||
|
CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, 0, 0) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
@ -37,23 +38,31 @@ CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, -2, 0) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
||||||
REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST: Scale factor required");
|
REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST: Scale factor required");
|
||||||
|
|
||||||
const double factor = block.width() > 1 ? INPUT_VARIABLE(1)->e<double>(0) : T_ARG(0);
|
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST: op expects rank of input array to be >= 3, but got %i instead", input->rankOf());
|
REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST: op expects rank of input array to be >= 3, but got %i instead", input->rankOf());
|
||||||
REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1));
|
REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1));
|
||||||
// compute mean before
|
|
||||||
|
NDArray* factor = nullptr;
|
||||||
|
|
||||||
|
if(block.width() > 1)
|
||||||
|
factor = INPUT_VARIABLE(1);
|
||||||
|
else {
|
||||||
|
factor = new NDArray(output->dataType(), block.launchContext());
|
||||||
|
factor->p(0, T_ARG(0));
|
||||||
|
}
|
||||||
|
|
||||||
// fill up axes vector first
|
// fill up axes vector first
|
||||||
std::vector<int> axes(input->rankOf() - 1);
|
std::vector<int> axes(input->rankOf() - 1);
|
||||||
for (auto i = 0; i < axes.size(); ++i)
|
for (auto i = 0; i < axes.size(); ++i)
|
||||||
axes[i] = i;
|
axes[i] = i;
|
||||||
|
|
||||||
// mean as reduction for last dimension set
|
// mean as reduction for last dimension set
|
||||||
auto mean = input->reduceAlongDims(reduce::Mean, axes);
|
auto mean = input->reduceAlongDims(reduce::Mean, axes);
|
||||||
|
|
||||||
NDArray factorT(output->dataType(), block.launchContext()); // = NDArrayFactory::create(factor, block.launchContext());
|
|
||||||
factorT.p(0, factor);
|
|
||||||
// this is contrast calculation
|
// this is contrast calculation
|
||||||
output->assign((*input - mean) * factorT + mean);
|
output->assign((*input - mean) * (*factor) + mean);
|
||||||
|
|
||||||
|
if(block.width() == 1)
|
||||||
|
delete factor;
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -64,45 +73,54 @@ DECLARE_TYPES(adjust_contrast) {
|
||||||
->setSameMode(true);
|
->setSameMode(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
|
CONFIGURABLE_OP_IMPL(adjust_contrast_v2, 1, 1, true, 0, 0) {
|
||||||
|
|
||||||
CONFIGURABLE_OP_IMPL(adjust_contrast_v2, 1, 1, true, -2, 0) {
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
auto input = INPUT_VARIABLE(0);
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
|
||||||
|
|
||||||
REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST_V2: Scale factor required");
|
|
||||||
|
|
||||||
const double factor = block.width() > 1 ? INPUT_VARIABLE(1)->e<double>(0) : T_ARG(0);
|
|
||||||
|
|
||||||
// just skip op if input is empty
|
|
||||||
if (input->isEmpty())
|
|
||||||
return Status::OK();
|
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST_V2: op expects rank of input array to be >= 3, but got %i instead", input->rankOf());
|
|
||||||
REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST_V2: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1));
|
|
||||||
|
|
||||||
// compute mean before
|
|
||||||
std::vector<int> axes(input->rankOf() - 1);
|
|
||||||
for (auto i = 0; i < axes.size(); ++i)
|
|
||||||
axes[i] = i;
|
|
||||||
|
|
||||||
// mean as reduction for last dimension set
|
|
||||||
auto mean = input->reduceAlongDims(reduce::Mean, axes);
|
|
||||||
|
|
||||||
// result as (x - mean) * factor + mean
|
|
||||||
auto temp = input->ulike();
|
|
||||||
input->applyTrueBroadcast(BroadcastOpsTuple::Subtract(), &mean, &temp);
|
|
||||||
temp.applyScalar(scalar::Multiply, factor);
|
|
||||||
temp.applyTrueBroadcast(BroadcastOpsTuple::Add(), &mean, output);
|
|
||||||
|
|
||||||
|
// just skip op if input is empty
|
||||||
|
if (input->isEmpty())
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
||||||
|
REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST_V2: op expects rank of input array to be >= 3, but got %i instead", input->rankOf());
|
||||||
|
REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST_V2: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1));
|
||||||
|
REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST_V2: Scale factor required");
|
||||||
|
|
||||||
|
NDArray* factor = nullptr;
|
||||||
|
|
||||||
|
if(block.width() > 1)
|
||||||
|
factor = INPUT_VARIABLE(1);
|
||||||
|
else {
|
||||||
|
factor = new NDArray(output->dataType(), block.launchContext());
|
||||||
|
factor->p(0, T_ARG(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(adjust_contrast_v2) {
|
// compute mean before
|
||||||
getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)
|
std::vector<int> axes(input->rankOf() - 1);
|
||||||
->setAllowedOutputTypes({ALL_FLOATS})
|
for (auto i = 0; i < axes.size(); ++i)
|
||||||
->setSameMode(true);
|
axes[i] = i;
|
||||||
}
|
|
||||||
|
// mean as reduction for last dimension set
|
||||||
|
auto mean = input->reduceAlongDims(reduce::Mean, axes);
|
||||||
|
|
||||||
|
// result as (x - mean) * factor + mean
|
||||||
|
auto temp = input->ulike();
|
||||||
|
input->applyTrueBroadcast(BroadcastOpsTuple::Subtract(), &mean, &temp);
|
||||||
|
temp.applyScalarArr(scalar::Multiply, factor);
|
||||||
|
temp.applyTrueBroadcast(BroadcastOpsTuple::Add(), &mean, output);
|
||||||
|
|
||||||
|
if(block.width() == 1)
|
||||||
|
delete factor;
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(adjust_contrast_v2) {
|
||||||
|
getOpDescriptor()->setAllowedInputTypes(nd4j::DataType::ANY)
|
||||||
|
->setAllowedOutputTypes({ALL_FLOATS})
|
||||||
|
->setSameMode(true);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,13 +24,12 @@
|
||||||
|
|
||||||
#include <ops/declarable/headers/parity_ops.h>
|
#include <ops/declarable/headers/parity_ops.h>
|
||||||
#include <ops/declarable/helpers/adjust_hue.h>
|
#include <ops/declarable/helpers/adjust_hue.h>
|
||||||
#include <NDArrayFactory.h>
|
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
|
|
||||||
CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, 1, -2) {
|
CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, 0, 0) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
@ -41,15 +40,26 @@ CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, 1, -2) {
|
||||||
|
|
||||||
const int rank = input->rankOf();
|
const int rank = input->rankOf();
|
||||||
const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
|
const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
|
||||||
const double delta = T_ARG(0);
|
|
||||||
|
|
||||||
|
REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_HUE: delta factor is required !");
|
||||||
REQUIRE_TRUE(rank >= 3, 0, "ADJUST_HUE: op expects rank of input array to be >= 3, but got %i instead", rank);
|
REQUIRE_TRUE(rank >= 3, 0, "ADJUST_HUE: op expects rank of input array to be >= 3, but got %i instead", rank);
|
||||||
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_HUE: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC));
|
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_HUE: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC));
|
||||||
REQUIRE_TRUE(-1. <= delta && delta <= 1., 0, "ADJUST_HUE: parameter delta must be within [-1, 1] interval, but got %f instead", delta);
|
|
||||||
|
|
||||||
NDArray deltaScalarArr = NDArrayFactory::create<double>(delta, block.launchContext());
|
NDArray* delta = nullptr;
|
||||||
|
|
||||||
helpers::adjustHue(block.launchContext(), input, &deltaScalarArr, output, dimC);
|
if(block.width() > 1)
|
||||||
|
delta = INPUT_VARIABLE(1);
|
||||||
|
else {
|
||||||
|
delta = new NDArray(output->dataType(), block.launchContext());
|
||||||
|
delta->p(0, T_ARG(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
REQUIRE_TRUE(-1. <= delta->e<double>(0) && delta->e<double>(0) <= 1., 0, "ADJUST_HUE: parameter delta must be within [-1, 1] interval, but got %f instead", delta);
|
||||||
|
|
||||||
|
helpers::adjustHue(block.launchContext(), input, delta, output, dimC);
|
||||||
|
|
||||||
|
if(block.width() == 1)
|
||||||
|
delete delta;
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,7 +28,7 @@
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, 1, -2) {
|
CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, 0, 0) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
@ -37,16 +37,26 @@ CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, 1, -2) {
|
||||||
if (input->isEmpty())
|
if (input->isEmpty())
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
||||||
const int rank = input->rankOf();
|
const int rank = input->rankOf();
|
||||||
const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
|
const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
|
||||||
const double factor = T_ARG(0);
|
|
||||||
|
|
||||||
REQUIRE_TRUE(rank >= 3, 0, "ADJUST_SATURATION: op expects rank of input array to be >= 3, but got %i instead", rank);
|
REQUIRE_TRUE(rank >= 3, 0, "ADJUST_SATURATION: op expects rank of input array to be >= 3, but got %i instead", rank);
|
||||||
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_SATURATION: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC));
|
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_SATURATION: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC));
|
||||||
|
REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_SATURATION: scale factor is required !");
|
||||||
|
|
||||||
NDArray factorScalarArr = NDArrayFactory::create<double>(factor, block.launchContext());
|
NDArray* factor = nullptr;
|
||||||
|
|
||||||
helpers::adjustSaturation(block.launchContext(), input, &factorScalarArr, output, dimC);
|
if(block.width() > 1)
|
||||||
|
factor = INPUT_VARIABLE(1);
|
||||||
|
else {
|
||||||
|
factor = new NDArray(output->dataType(), block.launchContext());
|
||||||
|
factor->p(0, T_ARG(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
helpers::adjustSaturation(block.launchContext(), input, factor, output, dimC);
|
||||||
|
|
||||||
|
if(block.width() == 1)
|
||||||
|
delete factor;
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -15,27 +16,35 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// Created by Yurii Shyrma on 13.12.2017.
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#ifndef LIBND4J_POLYGAMMA_H
|
#include <op_boilerplate.h>
|
||||||
#define LIBND4J_POLYGAMMA_H
|
#if NOT_EXCLUDED(OP_digamma)
|
||||||
|
|
||||||
#include <ops/declarable/helpers/helpers.h>
|
#include <ops/declarable/CustomOperations.h>
|
||||||
#include "NDArray.h"
|
#include <ops/declarable/helpers/gammaMathFunc.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
|
||||||
|
|
||||||
|
CONFIGURABLE_OP_IMPL(digamma, 1, 1, false, 0, 0) {
|
||||||
|
|
||||||
// calculate the polygamma function
|
auto x = INPUT_VARIABLE(0);
|
||||||
void polyGamma(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output);
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
helpers::diGamma(block.launchContext(), *x, *z);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(digamma) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes({ALL_FLOATS, ALL_INTS})
|
||||||
|
->setSameMode(true);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
#endif
|
||||||
#endif //LIBND4J_POLYGAMMA_H
|
|
|
@ -15,14 +15,14 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author Yurii Shyrma (iuriish@yahoo.com), created on 13.12.2017
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <op_boilerplate.h>
|
#include <op_boilerplate.h>
|
||||||
#if NOT_EXCLUDED(OP_polygamma)
|
#if NOT_EXCLUDED(OP_polygamma)
|
||||||
|
|
||||||
#include <ops/declarable/CustomOperations.h>
|
#include <ops/declarable/CustomOperations.h>
|
||||||
#include <ops/declarable/helpers/polyGamma.h>
|
#include <ops/declarable/helpers/gammaMathFunc.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -37,11 +37,11 @@ CONFIGURABLE_OP_IMPL(polygamma, 2, 1, false, 0, 0) {
|
||||||
|
|
||||||
Nd4jLong arrLen = n->lengthOf();
|
Nd4jLong arrLen = n->lengthOf();
|
||||||
// FIXME: this shit should be single op call, not a loop!
|
// FIXME: this shit should be single op call, not a loop!
|
||||||
auto nPositive = n->reduceNumber(nd4j::reduce::IsPositive, nullptr);
|
auto nNegative = n->reduceNumber(nd4j::reduce::IsNegative, nullptr);
|
||||||
auto xPositive = x->reduceNumber(nd4j::reduce::IsPositive, nullptr);
|
auto xPositive = x->reduceNumber(nd4j::reduce::IsPositive, nullptr);
|
||||||
bool nPositiveFlag = nPositive.e<bool>(0);
|
bool nPositiveFlag = !nNegative.e<bool>(0); // require all n >= 0
|
||||||
bool xPositiveFlag = xPositive.e<bool>(0);
|
bool xPositiveFlag = xPositive.e<bool>(0); // require all x > 0
|
||||||
REQUIRE_TRUE(nPositiveFlag, 0, "POLYGAMMA op: all elements of n array must be > 0 !");
|
REQUIRE_TRUE(nPositiveFlag, 0, "POLYGAMMA op: all elements of n array must be >= 0 !");
|
||||||
REQUIRE_TRUE(xPositiveFlag, 0, "POLYGAMMA op: all elements of x array must be > 0 !");
|
REQUIRE_TRUE(xPositiveFlag, 0, "POLYGAMMA op: all elements of x array must be > 0 !");
|
||||||
|
|
||||||
helpers::polyGamma(block.launchContext(), *n, *x, *output);
|
helpers::polyGamma(block.launchContext(), *n, *x, *output);
|
||||||
|
|
|
@ -513,7 +513,6 @@ namespace nd4j {
|
||||||
/**
|
/**
|
||||||
* This op calculates polygamma function psi^(n)(x). Implementation is based on serial representation written in
|
* This op calculates polygamma function psi^(n)(x). Implementation is based on serial representation written in
|
||||||
* terms of the Hurwitz zeta function: polygamma = (-1)^{n+1} * n! * zeta(n+1, x).
|
* terms of the Hurwitz zeta function: polygamma = (-1)^{n+1} * n! * zeta(n+1, x).
|
||||||
* Currently the case n = 0 is not supported.
|
|
||||||
*
|
*
|
||||||
* Input arrays:
|
* Input arrays:
|
||||||
* 0: n - define derivative order (n+1), type integer (however currently is implemented as float casted to integer)
|
* 0: n - define derivative order (n+1), type integer (however currently is implemented as float casted to integer)
|
||||||
|
@ -528,6 +527,20 @@ namespace nd4j {
|
||||||
DECLARE_CONFIGURABLE_OP(polygamma, 2, 1, false, 0, 0);
|
DECLARE_CONFIGURABLE_OP(polygamma, 2, 1, false, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This op calculates digamma function psi(x) = derivative of log(Gamma(x))
|
||||||
|
*
|
||||||
|
* Input arrays:
|
||||||
|
* 0: x - abscissa points where to evaluate the digamma function, type float
|
||||||
|
*
|
||||||
|
* Output array:
|
||||||
|
* 0: values of digamma function at corresponding x, type float
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_digamma)
|
||||||
|
DECLARE_CONFIGURABLE_OP(digamma, 1, 1, false, 0, 0);
|
||||||
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This operation takes shape as first argument, and returns new NDArray filled with specific scalar value.
|
* This operation takes shape as first argument, and returns new NDArray filled with specific scalar value.
|
||||||
* Input arrays:
|
* Input arrays:
|
||||||
|
@ -575,44 +588,47 @@ namespace nd4j {
|
||||||
* This operation adjusts image hue by delta
|
* This operation adjusts image hue by delta
|
||||||
* Input arrays:
|
* Input arrays:
|
||||||
* 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels.
|
* 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels.
|
||||||
|
* 1 - optional argument, input scalar-array containing delta
|
||||||
*
|
*
|
||||||
* T arguments:
|
* T arguments:
|
||||||
* 0 - delta value
|
* 0 - optional argument, delta value
|
||||||
*
|
*
|
||||||
* Int arguments:
|
* Int arguments:
|
||||||
* 0 - optional argument, corresponds to dimension with 3 channels
|
* 0 - optional argument, corresponds to dimension with 3 channels
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_adjust_hue)
|
#if NOT_EXCLUDED(OP_adjust_hue)
|
||||||
DECLARE_CONFIGURABLE_OP(adjust_hue, 1, 1, true, 1, -2);
|
DECLARE_CONFIGURABLE_OP(adjust_hue, 1, 1, true, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This operation adjusts image saturation by delta
|
* This operation adjusts image saturation by delta
|
||||||
* Input arrays:
|
* Input arrays:
|
||||||
* 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels.
|
* 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels.
|
||||||
|
* 1 - optional argument, input scalar-array containing saturation factor
|
||||||
*
|
*
|
||||||
* T arguments:
|
* T arguments:
|
||||||
* 0 - saturation factor
|
* 0 - optional argument, saturation factor
|
||||||
*
|
*
|
||||||
* Int arguments:
|
* Int arguments:
|
||||||
* 0 - optional argument, corresponds to dimension with 3 channels
|
* 0 - optional argument, corresponds to dimension with 3 channels
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_adjust_saturation)
|
#if NOT_EXCLUDED(OP_adjust_saturation)
|
||||||
DECLARE_CONFIGURABLE_OP(adjust_saturation, 1, 1, true, 1, -2);
|
DECLARE_CONFIGURABLE_OP(adjust_saturation, 1, 1, true, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This operation adjusts image contrast by given factor ( z = (x - mean) * factor + mean )
|
* This operation adjusts image contrast by given factor ( z = (x - mean) * factor + mean )
|
||||||
* Input arrays:
|
* Input arrays:
|
||||||
* 0 - input array with rank >= 3, must have last one dimension equal 3, that is dimension containing channels.
|
* 0 - input array with rank >= 3, must have last one dimension equal 3, that is dimension containing channels.
|
||||||
|
* 1 - optional argument, input scalar-array containing saturation contrast factor
|
||||||
*
|
*
|
||||||
* T arguments:
|
* T arguments:
|
||||||
* 0 - contrast factor
|
* 0 - optional argument, contrast factor
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_adjust_contrast)
|
#if NOT_EXCLUDED(OP_adjust_contrast)
|
||||||
DECLARE_CONFIGURABLE_OP(adjust_contrast, 1, 1, true, -2, 0);
|
DECLARE_CONFIGURABLE_OP(adjust_contrast, 1, 1, true, 0, 0);
|
||||||
DECLARE_CONFIGURABLE_OP(adjust_contrast_v2, 1, 1, true, -2, 0);
|
DECLARE_CONFIGURABLE_OP(adjust_contrast_v2, 1, 1, true, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
@ -1832,7 +1848,7 @@ namespace nd4j {
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* compare_and_bitpack - compare with greater and pack result with uint8
|
* compare_and_bitpack - compare with greater and pack result with uint8
|
||||||
*
|
*
|
||||||
* input params:
|
* input params:
|
||||||
* 0 - NDArray (input)
|
* 0 - NDArray (input)
|
||||||
|
|
|
@ -107,12 +107,12 @@ static T betaIncCore(T a, T b, T x) {
|
||||||
return x;
|
return x;
|
||||||
|
|
||||||
const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b);
|
const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b);
|
||||||
const T front = math::nd4j_exp<T,T>(math::nd4j_log<T, T>(x) * a + math::nd4j_log<T, T>(1 - x) * b - gammaPart) / a;
|
const T front = math::nd4j_exp<T,T>(math::nd4j_log<T, T>(x) * a + math::nd4j_log<T, T>(1.f - x) * b - gammaPart);
|
||||||
|
|
||||||
if (x <= (a + static_cast<T>(1)) / (a + b + static_cast<T>(2)))
|
if (x <= (a + static_cast<T>(1)) / (a + b + static_cast<T>(2)))
|
||||||
return front * continuedFraction(a, b, x);
|
return front * continuedFraction(a, b, x) / a;
|
||||||
else // symmetry relation
|
else // symmetry relation
|
||||||
return static_cast<T>(1) - front * continuedFraction(b, a, static_cast<T>(1) - x);
|
return static_cast<T>(1) - front * continuedFraction(b, a, static_cast<T>(1) - x) / b;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,53 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include<ops/declarable/helpers/gammaMathFunc.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
// calculate digamma function for array elements
|
||||||
|
template <typename T>
|
||||||
|
static void diGamma_(const NDArray& x, NDArray& z) {
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
for (auto i = start; i < stop; i += increment)
|
||||||
|
z.p(i, diGammaScalar<T>(x.e<T>(i)));
|
||||||
|
};
|
||||||
|
samediff::Threads::parallel_for(func, 0, x.lengthOf());
|
||||||
|
}
|
||||||
|
|
||||||
|
void diGamma(nd4j::LaunchContext* context, const NDArray& x, NDArray& z) {
|
||||||
|
|
||||||
|
BUILD_SINGLE_SELECTOR(x.dataType(), diGamma_, (x, z), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void diGamma_, (const NDArray& x, NDArray& z), FLOAT_TYPES);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
// Created by Yurii Shyrma on 12.12.2017
|
// Created by Yurii Shyrma on 12.12.2017
|
||||||
//
|
//
|
||||||
|
|
||||||
#include<ops/declarable/helpers/polyGamma.h>
|
#include<ops/declarable/helpers/gammaMathFunc.h>
|
||||||
#include<ops/declarable/helpers/zeta.h>
|
#include<ops/declarable/helpers/zeta.h>
|
||||||
#include <NDArrayFactory.h>
|
#include <NDArrayFactory.h>
|
||||||
#include <execution/Threads.h>
|
#include <execution/Threads.h>
|
||||||
|
@ -42,7 +42,7 @@ static FORCEINLINE T getFactorial(const int n) {
|
||||||
|
|
||||||
for(int i = 2; i <= n; ++i)
|
for(int i = 2; i <= n; ++i)
|
||||||
result *= i;
|
result *= i;
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -50,17 +50,15 @@ static FORCEINLINE T getFactorial(const int n) {
|
||||||
// implementation is based on serial representation written in terms of the Hurwitz zeta function as polygamma = (-1)^{n+1} * n! * zeta(n+1, x)
|
// implementation is based on serial representation written in terms of the Hurwitz zeta function as polygamma = (-1)^{n+1} * n! * zeta(n+1, x)
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static FORCEINLINE T polyGammaScalar(nd4j::LaunchContext * context, const int n, const T x) {
|
static FORCEINLINE T polyGammaScalar(nd4j::LaunchContext * context, const int n, const T x) {
|
||||||
|
|
||||||
// if (n < 0)
|
// if (n < 0)
|
||||||
// throw("polyGamma function: n must be >= 0 !");
|
// throw("polyGamma function: n must be >= 0 !");
|
||||||
|
|
||||||
// if (x <= (T)0.)
|
// if (x <= (T)0.)
|
||||||
// throw("polyGamma function: x must be > 0 !");
|
// throw("polyGamma function: x must be > 0 !");
|
||||||
|
|
||||||
// TODO case for n = 0 (digamma)
|
|
||||||
|
|
||||||
int sign = (n + 1) % 2 ? -1 : 1;
|
int sign = (n + 1) % 2 ? -1 : 1;
|
||||||
// T factorial = (T)std::tgamma(n + 1);
|
// T factorial = (T)std::tgamma(n + 1);
|
||||||
|
|
||||||
return sign * getFactorial<T>(n) * zetaScalar<T>((T)(n + 1), x);
|
return sign * getFactorial<T>(n) * zetaScalar<T>((T)(n + 1), x);
|
||||||
}
|
}
|
||||||
|
@ -71,17 +69,18 @@ static FORCEINLINE T polyGammaScalar(nd4j::LaunchContext * context, const int n,
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void polyGamma_(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output) {
|
static void polyGamma_(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output) {
|
||||||
|
|
||||||
NDArray& result = output;
|
|
||||||
|
|
||||||
int xLen = x.lengthOf();
|
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto i = start; i < stop; i += increment)
|
for (auto i = start; i < stop; i += increment) {
|
||||||
result.p(i, polyGammaScalar<T>(context, n.e<int>(i), x.e<T>(i)));
|
const T order = n.e<T>(i);
|
||||||
|
if(order != static_cast<int>(order)) // if order has fractional part then do not perform calculations and return NAN
|
||||||
|
output.p(i, std::numeric_limits<T>::quiet_NaN());
|
||||||
|
else if (order == 0) // polygamma function of zero order is digamma function
|
||||||
|
output.p(i, diGammaScalar<T>(x.e<T>(i)));
|
||||||
|
else
|
||||||
|
output.p(i, polyGammaScalar<T>(context, order, x.e<T>(i)));
|
||||||
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, x.lengthOf());
|
samediff::Threads::parallel_for(func, 0, x.lengthOf());
|
||||||
|
|
||||||
// return result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void polyGamma(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output) {
|
void polyGamma(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output) {
|
||||||
|
|
|
@ -89,20 +89,6 @@ __device__ T continuedFractionCuda(const T a, const T b, const T x) {
|
||||||
return 1.f / 0.f; // no convergence, more iterations is required
|
return 1.f / 0.f; // no convergence, more iterations is required
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
|
||||||
// evaluates incomplete beta function for positive a and b, and x between 0 and 1.
|
|
||||||
template <typename T>
|
|
||||||
__device__ T betaIncCoreCuda(T a, T b, T x) {
|
|
||||||
|
|
||||||
const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b);
|
|
||||||
const T front = math::nd4j_exp<T,T>(math::nd4j_log<T, T>(x) * a + math::nd4j_log<T, T>(1 - x) * b - gammaPart) / a;
|
|
||||||
|
|
||||||
if (x <= (a + static_cast<T>(1)) / (a + b + static_cast<T>(2)))
|
|
||||||
return front * continuedFractionCuda(a, b, x);
|
|
||||||
else // symmetry relation
|
|
||||||
return static_cast<T>(1) - front * continuedFractionCuda(b, a, static_cast<T>(1) - x);
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
template<typename T>
|
template<typename T>
|
||||||
__global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo,
|
__global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo,
|
||||||
|
@ -115,12 +101,21 @@ __global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo,
|
||||||
|
|
||||||
const Nd4jLong j = blockIdx.x; // one block per each element
|
const Nd4jLong j = blockIdx.x; // one block per each element
|
||||||
|
|
||||||
Nd4jLong len = shape::length(xShapeInfo);
|
T& z = *(reinterpret_cast<T*>(vz) + shape::getIndexOffset(j, zShapeInfo));
|
||||||
|
|
||||||
const T a = *(reinterpret_cast<const T*>(va) + shape::getIndexOffset(j, aShapeInfo));
|
__shared__ T a, b, x;
|
||||||
const T b = *(reinterpret_cast<const T*>(vb) + shape::getIndexOffset(j, bShapeInfo));
|
__shared__ bool symmCond;
|
||||||
const T x = *(reinterpret_cast<const T*>(vx) + shape::getIndexOffset(j, xShapeInfo));
|
|
||||||
T& z = *(reinterpret_cast<T*>(vz) + shape::getIndexOffset(j, zShapeInfo));
|
if (threadIdx.x == 0) {
|
||||||
|
|
||||||
|
a = *(reinterpret_cast<const T*>(va) + shape::getIndexOffset(j, aShapeInfo));
|
||||||
|
b = *(reinterpret_cast<const T*>(vb) + shape::getIndexOffset(j, bShapeInfo));
|
||||||
|
x = *(reinterpret_cast<const T*>(vx) + shape::getIndexOffset(j, xShapeInfo));
|
||||||
|
|
||||||
|
symmCond = x <= (a + static_cast<T>(1)) / (a + b + static_cast<T>(2));
|
||||||
|
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
// t^{n-1} * (1 - t)^{n-1} is symmetric function with respect to x = 0.5
|
// t^{n-1} * (1 - t)^{n-1} is symmetric function with respect to x = 0.5
|
||||||
if(a == b && x == static_cast<T>(0.5)) {
|
if(a == b && x == static_cast<T>(0.5)) {
|
||||||
|
@ -135,17 +130,31 @@ __global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo,
|
||||||
|
|
||||||
if(threadIdx.x % 2 == 0) { /***** even part *****/
|
if(threadIdx.x % 2 == 0) { /***** even part *****/
|
||||||
const int m = threadIdx.x + 1;
|
const int m = threadIdx.x + 1;
|
||||||
sharedMem[threadIdx.x] = m * (b - m) * x / ((a + 2 * m - static_cast<T>(1)) * (a + 2 * m));
|
if(symmCond)
|
||||||
|
sharedMem[threadIdx.x] = m * (b - m) * x / ((a + 2 * m - static_cast<T>(1)) * (a + 2 * m));
|
||||||
|
else
|
||||||
|
sharedMem[threadIdx.x] = m * (a - m) * (1.f-x) / ((b + 2 * m - static_cast<T>(1)) * (b + 2 * m));
|
||||||
}
|
}
|
||||||
else { /***** odd part *****/
|
else { /***** odd part *****/
|
||||||
const int m = threadIdx.x;
|
const int m = threadIdx.x;
|
||||||
sharedMem[threadIdx.x] = -(a + m) * (a + b + m) * x / ((a + 2 * m + static_cast<T>(1)) * (a + 2 * m));
|
if(symmCond)
|
||||||
|
sharedMem[threadIdx.x] = -(a + m) * (a + b + m) * x / ((a + 2 * m + static_cast<T>(1)) * (a + 2 * m));
|
||||||
|
else
|
||||||
|
sharedMem[threadIdx.x] = -(b + m) * (a + b + m) * (1.f-x) / ((b + 2 * m + static_cast<T>(1)) * (b + 2 * m));
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
if(threadIdx.x == 0)
|
if(threadIdx.x == 0) {
|
||||||
z = betaIncCoreCuda(a, b, x);
|
|
||||||
|
const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b);
|
||||||
|
const T front = math::nd4j_exp<T,T>(math::nd4j_log<T, T>(x) * a + math::nd4j_log<T, T>(1.f - x) * b - gammaPart);
|
||||||
|
|
||||||
|
if (symmCond)
|
||||||
|
z = front * continuedFractionCuda(a, b, x) / a;
|
||||||
|
else // symmetry relation
|
||||||
|
z = static_cast<T>(1) - front * continuedFractionCuda(b, a, static_cast<T>(1) - x) / b;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -0,0 +1,78 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include<ops/declarable/helpers/gammaMathFunc.h>
|
||||||
|
#include <NDArrayFactory.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
__global__ static void diGammaCuda(const void *vx, const Nd4jLong *xShapeInfo,
|
||||||
|
void *vz, const Nd4jLong *zShapeInfo) {
|
||||||
|
|
||||||
|
const auto x = reinterpret_cast<const T*>(vx);
|
||||||
|
auto z = reinterpret_cast<T*>(vz);
|
||||||
|
|
||||||
|
__shared__ Nd4jLong len;
|
||||||
|
__shared__ bool sameOffset;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
len = shape::length(xShapeInfo);
|
||||||
|
sameOffset = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < len; i += gridDim.x * blockDim.x) {
|
||||||
|
|
||||||
|
const auto xOffset = shape::getIndexOffset(i, xShapeInfo);
|
||||||
|
const auto zOffset = sameOffset ? xOffset : shape::getIndexOffset(i, zShapeInfo);
|
||||||
|
|
||||||
|
z[zOffset] = diGammaScalar<T>(x[xOffset]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
static void diGammaCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo) {
|
||||||
|
|
||||||
|
diGammaCuda<T><<<blocksPerGrid, threadsPerBlock, 1024, *stream>>>(vx, xShapeInfo, vz, zShapeInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
void diGamma(nd4j::LaunchContext* context, const NDArray& x, NDArray& z) {
|
||||||
|
|
||||||
|
int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
|
int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&z}, {&x});
|
||||||
|
BUILD_SINGLE_SELECTOR(x.dataType(), diGammaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), z.getSpecialBuffer(), z.getSpecialShapeInfo()), FLOAT_TYPES);
|
||||||
|
NDArray::registerSpecialUse({&z}, {&x});
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void diGammaCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo), FLOAT_TYPES);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
// @author Yurii Shyrma (iuriish@yahoo.com), created on 26.04.2019
|
// @author Yurii Shyrma (iuriish@yahoo.com), created on 26.04.2019
|
||||||
//
|
//
|
||||||
|
|
||||||
#include<ops/declarable/helpers/polyGamma.h>
|
#include<ops/declarable/helpers/gammaMathFunc.h>
|
||||||
#include<ops/declarable/helpers/zeta.h>
|
#include<ops/declarable/helpers/zeta.h>
|
||||||
#include <NDArrayFactory.h>
|
#include <NDArrayFactory.h>
|
||||||
|
|
||||||
|
@ -37,9 +37,13 @@ __global__ static void polyGammaCuda(const void *vn, const Nd4jLong *nShapeInfo,
|
||||||
auto z = reinterpret_cast<T*>(vz);
|
auto z = reinterpret_cast<T*>(vz);
|
||||||
|
|
||||||
__shared__ Nd4jLong len;
|
__shared__ Nd4jLong len;
|
||||||
|
__shared__ bool sameOffsetNX, sameOffsetNZ;
|
||||||
|
|
||||||
if (threadIdx.x == 0)
|
if (threadIdx.x == 0) {
|
||||||
len = shape::length(nShapeInfo);
|
len = shape::length(nShapeInfo);
|
||||||
|
sameOffsetNX = shape::haveSameShapeAndStrides(xShapeInfo, nShapeInfo);
|
||||||
|
sameOffsetNZ = shape::haveSameShapeAndStrides(zShapeInfo, nShapeInfo);
|
||||||
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
@ -48,19 +52,26 @@ __global__ static void polyGammaCuda(const void *vn, const Nd4jLong *nShapeInfo,
|
||||||
for (int i = tid; i < len; i += totalThreads) {
|
for (int i = tid; i < len; i += totalThreads) {
|
||||||
|
|
||||||
const auto nOffset = shape::getIndexOffset(i, nShapeInfo);
|
const auto nOffset = shape::getIndexOffset(i, nShapeInfo);
|
||||||
const auto xOffset = shape::getIndexOffset(i, xShapeInfo);
|
const auto xOffset = sameOffsetNX ? nOffset : shape::getIndexOffset(i, xShapeInfo);
|
||||||
const auto zOffset = shape::getIndexOffset(i, zShapeInfo);
|
const auto zOffset = sameOffsetNZ ? nOffset : shape::getIndexOffset(i, zShapeInfo);
|
||||||
|
|
||||||
const T nVal = n[nOffset];
|
const T order = n[nOffset];
|
||||||
|
|
||||||
int sign = (static_cast<int>(nVal) + 1) % 2 ? -1 : 1;
|
int sign = (static_cast<int>(order) + 1) % 2 ? -1 : 1;
|
||||||
|
|
||||||
T factorial = 1;
|
if(order != static_cast<int>(order)) {
|
||||||
if(nVal != 0 && nVal != 1)
|
z[zOffset] = DataTypeUtils::nanOrZero<T>();
|
||||||
for(int i = 2; i <= nVal; ++i)
|
}
|
||||||
factorial *= i;
|
else if(order == 0) {
|
||||||
|
z[zOffset] = diGammaScalar<T>(x[xOffset]);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
T factorial = 1;
|
||||||
|
for(int i = 2; i <= order; ++i)
|
||||||
|
factorial *= i;
|
||||||
|
|
||||||
z[zOffset] = sign * factorial * zetaScalar<T>(nVal + 1, x[xOffset]);
|
z[zOffset] = sign * factorial * zetaScalar<T>(order + 1, x[xOffset]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -76,7 +87,7 @@ void polyGamma(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&z}, {&n, &x});
|
NDArray::prepareSpecialUse({&z}, {&n, &x});
|
||||||
|
|
||||||
int threadsPerBlock = MAX_NUM_THREADS;
|
int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(n.dataType(), polyGammaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), n.getSpecialBuffer(), n.getSpecialShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), z.getSpecialBuffer(), z.getSpecialShapeInfo()), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(n.dataType(), polyGammaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), n.getSpecialBuffer(), n.getSpecialShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), z.getSpecialBuffer(), z.getSpecialShapeInfo()), FLOAT_TYPES);
|
||||||
|
|
|
@ -0,0 +1,100 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef LIBND4J_GAMMAMATHFUNC_H
|
||||||
|
#define LIBND4J_GAMMAMATHFUNC_H
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/helpers.h>
|
||||||
|
#include "NDArray.h"
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
// calculate the digamma function for each element for array
|
||||||
|
void diGamma(nd4j::LaunchContext* context, const NDArray& x, NDArray& z);
|
||||||
|
|
||||||
|
// calculate the polygamma function
|
||||||
|
void polyGamma(nd4j::LaunchContext* context, const NDArray& n, const NDArray& x, NDArray& z);
|
||||||
|
|
||||||
|
// calculate the digamma function for one element
|
||||||
|
// implementation is based on serial representation written in terms of the Hurwitz zeta function as polygamma = (-1)^{n+1} * n! * zeta(n+1, x)
|
||||||
|
template <typename T>
|
||||||
|
_CUDA_HD T diGammaScalar(T x) {
|
||||||
|
|
||||||
|
const int xInt = static_cast<int>(x);
|
||||||
|
|
||||||
|
// negative and zero
|
||||||
|
if(x <= 0) {
|
||||||
|
if(x == xInt) // integer
|
||||||
|
return DataTypeUtils::infOrMax<T>();
|
||||||
|
else
|
||||||
|
return diGammaScalar<T>(1 - x) - M_PI / nd4j::math::nd4j_tan<T,T>(M_PI * x); // use reflection formula psi(1-x) = psi(x) + pi*cot(pi*x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// positive integer
|
||||||
|
if(x == xInt && xInt <= 20) { // psi(n) = -Euler_Mascheroni_const + sum_from_k=1_to_n-1( 1/k ), for n = 1,2,3,...inf, we use this formula only for n <= 20 to avoid time consuming sum calculation for bigger n
|
||||||
|
T result = -0.577215664901532;
|
||||||
|
for (uint i = 1; i <= xInt - 1; ++i) {
|
||||||
|
result += static_cast<T>(1) / i;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// positive half-integer
|
||||||
|
if(x - xInt == 0.5 && xInt <= 20) { // psi(n+0.5) = -Euler_Mascheroni_const - 2*ln(2) + sum_from_k=1_to_n( 2/(2*k-1) ) , for n = 1,2,3,...inf, we use this formula only for n <= 20 to avoid time consuming sum calculation for bigger n
|
||||||
|
T result = -0.577215664901532 - 2 * nd4j::math::nd4j_log<T,T>(2);
|
||||||
|
for (uint i = 1; i <= xInt; ++i) {
|
||||||
|
result += static_cast<T>(2) / (2*i - 1);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// positive, smaller then 5; we should use number > 5 in order to have satisfactory accuracy in asymptotic expansion
|
||||||
|
if(x < 5)
|
||||||
|
return diGammaScalar<T>(1 + x) - static_cast<T>(1) / x; // recurrence formula psi(x) = psi(x+1) - 1/x.
|
||||||
|
|
||||||
|
// *** other positive **** //
|
||||||
|
|
||||||
|
// truncated expansion formula (from wiki)
|
||||||
|
// psi(x) = log(x) - 1/(2*x) - 1/(12*x^2) + 1/(120*x^4) - 1/(252*x^6) + 1/(240*x^8) - 5/(660*x^10) + 691/(32760*x^12) - 1/(12*x^14) + ...
|
||||||
|
|
||||||
|
if(x >= (sizeof(T) > 4 ? 1.e16 : 1.e8)) // if x is too big take into account only log(x)
|
||||||
|
return nd4j::math::nd4j_log<T,T>(x);
|
||||||
|
|
||||||
|
// coefficients used in truncated asymptotic expansion formula
|
||||||
|
const T coeffs[7] = {-(T)1/12, (T)1/120, -(T)1/252, (T)1/240, -(T)5/660, (T)691/32760, -(T)1/12};
|
||||||
|
// const T coeffs[7] = {-0.0833333333333333, 0.00833333333333333, -0.00396825396825397, 0.00416666666666667, -0.00757575757575758, 0.0210927960927961, -0.0833333333333333};
|
||||||
|
|
||||||
|
const T x2Inv = static_cast<T>(1) / (x * x);
|
||||||
|
T result = 0;
|
||||||
|
|
||||||
|
for (int i = 6; i >= 0; --i)
|
||||||
|
result = (result + coeffs[i]) * x2Inv;
|
||||||
|
return result + nd4j::math::nd4j_log<T,T>(x) - static_cast<T>(0.5) / x;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#endif //LIBND4J_GAMMAMATHFUNC_H
|
|
@ -36,7 +36,7 @@ namespace platforms {
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, const NDArray *weights,
|
static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, const NDArray *weights,
|
||||||
const NDArray *bias, NDArray *output, const int kH, const int kW, const int sH,
|
const NDArray *bias, NDArray *output, const int kH, const int kW, const int sH,
|
||||||
const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode,
|
const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode,
|
||||||
const int isNCHW) {
|
const int isNCHW) {
|
||||||
|
|
||||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
|
@ -44,8 +44,7 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con
|
||||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW,
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW,
|
||||||
indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
|
|
||||||
if(isSameMode) // SAME
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
|
||||||
|
|
||||||
dnnl_memory_desc_t empty;
|
dnnl_memory_desc_t empty;
|
||||||
dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(
|
dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(
|
||||||
|
@ -53,7 +52,7 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con
|
||||||
dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
|
dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
|
||||||
empty);
|
empty);
|
||||||
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
||||||
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW,
|
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW,
|
||||||
bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr,
|
bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr,
|
||||||
bias, output,
|
bias, output,
|
||||||
&conv_src_md, nullptr, &conv_weights_md, nullptr,
|
&conv_src_md, nullptr, &conv_weights_md, nullptr,
|
||||||
|
@ -115,13 +114,11 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
PLATFORM_IMPL(conv2d) {
|
PLATFORM_IMPL(conv2d) {
|
||||||
auto input = INPUT_VARIABLE(
|
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
|
||||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(
|
auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
|
||||||
0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
|
|
||||||
|
|
||||||
int sH = INT_ARG(2); // strides height
|
int sH = INT_ARG(2); // strides height
|
||||||
int sW = INT_ARG(3); // strides width
|
int sW = INT_ARG(3); // strides width
|
||||||
|
@ -129,13 +126,13 @@ PLATFORM_IMPL(conv2d) {
|
||||||
int pW = INT_ARG(5); // paddings width
|
int pW = INT_ARG(5); // paddings width
|
||||||
int dH = INT_ARG(6); // dilations height
|
int dH = INT_ARG(6); // dilations height
|
||||||
int dW = INT_ARG(7); // dilations width
|
int dW = INT_ARG(7); // dilations width
|
||||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||||
bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||||
|
|
||||||
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height
|
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height
|
||||||
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width
|
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width
|
||||||
|
|
||||||
conv2d_mkldnn(block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW);
|
conv2d_mkldnn(block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -155,18 +152,13 @@ PLATFORM_CHECK(conv2d) {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
PLATFORM_IMPL(conv2d_bp) {
|
PLATFORM_IMPL(conv2d_bp) {
|
||||||
auto input = INPUT_VARIABLE(
|
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||||
auto weights = INPUT_VARIABLE(
|
|
||||||
1); // [kH, kW, iC, oC] always
|
|
||||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
|
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||||
2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
|
||||||
|
|
||||||
auto gradI = OUTPUT_VARIABLE(
|
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||||
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||||
auto gradW = OUTPUT_VARIABLE(
|
|
||||||
1); // [kH, kW, iC, oC] always
|
|
||||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
|
||||||
int kH = INT_ARG(0); // filter(kernel) height
|
int kH = INT_ARG(0); // filter(kernel) height
|
||||||
|
@ -177,7 +169,7 @@ PLATFORM_IMPL(conv2d_bp) {
|
||||||
int pW = INT_ARG(5); // paddings width
|
int pW = INT_ARG(5); // paddings width
|
||||||
int dH = INT_ARG(6); // dilations height
|
int dH = INT_ARG(6); // dilations height
|
||||||
int dW = INT_ARG(7); // dilations width
|
int dW = INT_ARG(7); // dilations width
|
||||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 4, 0,
|
REQUIRE_TRUE(input->rankOf() == 4, 0,
|
||||||
|
@ -195,8 +187,7 @@ PLATFORM_IMPL(conv2d_bp) {
|
||||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
|
||||||
indIiH, indWiC, indWoC, indWkH, indOoH);
|
indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
|
|
||||||
if (isSameMode) // SAME
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
|
||||||
|
|
||||||
dnnl_memory_desc_t empty;
|
dnnl_memory_desc_t empty;
|
||||||
dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
|
dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
|
||||||
|
@ -204,7 +195,7 @@ PLATFORM_IMPL(conv2d_bp) {
|
||||||
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
|
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
|
||||||
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
|
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
|
||||||
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
||||||
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW,
|
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW,
|
||||||
bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW,
|
bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW,
|
||||||
gradB, gradO,
|
gradB, gradO,
|
||||||
&conv_src_md, &conv_diff_src_md, &conv_weights_md,
|
&conv_src_md, &conv_diff_src_md, &conv_weights_md,
|
||||||
|
@ -342,18 +333,13 @@ PLATFORM_CHECK(conv2d_bp) {
|
||||||
if (::optimalLevel() < 2)
|
if (::optimalLevel() < 2)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(
|
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||||
auto weights = INPUT_VARIABLE(
|
|
||||||
1); // [kH, kW, iC, oC] always
|
|
||||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
|
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||||
2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
|
||||||
|
|
||||||
auto gradI = OUTPUT_VARIABLE(
|
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||||
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||||
auto gradW = OUTPUT_VARIABLE(
|
|
||||||
1); // [kH, kW, iC, oC] always
|
|
||||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1568,11 +1568,39 @@ namespace simdOps {
|
||||||
return opOutput + old;
|
return opOutput + old;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
op_def static Z update(X old, X opOutput, X *extraParams) {
|
op_def static Z update(X old, X opOutput, X *extraParams) {
|
||||||
return opOutput + old;
|
return opOutput + old;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) {
|
||||||
|
return reduction;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename X, typename Z>
|
||||||
|
class IsNegative {
|
||||||
|
public:
|
||||||
|
no_op_exec_special_bool
|
||||||
|
no_op_exec_special_bool_cuda
|
||||||
|
|
||||||
|
no_op_exec_special_accumulation
|
||||||
|
no_op_exec_special_accumulation_cuda
|
||||||
|
|
||||||
|
op_def static Z op(X d1, X *params) {
|
||||||
|
return d1 < (X)0.f;
|
||||||
|
}
|
||||||
|
|
||||||
|
op_def static X startingValue(const X *input) {
|
||||||
|
return static_cast<X>(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
op_def static Z merge(X old, X opOutput, X *extraParams) {
|
||||||
|
return opOutput + old;
|
||||||
|
}
|
||||||
|
|
||||||
|
op_def static Z update(X old, X opOutput, X *extraParams) {
|
||||||
|
return opOutput + old;
|
||||||
|
}
|
||||||
|
|
||||||
op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) {
|
op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) {
|
||||||
return reduction;
|
return reduction;
|
||||||
|
|
|
@ -1008,6 +1008,38 @@ TEST_F(ConvolutionTests1, conv1d_causal_6) {
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(ConvolutionTests1, conv1d_causal_7) {
|
||||||
|
|
||||||
|
int bS=2, iW=8, iC=3,oC=4, kW=2, sW=1, pW=0, dW=1;
|
||||||
|
int oW = (iW-1)/sW + 1;
|
||||||
|
int paddingMode = 2; // CAUSAL
|
||||||
|
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
||||||
|
|
||||||
|
NDArray input('c', {bS, iW, iC}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray weights('c', {kW, iC, oC}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray expOutput('c', {bS, oW, oC}, {11.000000, 11.600000, 12.200000, 12.800000, 30.099998, 32.200001, 34.299999, 36.400002, 49.899998, 53.800003, 57.699997,
|
||||||
|
61.599998, 69.699997, 75.400002, 81.099998, 86.800003, 89.500000, 97.000000, 104.500000, 112.000000, 109.300003, 118.600006, 127.899994, 137.199997, 129.100006,
|
||||||
|
140.199997, 151.300003, 162.399994, 148.899994, 161.800003, 174.699997, 187.600006, 133.399994, 141.200012, 149.000000, 156.800003, 188.500000, 205.000000,
|
||||||
|
221.500000, 238.000000, 208.299988, 226.600006, 244.899994, 263.200012, 228.100006, 248.200012, 268.299988, 288.399994, 247.899994, 269.799988, 291.700012,
|
||||||
|
313.600006, 267.700012, 291.399994, 315.100006, 338.799988, 287.500000, 313.000000, 338.500000, 364.000000, 307.299988, 334.600006, 361.899994, 389.200012}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
input.linspace(1., 1.);
|
||||||
|
weights.linspace(0.1, 0.1);
|
||||||
|
|
||||||
|
nd4j::ops::conv1d op;
|
||||||
|
auto results = op.execute({&input, &weights}, {}, {kW, sW, pW, dW, paddingMode, dataFormat});
|
||||||
|
auto output = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
||||||
|
ASSERT_TRUE(expOutput.isSameShape(output));
|
||||||
|
ASSERT_TRUE(expOutput.equalsTo(output));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(ConvolutionTests1, conv1d_causal_bp_1) {
|
TEST_F(ConvolutionTests1, conv1d_causal_bp_1) {
|
||||||
|
|
||||||
|
@ -1174,14 +1206,14 @@ TYPED_TEST(TypedConvolutionTests1, conv2d_bp_test3) {
|
||||||
auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
|
auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
|
||||||
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oC, oH, oW});
|
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oC, oH, oW});
|
||||||
|
|
||||||
auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW},{ 0.567f, 1.224f, 0.66f, 1.314f, 2.82f, 1.512f, 1.386f, 2.976f, 1.596f, 0.801f, 1.71f, 0.912f, 0.657f, 1.422f, 0.768f, 1.53f, 3.288f, 1.764f, 1.602f, 3.444f, 1.848f, 0.927f, 1.98f, 1.056f,
|
auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iC, iH, iW},{ 0.567f, 1.224f, 0.66f, 1.314f, 2.82f, 1.512f, 1.386f, 2.976f, 1.596f, 0.801f, 1.71f, 0.912f, 0.657f, 1.422f, 0.768f, 1.53f, 3.288f, 1.764f, 1.602f, 3.444f, 1.848f, 0.927f, 1.98f, 1.056f,
|
||||||
0.747f, 1.62f, 0.876f, 1.746f, 3.756f, 2.016f, 1.818f, 3.912f, 2.1f, 1.053f, 2.25f, 1.2f, 0.837f, 1.818f, 0.984f, 1.962f, 4.224f, 2.268f, 2.034f, 4.38f, 2.352f, 1.179f, 2.52f, 1.344f,
|
0.747f, 1.62f, 0.876f, 1.746f, 3.756f, 2.016f, 1.818f, 3.912f, 2.1f, 1.053f, 2.25f, 1.2f, 0.837f, 1.818f, 0.984f, 1.962f, 4.224f, 2.268f, 2.034f, 4.38f, 2.352f, 1.179f, 2.52f, 1.344f,
|
||||||
1.467f, 3.06f, 1.596f, 3.186f, 6.636f, 3.456f, 3.402f, 7.08f, 3.684f, 1.845f, 3.834f, 1.992f, 1.773f, 3.69f, 1.92f, 3.834f, 7.968f, 4.14f, 4.05f, 8.412f, 4.368f, 2.187f, 4.536f, 2.352f,
|
1.467f, 3.06f, 1.596f, 3.186f, 6.636f, 3.456f, 3.402f, 7.08f, 3.684f, 1.845f, 3.834f, 1.992f, 1.773f, 3.69f, 1.92f, 3.834f, 7.968f, 4.14f, 4.05f, 8.412f, 4.368f, 2.187f, 4.536f, 2.352f,
|
||||||
2.079f, 4.32f, 2.244f, 4.482f, 9.3f, 4.824f, 4.698f, 9.744f, 5.052f, 2.529f, 5.238f, 2.712f, 2.385f, 4.95f, 2.568f, 5.13f, 10.632f, 5.508f, 5.346f, 11.076f, 5.736f, 2.871f, 5.94f, 3.072f});
|
2.079f, 4.32f, 2.244f, 4.482f, 9.3f, 4.824f, 4.698f, 9.744f, 5.052f, 2.529f, 5.238f, 2.712f, 2.385f, 4.95f, 2.568f, 5.13f, 10.632f, 5.508f, 5.346f, 11.076f, 5.736f, 2.871f, 5.94f, 3.072f});
|
||||||
|
|
||||||
auto expGradW = NDArrayFactory::create<TypeParam>('c', {oC, iC, kH, kW},{1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f,
|
auto expGradW = NDArrayFactory::create<TypeParam>('c', {oC, iC, kH, kW},{1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f,
|
||||||
1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f,
|
1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 1.3600e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f,
|
||||||
2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f,
|
2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.0000e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f,
|
||||||
2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f});
|
2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f, 2.6400e+00f});
|
||||||
auto expGradB = NDArrayFactory::create<TypeParam>('c', {oC},{0.68f, 1.f, 1.32f});
|
auto expGradB = NDArrayFactory::create<TypeParam>('c', {oC},{0.68f, 1.f, 1.32f});
|
||||||
|
|
||||||
|
@ -1252,20 +1284,20 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test1) {
|
||||||
auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
|
auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
|
||||||
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, oC});
|
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, oC});
|
||||||
|
|
||||||
auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC},{0.226f, 0.343f, 0.46f, 0.577f, 1.172f, 1.46f, 1.748f, 2.036f, 1.892f, 2.288f, 2.684f, 3.08f, 1.284f, 1.581f, 1.878f, 2.175f, 4.458f, 5.133f, 5.808f, 6.483f, 6.186f, 7.023f, 7.86f, 8.697f, 3.39f, 3.93f, 4.47f, 5.01f, 9.642f, 10.803f, 11.964f, 13.125f, 11.37f, 12.693f, 14.016f, 15.339f,
|
auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC},{0.226f, 0.343f, 0.46f, 0.577f, 1.172f, 1.46f, 1.748f, 2.036f, 1.892f, 2.288f, 2.684f, 3.08f, 1.284f, 1.581f, 1.878f, 2.175f, 4.458f, 5.133f, 5.808f, 6.483f, 6.186f, 7.023f, 7.86f, 8.697f, 3.39f, 3.93f, 4.47f, 5.01f, 9.642f, 10.803f, 11.964f, 13.125f, 11.37f, 12.693f, 14.016f, 15.339f,
|
||||||
5.266f, 5.707f, 6.148f, 6.589f, 12.98f, 13.916f, 14.852f, 15.788f, 14.564f, 15.608f, 16.652f, 17.696f, 6.284f, 7.166f, 8.048f, 8.93f, 17.896f, 19.768f, 21.64f, 23.512f, 21.928f, 24.016f, 26.104f, 28.192f, 18.12f, 19.686f, 21.252f, 22.818f, 45.852f, 49.146f, 52.44f, 55.734f, 53.196f, 56.814f, 60.432f, 64.05f,
|
5.266f, 5.707f, 6.148f, 6.589f, 12.98f, 13.916f, 14.852f, 15.788f, 14.564f, 15.608f, 16.652f, 17.696f, 6.284f, 7.166f, 8.048f, 8.93f, 17.896f, 19.768f, 21.64f, 23.512f, 21.928f, 24.016f, 26.104f, 28.192f, 18.12f, 19.686f, 21.252f, 22.818f, 45.852f, 49.146f, 52.44f, 55.734f, 53.196f, 56.814f, 60.432f, 64.05f,
|
||||||
28.164f, 30.216f, 32.268f, 34.32f, 67.884f, 72.15f, 76.416f, 80.682f, 75.228f, 79.818f, 84.408f, 88.998f, 29.324f, 30.854f, 32.384f, 33.914f, 67.432f, 70.6f, 73.768f, 76.936f, 73.192f, 76.576f, 79.96f, 83.344f, 27.884f, 30.062f, 32.24f, 34.418f, 66.28f, 70.744f, 75.208f, 79.672f, 70.312f, 74.992f, 79.672f, 84.352f,
|
28.164f, 30.216f, 32.268f, 34.32f, 67.884f, 72.15f, 76.416f, 80.682f, 75.228f, 79.818f, 84.408f, 88.998f, 29.324f, 30.854f, 32.384f, 33.914f, 67.432f, 70.6f, 73.768f, 76.936f, 73.192f, 76.576f, 79.96f, 83.344f, 27.884f, 30.062f, 32.24f, 34.418f, 66.28f, 70.744f, 75.208f, 79.672f, 70.312f, 74.992f, 79.672f, 84.352f,
|
||||||
58.296f, 61.806f, 65.316f, 68.826f, 133.98f, 141.162f, 148.344f, 155.526f, 141.324f, 148.83f, 156.336f, 163.842f, 68.34f, 72.336f, 76.332f, 80.328f, 156.012f, 164.166f, 172.32f, 180.474f, 163.356f, 171.834f, 180.312f, 188.79f, 61.292f, 64.118f, 66.944f, 69.77f, 136.552f, 142.312f, 148.072f, 153.832f, 142.312f, 148.288f, 154.264f, 160.24f,
|
58.296f, 61.806f, 65.316f, 68.826f, 133.98f, 141.162f, 148.344f, 155.526f, 141.324f, 148.83f, 156.336f, 163.842f, 68.34f, 72.336f, 76.332f, 80.328f, 156.012f, 164.166f, 172.32f, 180.474f, 163.356f, 171.834f, 180.312f, 188.79f, 61.292f, 64.118f, 66.944f, 69.77f, 136.552f, 142.312f, 148.072f, 153.832f, 142.312f, 148.288f, 154.264f, 160.24f,
|
||||||
9.298f, 11.359f, 13.42f, 15.481f, 27.092f, 31.268f, 35.444f, 39.62f, 27.812f, 32.096f, 36.38f, 40.664f, 26.556f, 29.769f, 32.982f, 36.195f, 66.666f, 73.173f, 79.68f, 86.187f, 68.394f, 75.063f, 81.732f, 88.401f, 28.662f, 32.118f, 35.574f, 39.03f, 71.85f, 78.843f, 85.836f, 92.829f, 73.578f, 80.733f, 87.888f, 95.043f,
|
9.298f, 11.359f, 13.42f, 15.481f, 27.092f, 31.268f, 35.444f, 39.62f, 27.812f, 32.096f, 36.38f, 40.664f, 26.556f, 29.769f, 32.982f, 36.195f, 66.666f, 73.173f, 79.68f, 86.187f, 68.394f, 75.063f, 81.732f, 88.401f, 28.662f, 32.118f, 35.574f, 39.03f, 71.85f, 78.843f, 85.836f, 92.829f, 73.578f, 80.733f, 87.888f, 95.043f,
|
||||||
29.89f, 32.275f, 34.66f, 37.045f, 70.004f, 74.828f, 79.652f, 84.476f, 71.588f, 76.52f, 81.452f, 86.384f, 71.084f, 75.854f, 80.624f, 85.394f, 163.048f, 172.696f, 182.344f, 191.992f, 167.08f, 176.944f, 186.808f, 196.672f, 138.648f, 146.046f, 153.444f, 160.842f, 310.236f, 325.194f, 340.152f, 355.11f, 317.58f, 332.862f, 348.144f, 363.426f,
|
29.89f, 32.275f, 34.66f, 37.045f, 70.004f, 74.828f, 79.652f, 84.476f, 71.588f, 76.52f, 81.452f, 86.384f, 71.084f, 75.854f, 80.624f, 85.394f, 163.048f, 172.696f, 182.344f, 191.992f, 167.08f, 176.944f, 186.808f, 196.672f, 138.648f, 146.046f, 153.444f, 160.842f, 310.236f, 325.194f, 340.152f, 355.11f, 317.58f, 332.862f, 348.144f, 363.426f,
|
||||||
148.692f, 156.576f, 164.46f, 172.344f, 332.268f, 348.198f, 364.128f, 380.058f, 339.612f, 355.866f, 372.12f, 388.374f, 125.228f, 130.646f, 136.064f, 141.482f, 274.792f, 285.736f, 296.68f, 307.624f, 280.552f, 291.712f, 302.872f, 314.032f, 92.684f, 98.75f, 104.816f, 110.882f, 211.432f, 223.672f, 235.912f, 248.152f, 215.464f, 227.92f, 240.376f, 252.832f,
|
148.692f, 156.576f, 164.46f, 172.344f, 332.268f, 348.198f, 364.128f, 380.058f, 339.612f, 355.866f, 372.12f, 388.374f, 125.228f, 130.646f, 136.064f, 141.482f, 274.792f, 285.736f, 296.68f, 307.624f, 280.552f, 291.712f, 302.872f, 314.032f, 92.684f, 98.75f, 104.816f, 110.882f, 211.432f, 223.672f, 235.912f, 248.152f, 215.464f, 227.92f, 240.376f, 252.832f,
|
||||||
178.824f, 188.166f, 197.508f, 206.85f, 398.364f, 417.21f, 436.056f, 454.902f, 405.708f, 424.878f, 444.048f, 463.218f, 188.868f, 198.696f, 208.524f, 218.352f, 420.396f, 440.214f, 460.032f, 479.85f, 427.74f, 447.882f, 468.024f, 488.166f, 157.196f, 163.91f, 170.624f, 177.338f, 343.912f, 357.448f, 370.984f, 384.52f, 349.672f, 363.424f, 377.176f, 390.928f});
|
178.824f, 188.166f, 197.508f, 206.85f, 398.364f, 417.21f, 436.056f, 454.902f, 405.708f, 424.878f, 444.048f, 463.218f, 188.868f, 198.696f, 208.524f, 218.352f, 420.396f, 440.214f, 460.032f, 479.85f, 427.74f, 447.882f, 468.024f, 488.166f, 157.196f, 163.91f, 170.624f, 177.338f, 343.912f, 357.448f, 370.984f, 384.52f, 349.672f, 363.424f, 377.176f, 390.928f});
|
||||||
|
|
||||||
auto expGradW = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC},{120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f,
|
auto expGradW = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC},{120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 120.96f, 122.04f, 123.12f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f, 79.56f, 80.28f, 81.f,
|
||||||
154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f,
|
154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 154.8f, 156.24f, 157.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f, 101.76f, 102.72f, 103.68f,
|
||||||
111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f,
|
111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 111.24f, 112.32f, 113.4f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f, 73.08f, 73.8f, 74.52f,
|
||||||
67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f,
|
67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 67.68f, 68.4f, 69.12f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f, 44.4f, 44.88f, 45.36f,
|
||||||
85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f,
|
85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 85.92f, 86.88f, 87.84f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f, 56.32f, 56.96f, 57.6f,
|
||||||
61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f});
|
61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 61.2f, 61.92f, 62.64f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f, 40.08f, 40.56f, 41.04f});
|
||||||
// auto expGradB('c', {oC},{});
|
// auto expGradB('c', {oC},{});
|
||||||
|
|
||||||
|
@ -1302,18 +1334,18 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test2) {
|
||||||
auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
|
auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
|
||||||
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, oC});
|
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oD, oH, oW, oC});
|
||||||
|
|
||||||
auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC},{ 0.014f, 0.032f, 0.05f, 0.068f, 0.118f, 0.181f, 0.244f, 0.307f, 0.212f, 0.257f, 0.302f, 0.347f, 0.208f, 0.298f, 0.388f, 0.478f, 1.028f, 1.262f, 1.496f, 1.73f, 1.036f, 1.18f, 1.324f, 1.468f, 0.928f, 1.018f, 1.108f, 1.198f, 2.9f, 3.134f, 3.368f, 3.602f, 2.188f, 2.332f, 2.476f, 2.62f,
|
auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC},{ 0.014f, 0.032f, 0.05f, 0.068f, 0.118f, 0.181f, 0.244f, 0.307f, 0.212f, 0.257f, 0.302f, 0.347f, 0.208f, 0.298f, 0.388f, 0.478f, 1.028f, 1.262f, 1.496f, 1.73f, 1.036f, 1.18f, 1.324f, 1.468f, 0.928f, 1.018f, 1.108f, 1.198f, 2.9f, 3.134f, 3.368f, 3.602f, 2.188f, 2.332f, 2.476f, 2.62f,
|
||||||
1.202f, 1.274f, 1.346f, 1.418f, 3.142f, 3.313f, 3.484f, 3.655f, 2.048f, 2.147f, 2.246f, 2.345f, 0.532f, 0.676f, 0.82f, 0.964f, 2.324f, 2.666f, 3.008f, 3.35f, 2.008f, 2.206f, 2.404f, 2.602f, 3.584f, 3.98f, 4.376f, 4.772f, 10.552f, 11.452f, 12.352f, 13.252f, 7.4f, 7.904f, 8.408f, 8.912f,
|
1.202f, 1.274f, 1.346f, 1.418f, 3.142f, 3.313f, 3.484f, 3.655f, 2.048f, 2.147f, 2.246f, 2.345f, 0.532f, 0.676f, 0.82f, 0.964f, 2.324f, 2.666f, 3.008f, 3.35f, 2.008f, 2.206f, 2.404f, 2.602f, 3.584f, 3.98f, 4.376f, 4.772f, 10.552f, 11.452f, 12.352f, 13.252f, 7.4f, 7.904f, 8.408f, 8.912f,
|
||||||
6.752f, 7.148f, 7.544f, 7.94f, 17.752f, 18.652f, 19.552f, 20.452f, 11.432f, 11.936f, 12.44f, 12.944f, 5.932f, 6.184f, 6.436f, 6.688f, 14.42f, 14.978f, 15.536f, 16.094f, 8.704f, 9.01f, 9.316f, 9.622f, 3.11f, 3.236f, 3.362f, 3.488f, 7.39f, 7.669f, 7.948f, 8.227f, 4.388f, 4.541f, 4.694f, 4.847f,
|
6.752f, 7.148f, 7.544f, 7.94f, 17.752f, 18.652f, 19.552f, 20.452f, 11.432f, 11.936f, 12.44f, 12.944f, 5.932f, 6.184f, 6.436f, 6.688f, 14.42f, 14.978f, 15.536f, 16.094f, 8.704f, 9.01f, 9.316f, 9.622f, 3.11f, 3.236f, 3.362f, 3.488f, 7.39f, 7.669f, 7.948f, 8.227f, 4.388f, 4.541f, 4.694f, 4.847f,
|
||||||
8.56f, 8.866f, 9.172f, 9.478f, 19.892f, 20.558f, 21.224f, 21.89f, 11.548f, 11.908f, 12.268f, 12.628f, 11.008f, 11.314f, 11.62f, 11.926f, 25.22f, 25.886f, 26.552f, 27.218f, 14.428f, 14.788f, 15.148f, 15.508f, 7.322f, 7.502f, 7.682f, 7.862f, 16.462f, 16.849f, 17.236f, 17.623f, 9.248f, 9.455f, 9.662f, 9.869f,
|
8.56f, 8.866f, 9.172f, 9.478f, 19.892f, 20.558f, 21.224f, 21.89f, 11.548f, 11.908f, 12.268f, 12.628f, 11.008f, 11.314f, 11.62f, 11.926f, 25.22f, 25.886f, 26.552f, 27.218f, 14.428f, 14.788f, 15.148f, 15.508f, 7.322f, 7.502f, 7.682f, 7.862f, 16.462f, 16.849f, 17.236f, 17.623f, 9.248f, 9.455f, 9.662f, 9.869f,
|
||||||
0.158f, 0.392f, 0.626f, 0.86f, 1.27f, 1.765f, 2.26f, 2.755f, 1.22f, 1.481f, 1.742f, 2.003f, 2.224f, 2.746f, 3.268f, 3.79f, 6.788f, 7.886f, 8.984f, 10.082f, 4.78f, 5.356f, 5.932f, 6.508f, 6.4f, 6.922f, 7.444f, 7.966f, 15.572f, 16.67f, 17.768f, 18.866f, 9.388f, 9.964f, 10.54f, 11.116f,
|
0.158f, 0.392f, 0.626f, 0.86f, 1.27f, 1.765f, 2.26f, 2.755f, 1.22f, 1.481f, 1.742f, 2.003f, 2.224f, 2.746f, 3.268f, 3.79f, 6.788f, 7.886f, 8.984f, 10.082f, 4.78f, 5.356f, 5.932f, 6.508f, 6.4f, 6.922f, 7.444f, 7.966f, 15.572f, 16.67f, 17.768f, 18.866f, 9.388f, 9.964f, 10.54f, 11.116f,
|
||||||
4.802f, 5.09f, 5.378f, 5.666f, 11.206f, 11.809f, 12.412f, 13.015f, 6.512f, 6.827f, 7.142f, 7.457f, 6.004f, 6.58f, 7.156f, 7.732f, 14.996f, 16.202f, 17.408f, 18.614f, 9.208f, 9.838f, 10.468f, 11.098f, 17.984f, 19.244f, 20.504f, 21.764f, 42.808f, 45.436f, 48.064f, 50.692f, 25.256f, 26.624f, 27.992f, 29.36f,
|
4.802f, 5.09f, 5.378f, 5.666f, 11.206f, 11.809f, 12.412f, 13.015f, 6.512f, 6.827f, 7.142f, 7.457f, 6.004f, 6.58f, 7.156f, 7.732f, 14.996f, 16.202f, 17.408f, 18.614f, 9.208f, 9.838f, 10.468f, 11.098f, 17.984f, 19.244f, 20.504f, 21.764f, 42.808f, 45.436f, 48.064f, 50.692f, 25.256f, 26.624f, 27.992f, 29.36f,
|
||||||
28.064f, 29.324f, 30.584f, 31.844f, 63.832f, 66.46f, 69.088f, 71.716f, 36.2f, 37.568f, 38.936f, 40.304f, 18.316f, 19.f, 19.684f, 20.368f, 40.916f, 42.338f, 43.76f, 45.182f, 22.816f, 23.554f, 24.292f, 25.03f, 8.438f, 8.78f, 9.122f, 9.464f, 18.91f, 19.621f, 20.332f, 21.043f, 10.58f, 10.949f, 11.318f, 11.687f,
|
28.064f, 29.324f, 30.584f, 31.844f, 63.832f, 66.46f, 69.088f, 71.716f, 36.2f, 37.568f, 38.936f, 40.304f, 18.316f, 19.f, 19.684f, 20.368f, 40.916f, 42.338f, 43.76f, 45.182f, 22.816f, 23.554f, 24.292f, 25.03f, 8.438f, 8.78f, 9.122f, 9.464f, 18.91f, 19.621f, 20.332f, 21.043f, 10.58f, 10.949f, 11.318f, 11.687f,
|
||||||
20.944f, 21.682f, 22.42f, 23.158f, 46.388f, 47.918f, 49.448f, 50.978f, 25.66f, 26.452f, 27.244f, 28.036f, 26.848f, 27.586f, 28.324f, 29.062f, 58.628f, 60.158f, 61.688f, 63.218f, 31.996f, 32.788f, 33.58f, 34.372f, 16.106f, 16.502f, 16.898f, 17.294f, 34.894f, 35.713f, 36.532f, 37.351f, 18.896f, 19.319f, 19.742f, 20.165f});
|
20.944f, 21.682f, 22.42f, 23.158f, 46.388f, 47.918f, 49.448f, 50.978f, 25.66f, 26.452f, 27.244f, 28.036f, 26.848f, 27.586f, 28.324f, 29.062f, 58.628f, 60.158f, 61.688f, 63.218f, 31.996f, 32.788f, 33.58f, 34.372f, 16.106f, 16.502f, 16.898f, 17.294f, 34.894f, 35.713f, 36.532f, 37.351f, 18.896f, 19.319f, 19.742f, 20.165f});
|
||||||
|
|
||||||
auto expGradW = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC},{7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f,
|
auto expGradW = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC},{7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f,
|
||||||
7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f,
|
7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f,
|
||||||
7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f,
|
7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f,
|
||||||
7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f});
|
7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f, 7.52f, 7.84f, 8.16f});
|
||||||
// auto expGradB('c', {oC},{});
|
// auto expGradB('c', {oC},{});
|
||||||
|
|
||||||
|
@ -1350,20 +1382,20 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_bp_test3) {
|
||||||
auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
|
auto bias = NDArrayFactory::create<TypeParam>('c', {oC}, {1,2,3});
|
||||||
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oC, oD, oH, oW});
|
auto gradO = NDArrayFactory::create<TypeParam>('c', {bS, oC, oD, oH, oW});
|
||||||
|
|
||||||
auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW},{2.091f, 4.356f, 2.268f, 4.53f, 9.42f, 4.896f, 4.65f, 9.672f, 5.028f, 2.517f, 5.226f, 2.712f, 4.932f, 10.242f, 5.316f, 10.62f, 22.02f, 11.412f, 10.908f, 22.62f, 11.724f, 5.868f, 12.15f, 6.288f, 2.913f, 6.03f, 3.12f, 6.234f, 12.888f, 6.66f, 6.402f, 13.236f, 6.84f, 3.423f, 7.068f, 3.648f,
|
auto expGradI = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW},{2.091f, 4.356f, 2.268f, 4.53f, 9.42f, 4.896f, 4.65f, 9.672f, 5.028f, 2.517f, 5.226f, 2.712f, 4.932f, 10.242f, 5.316f, 10.62f, 22.02f, 11.412f, 10.908f, 22.62f, 11.724f, 5.868f, 12.15f, 6.288f, 2.913f, 6.03f, 3.12f, 6.234f, 12.888f, 6.66f, 6.402f, 13.236f, 6.84f, 3.423f, 7.068f, 3.648f,
|
||||||
2.415f, 5.04f, 2.628f, 5.25f, 10.932f, 5.688f, 5.37f, 11.184f, 5.82f, 2.913f, 6.054f, 3.144f, 5.724f, 11.898f, 6.18f, 12.348f, 25.62f, 13.284f, 12.636f, 26.22f, 13.596f, 6.804f, 14.094f, 7.296f, 3.381f, 7.002f, 3.624f, 7.242f, 14.976f, 7.74f, 7.41f, 15.324f, 7.92f, 3.963f, 8.184f, 4.224f,
|
2.415f, 5.04f, 2.628f, 5.25f, 10.932f, 5.688f, 5.37f, 11.184f, 5.82f, 2.913f, 6.054f, 3.144f, 5.724f, 11.898f, 6.18f, 12.348f, 25.62f, 13.284f, 12.636f, 26.22f, 13.596f, 6.804f, 14.094f, 7.296f, 3.381f, 7.002f, 3.624f, 7.242f, 14.976f, 7.74f, 7.41f, 15.324f, 7.92f, 3.963f, 8.184f, 4.224f,
|
||||||
2.739f, 5.724f, 2.988f, 5.97f, 12.444f, 6.48f, 6.09f, 12.696f, 6.612f, 3.309f, 6.882f, 3.576f, 6.516f, 13.554f, 7.044f, 14.076f, 29.22f, 15.156f, 14.364f, 29.82f, 15.468f, 7.74f, 16.038f, 8.304f, 3.849f, 7.974f, 4.128f, 8.25f, 17.064f, 8.82f, 8.418f, 17.412f, 9.f, 4.503f, 9.3f, 4.8f,
|
2.739f, 5.724f, 2.988f, 5.97f, 12.444f, 6.48f, 6.09f, 12.696f, 6.612f, 3.309f, 6.882f, 3.576f, 6.516f, 13.554f, 7.044f, 14.076f, 29.22f, 15.156f, 14.364f, 29.82f, 15.468f, 7.74f, 16.038f, 8.304f, 3.849f, 7.974f, 4.128f, 8.25f, 17.064f, 8.82f, 8.418f, 17.412f, 9.f, 4.503f, 9.3f, 4.8f,
|
||||||
3.063f, 6.408f, 3.348f, 6.69f, 13.956f, 7.272f, 6.81f, 14.208f, 7.404f, 3.705f, 7.71f, 4.008f, 7.308f, 15.21f, 7.908f, 15.804f, 32.82f, 17.028f, 16.092f, 33.42f, 17.34f, 8.676f, 17.982f, 9.312f, 4.317f, 8.946f, 4.632f, 9.258f, 19.152f, 9.9f, 9.426f, 19.5f, 10.08f, 5.043f, 10.416f, 5.376f,
|
3.063f, 6.408f, 3.348f, 6.69f, 13.956f, 7.272f, 6.81f, 14.208f, 7.404f, 3.705f, 7.71f, 4.008f, 7.308f, 15.21f, 7.908f, 15.804f, 32.82f, 17.028f, 16.092f, 33.42f, 17.34f, 8.676f, 17.982f, 9.312f, 4.317f, 8.946f, 4.632f, 9.258f, 19.152f, 9.9f, 9.426f, 19.5f, 10.08f, 5.043f, 10.416f, 5.376f,
|
||||||
5.619f, 11.484f, 5.868f, 11.73f, 23.964f, 12.24f, 12.138f, 24.792f, 12.66f, 6.333f, 12.93f, 6.6f, 12.42f, 25.362f, 12.948f, 25.884f, 52.836f, 26.964f, 26.748f, 54.588f, 27.852f, 13.932f, 28.422f, 14.496f, 6.873f, 14.022f, 7.152f, 14.298f, 29.16f, 14.868f, 14.754f, 30.084f, 15.336f, 7.671f, 15.636f, 7.968f,
|
5.619f, 11.484f, 5.868f, 11.73f, 23.964f, 12.24f, 12.138f, 24.792f, 12.66f, 6.333f, 12.93f, 6.6f, 12.42f, 25.362f, 12.948f, 25.884f, 52.836f, 26.964f, 26.748f, 54.588f, 27.852f, 13.932f, 28.422f, 14.496f, 6.873f, 14.022f, 7.152f, 14.298f, 29.16f, 14.868f, 14.754f, 30.084f, 15.336f, 7.671f, 15.636f, 7.968f,
|
||||||
6.807f, 13.896f, 7.092f, 14.178f, 28.932f, 14.76f, 14.586f, 29.76f, 15.18f, 7.593f, 15.486f, 7.896f, 14.94f, 30.474f, 15.54f, 31.068f, 63.348f, 32.292f, 31.932f, 65.1f, 33.18f, 16.596f, 33.822f, 17.232f, 8.205f, 16.722f, 8.52f, 17.034f, 34.704f, 17.676f, 17.49f, 35.628f, 18.144f, 9.075f, 18.48f, 9.408f,
|
6.807f, 13.896f, 7.092f, 14.178f, 28.932f, 14.76f, 14.586f, 29.76f, 15.18f, 7.593f, 15.486f, 7.896f, 14.94f, 30.474f, 15.54f, 31.068f, 63.348f, 32.292f, 31.932f, 65.1f, 33.18f, 16.596f, 33.822f, 17.232f, 8.205f, 16.722f, 8.52f, 17.034f, 34.704f, 17.676f, 17.49f, 35.628f, 18.144f, 9.075f, 18.48f, 9.408f,
|
||||||
7.995f, 16.308f, 8.316f, 16.626f, 33.9f, 17.28f, 17.034f, 34.728f, 17.7f, 8.853f, 18.042f, 9.192f, 17.46f, 35.586f, 18.132f, 36.252f, 73.86f, 37.62f, 37.116f, 75.612f, 38.508f, 19.26f, 39.222f, 19.968f, 9.537f, 19.422f, 9.888f, 19.77f, 40.248f, 20.484f, 20.226f, 41.172f, 20.952f, 10.479f, 21.324f, 10.848f,
|
7.995f, 16.308f, 8.316f, 16.626f, 33.9f, 17.28f, 17.034f, 34.728f, 17.7f, 8.853f, 18.042f, 9.192f, 17.46f, 35.586f, 18.132f, 36.252f, 73.86f, 37.62f, 37.116f, 75.612f, 38.508f, 19.26f, 39.222f, 19.968f, 9.537f, 19.422f, 9.888f, 19.77f, 40.248f, 20.484f, 20.226f, 41.172f, 20.952f, 10.479f, 21.324f, 10.848f,
|
||||||
9.183f, 18.72f, 9.54f, 19.074f, 38.868f, 19.8f, 19.482f, 39.696f, 20.22f, 10.113f, 20.598f, 10.488f, 19.98f, 40.698f, 20.724f, 41.436f, 84.372f, 42.948f, 42.3f, 86.124f, 43.836f, 21.924f, 44.622f, 22.704f, 10.869f, 22.122f, 11.256f, 22.506f, 45.792f, 23.292f, 22.962f, 46.716f, 23.76f, 11.883f, 24.168f, 12.288f});
|
9.183f, 18.72f, 9.54f, 19.074f, 38.868f, 19.8f, 19.482f, 39.696f, 20.22f, 10.113f, 20.598f, 10.488f, 19.98f, 40.698f, 20.724f, 41.436f, 84.372f, 42.948f, 42.3f, 86.124f, 43.836f, 21.924f, 44.622f, 22.704f, 10.869f, 22.122f, 11.256f, 22.506f, 45.792f, 23.292f, 22.962f, 46.716f, 23.76f, 11.883f, 24.168f, 12.288f});
|
||||||
|
|
||||||
auto expGradW = NDArrayFactory::create<TypeParam>('c', {oC, iC, kD, kH, kW},{5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f,
|
auto expGradW = NDArrayFactory::create<TypeParam>('c', {oC, iC, kD, kH, kW},{5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f,
|
||||||
5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f,
|
5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f, 5.28f,
|
||||||
7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f,
|
7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f,
|
||||||
7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f,
|
7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f, 7.84f,
|
||||||
10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f,
|
10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f,
|
||||||
10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f});
|
10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f, 10.4f});
|
||||||
|
|
||||||
auto expGradB = NDArrayFactory::create<TypeParam>('c', {oC},{2.64f, 3.92f, 5.2f});
|
auto expGradB = NDArrayFactory::create<TypeParam>('c', {oC},{2.64f, 3.92f, 5.2f});
|
||||||
|
@ -1407,9 +1439,9 @@ TYPED_TEST(TypedConvolutionTests1, depthwise_conv2d_1) {
|
||||||
auto weights = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, mC});
|
auto weights = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, mC});
|
||||||
|
|
||||||
|
|
||||||
auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC},{12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f,
|
auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC},{12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f,
|
||||||
13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f,
|
13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f,
|
||||||
12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f,
|
12.f, 12.8f, 13.6f, 14.4f, 12.f, 12.8f, 13.6f, 14.4f, 5.2f, 5.6f, 6.f, 6.4f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f,
|
||||||
13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f});
|
13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 5.4f, 6.f, 6.6f, 7.2f, 5.6f, 6.4f, 7.2f, 8.f, 5.6f, 6.4f, 7.2f, 8.f, 2.f, 2.4f, 2.8f, 3.2f});
|
||||||
input = 2.;
|
input = 2.;
|
||||||
weights.linspace(0.1, 0.1);
|
weights.linspace(0.1, 0.1);
|
||||||
|
@ -1439,7 +1471,7 @@ TEST_F(ConvolutionTests1, depthwise_conv2d_2) {
|
||||||
auto weights = NDArrayFactory::create<double>('c', {kH, kW, iC, mC});
|
auto weights = NDArrayFactory::create<double>('c', {kH, kW, iC, mC});
|
||||||
|
|
||||||
|
|
||||||
auto expOutput = NDArrayFactory::create<double>('c', {bS, oH, oW, oC},{13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f,
|
auto expOutput = NDArrayFactory::create<double>('c', {bS, oH, oW, oC},{13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f,
|
||||||
13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f});
|
13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f, 13.2f, 14.4f, 15.6f, 16.8f});
|
||||||
input = 2.;
|
input = 2.;
|
||||||
weights.linspace(0.1, 0.1);
|
weights.linspace(0.1, 0.1);
|
||||||
|
@ -1697,13 +1729,13 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test1) {
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
|
||||||
auto weights = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC});
|
auto weights = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC});
|
||||||
auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 4, 3, 3}, {64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f,
|
auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 4, 3, 3}, {64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f,
|
||||||
64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f,
|
64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f,
|
||||||
96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f,
|
96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f,
|
||||||
48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f,
|
48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f,
|
||||||
64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f,
|
64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f,
|
||||||
64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f,
|
64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f,
|
||||||
96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f,
|
96.f, 96.f, 96.f, 96.f, 96.f, 96.f, 48.f, 48.f, 48.f, 64.f, 64.f, 64.f, 64.f, 64.f, 64.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f,
|
||||||
48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f});
|
48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 48.f, 48.f, 48.f, 48.f, 48.f, 48.f, 24.f, 24.f, 24.f, 32.f, 32.f, 32.f, 32.f, 32.f, 32.f, 16.f, 16.f, 16.f});
|
||||||
input = 2.;
|
input = 2.;
|
||||||
weights = 1.;
|
weights = 1.;
|
||||||
|
@ -1729,13 +1761,13 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test2) {
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
|
||||||
auto weights = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC});
|
auto weights = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC});
|
||||||
auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 4, 3, 3}, {534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f,
|
auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 4, 3, 3}, {534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f,
|
||||||
380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f,
|
380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f,
|
||||||
686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f,
|
686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f,
|
||||||
170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f,
|
170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f,
|
||||||
534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f,
|
534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f,
|
||||||
380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f,
|
380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 534.4f, 540.8f, 547.2f, 534.4f, 540.8f, 547.2f, 248.f, 251.2f, 254.4f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f,
|
||||||
686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f,
|
686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 314.4f, 319.2f, 324.f, 380.8f, 387.2f, 393.6f, 380.8f, 387.2f, 393.6f, 171.2f, 174.4f, 177.6f, 152.f, 155.2f, 158.4f, 152.f, 155.2f, 158.4f, 66.4f, 68.f, 69.6f,
|
||||||
170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f});
|
170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 170.4f, 175.2f, 180.f, 170.4f, 175.2f, 180.f, 70.8f, 73.2f, 75.6f, 75.2f, 78.4f, 81.6f, 75.2f, 78.4f, 81.6f, 28.f, 29.6f, 31.2f});
|
||||||
input = 2.;
|
input = 2.;
|
||||||
weights.linspace(0.1, 0.1);
|
weights.linspace(0.1, 0.1);
|
||||||
|
@ -1760,9 +1792,9 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test3) {
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iD, iH, iW, iC});
|
||||||
auto weights = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC});
|
auto weights = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC});
|
||||||
auto expected = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2, 3}, {686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f,
|
auto expected = NDArrayFactory::create<TypeParam>('c', {2, 2, 2, 2, 3}, {686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f,
|
||||||
686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f,
|
686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f,
|
||||||
686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f,
|
686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f,
|
||||||
686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f});
|
686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f, 686.4f, 696.f, 705.6f});
|
||||||
input = 2.;
|
input = 2.;
|
||||||
weights.linspace(0.1, 0.1);
|
weights.linspace(0.1, 0.1);
|
||||||
|
@ -1844,8 +1876,8 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test6) {
|
||||||
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
|
||||||
auto weights = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC});
|
auto weights = NDArrayFactory::create<TypeParam>('c', {kD, kH, kW, iC, oC});
|
||||||
auto bias = NDArrayFactory::create<TypeParam>('c', {oC},{1.f, 2.f, 3.f});
|
auto bias = NDArrayFactory::create<TypeParam>('c', {oC},{1.f, 2.f, 3.f});
|
||||||
auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 2, 2, 2},{49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f,
|
auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 2, 2, 2},{49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f,
|
||||||
51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f,
|
51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f, 49.f,
|
||||||
50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f});
|
50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 50.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f, 51.f});
|
||||||
input = 2.;
|
input = 2.;
|
||||||
weights = 0.5;
|
weights = 0.5;
|
||||||
|
@ -1873,9 +1905,9 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test7) {
|
||||||
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
|
||||||
auto weights = NDArrayFactory::create<TypeParam>('c', {oC, iC, kD, kH, kW});
|
auto weights = NDArrayFactory::create<TypeParam>('c', {oC, iC, kD, kH, kW});
|
||||||
auto bias = NDArrayFactory::create<TypeParam>('c', {oC},{1.f, 2.f, 3.f});
|
auto bias = NDArrayFactory::create<TypeParam>('c', {oC},{1.f, 2.f, 3.f});
|
||||||
auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 2, 2, 2},{236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 698.f, 698.f, 698.f, 698.f,
|
auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 2, 2, 2},{236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 698.f, 698.f, 698.f, 698.f,
|
||||||
698.f, 698.f, 698.f, 698.f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f,
|
698.f, 698.f, 698.f, 698.f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f,
|
||||||
236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 698.f, 698.f, 698.f, 698.f,
|
236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 236.2f, 698.f, 698.f, 698.f, 698.f,
|
||||||
698.f, 698.f, 698.f, 698.f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f});
|
698.f, 698.f, 698.f, 698.f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f, 1159.8f});
|
||||||
input = 2.;
|
input = 2.;
|
||||||
weights.linspace(0.1, 0.1);
|
weights.linspace(0.1, 0.1);
|
||||||
|
@ -1903,8 +1935,8 @@ TYPED_TEST(TypedConvolutionTests1, conv3d_test8) {
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, iC, iD, iH, iW});
|
||||||
auto weights = NDArrayFactory::create<TypeParam>('c', {oC, iC, kD, kH, kW});
|
auto weights = NDArrayFactory::create<TypeParam>('c', {oC, iC, kD, kH, kW});
|
||||||
auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 2, 2, 2},{235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f,
|
auto expected = NDArrayFactory::create<TypeParam>('c', {2, 3, 2, 2, 2},{235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f,
|
||||||
1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f,
|
1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f, 235.2f,
|
||||||
696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f});
|
696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 696.f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f, 1156.8f});
|
||||||
input = 2.;
|
input = 2.;
|
||||||
weights.linspace(0.1, 0.1);
|
weights.linspace(0.1, 0.1);
|
||||||
|
@ -1997,9 +2029,9 @@ TYPED_TEST(TypedConvolutionTests1, pointwise_conv2d_test1) {
|
||||||
auto bias = NDArrayFactory::create<TypeParam>('c', {oC});
|
auto bias = NDArrayFactory::create<TypeParam>('c', {oC});
|
||||||
|
|
||||||
|
|
||||||
auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, oC},{ 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f,
|
auto expOutput = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, oC},{ 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f,
|
||||||
7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f,
|
7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f,
|
||||||
6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f,
|
6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f,
|
||||||
5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f});
|
5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f, 5.4f, 6.2f, 7.0f});
|
||||||
input = 2.;
|
input = 2.;
|
||||||
weights.linspace(0.1, 0.1);
|
weights.linspace(0.1, 0.1);
|
||||||
|
@ -2110,20 +2142,20 @@ TEST_F(ConvolutionTests1, vol2col_test2) {
|
||||||
auto columns = NDArrayFactory::create<float>('c', {kD, iC, kH, oW, kW, bS, oD, oH});
|
auto columns = NDArrayFactory::create<float>('c', {kD, iC, kH, oW, kW, bS, oD, oH});
|
||||||
columns.permutei({5, 1, 0, 2, 4, 6, 7, 3});
|
columns.permutei({5, 1, 0, 2, 4, 6, 7, 3});
|
||||||
columns = -1.;
|
columns = -1.;
|
||||||
auto columnsExpected = NDArrayFactory::create<float>('c', {bS, iC, kD, kH, kW, oD, oH, oW}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f,
|
auto columnsExpected = NDArrayFactory::create<float>('c', {bS, iC, kD, kH, kW, oD, oH, oW}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f,
|
||||||
10.f, 11.f, 12.f, 2.f, 0.f, 4.f, 0.f, 6.f, 0.f, 8.f, 0.f, 10.f, 0.f, 12.f, 0.f, 3.f, 4.f, 5.f, 6.f, 0.f, 0.f, 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 4.f, 0.f, 6.f, 0.f, 0.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 5.f, 6.f, 0.f, 0.f, 0.f, 0.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 6.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 7.f, 8.f,
|
10.f, 11.f, 12.f, 2.f, 0.f, 4.f, 0.f, 6.f, 0.f, 8.f, 0.f, 10.f, 0.f, 12.f, 0.f, 3.f, 4.f, 5.f, 6.f, 0.f, 0.f, 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 4.f, 0.f, 6.f, 0.f, 0.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 5.f, 6.f, 0.f, 0.f, 0.f, 0.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 6.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 7.f, 8.f,
|
||||||
9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 8.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 8.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 9.f, 10.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 10.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 11.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 12.f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
||||||
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 14.f, 0.f, 16.f, 0.f, 18.f, 0.f, 20.f, 0.f, 22.f, 0.f, 24.f, 0.f, 15.f, 16.f, 17.f, 18.f, 0.f, 0.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 16.f, 0.f, 18.f, 0.f, 0.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 17.f, 18.f, 0.f, 0.f, 0.f, 0.f,
|
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 14.f, 0.f, 16.f, 0.f, 18.f, 0.f, 20.f, 0.f, 22.f, 0.f, 24.f, 0.f, 15.f, 16.f, 17.f, 18.f, 0.f, 0.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 16.f, 0.f, 18.f, 0.f, 0.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 17.f, 18.f, 0.f, 0.f, 0.f, 0.f,
|
||||||
23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 20.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 18.f, 0.f, 0.f, 0.f, 0.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 20.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 21.f, 22.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 22.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
||||||
0.f, 0.f, 0.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 26.f, 0.f, 28.f, 0.f, 30.f, 0.f, 32.f, 0.f, 34.f, 0.f, 36.f, 0.f, 27.f, 28.f, 29.f, 30.f, 0.f, 0.f, 33.f, 34.f, 35.f, 36.f,
|
0.f, 0.f, 0.f, 23.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 24.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 26.f, 0.f, 28.f, 0.f, 30.f, 0.f, 32.f, 0.f, 34.f, 0.f, 36.f, 0.f, 27.f, 28.f, 29.f, 30.f, 0.f, 0.f, 33.f, 34.f, 35.f, 36.f,
|
||||||
0.f, 0.f, 28.f, 0.f, 30.f, 0.f, 0.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 29.f, 30.f, 0.f, 0.f, 0.f, 0.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 30.f, 0.f, 0.f, 0.f, 0.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 32.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 33.f,
|
0.f, 0.f, 28.f, 0.f, 30.f, 0.f, 0.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 29.f, 30.f, 0.f, 0.f, 0.f, 0.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 30.f, 0.f, 0.f, 0.f, 0.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 32.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 33.f,
|
||||||
34.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 38.f, 0.f, 40.f,
|
34.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 34.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 35.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 36.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 38.f, 0.f, 40.f,
|
||||||
0.f, 42.f, 0.f, 44.f, 0.f, 46.f, 0.f, 48.f, 0.f, 39.f, 40.f, 41.f, 42.f, 0.f, 0.f, 45.f, 46.f, 47.f, 48.f, 0.f, 0.f, 40.f, 0.f, 42.f, 0.f, 0.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 41.f, 42.f, 0.f, 0.f, 0.f, 0.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 42.f, 0.f, 0.f, 0.f, 0.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 43.f, 44.f, 45.f, 46.f, 47.f,
|
0.f, 42.f, 0.f, 44.f, 0.f, 46.f, 0.f, 48.f, 0.f, 39.f, 40.f, 41.f, 42.f, 0.f, 0.f, 45.f, 46.f, 47.f, 48.f, 0.f, 0.f, 40.f, 0.f, 42.f, 0.f, 0.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 41.f, 42.f, 0.f, 0.f, 0.f, 0.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 42.f, 0.f, 0.f, 0.f, 0.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 43.f, 44.f, 45.f, 46.f, 47.f,
|
||||||
48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 44.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 45.f, 46.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 44.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 45.f, 46.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 46.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 47.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 48.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f,
|
||||||
0.f, 0.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 50.f, 0.f, 52.f, 0.f, 54.f, 0.f, 56.f, 0.f, 58.f, 0.f, 60.f, 0.f, 51.f, 52.f, 53.f, 54.f, 0.f, 0.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 52.f, 0.f, 54.f, 0.f, 0.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 53.f, 54.f, 0.f, 0.f, 0.f, 0.f, 59.f, 60.f, 0.f, 0.f,
|
0.f, 0.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 50.f, 0.f, 52.f, 0.f, 54.f, 0.f, 56.f, 0.f, 58.f, 0.f, 60.f, 0.f, 51.f, 52.f, 53.f, 54.f, 0.f, 0.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 52.f, 0.f, 54.f, 0.f, 0.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 53.f, 54.f, 0.f, 0.f, 0.f, 0.f, 59.f, 60.f, 0.f, 0.f,
|
||||||
0.f, 0.f, 54.f, 0.f, 0.f, 0.f, 0.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 56.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 59.f, 60.f,
|
0.f, 0.f, 54.f, 0.f, 0.f, 0.f, 0.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 56.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 57.f, 58.f, 59.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 58.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 59.f, 60.f,
|
||||||
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 62.f, 0.f, 64.f, 0.f, 66.f, 0.f, 68.f, 0.f, 70.f, 0.f, 72.f, 0.f, 63.f, 64.f, 65.f, 66.f, 0.f, 0.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, 64.f, 0.f, 66.f,
|
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 60.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 62.f, 0.f, 64.f, 0.f, 66.f, 0.f, 68.f, 0.f, 70.f, 0.f, 72.f, 0.f, 63.f, 64.f, 65.f, 66.f, 0.f, 0.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, 64.f, 0.f, 66.f,
|
||||||
0.f, 0.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 65.f, 66.f, 0.f, 0.f, 0.f, 0.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 66.f, 0.f, 0.f, 0.f, 0.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 68.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f,
|
0.f, 0.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 65.f, 66.f, 0.f, 0.f, 0.f, 0.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 66.f, 0.f, 0.f, 0.f, 0.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 68.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 69.f, 70.f, 71.f, 72.f, 0.f, 0.f,
|
||||||
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f});
|
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 70.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 71.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 72.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f});
|
||||||
|
|
||||||
graph::Context context(1);
|
graph::Context context(1);
|
||||||
|
@ -2164,11 +2196,11 @@ TEST_F(ConvolutionTests1, upsampling2d_test1) {
|
||||||
auto input = NDArrayFactory::create<float>('c', {bS, iH, iW, iC});
|
auto input = NDArrayFactory::create<float>('c', {bS, iH, iW, iC});
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
auto expOutput = NDArrayFactory::create<float>('c', {bS, iH*factorH, iW*factorW, iC}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f,
|
auto expOutput = NDArrayFactory::create<float>('c', {bS, iH*factorH, iW*factorW, iC}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f,
|
||||||
7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f,
|
7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f,
|
||||||
13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
|
13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
|
||||||
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f,
|
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f,
|
||||||
25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f,
|
25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f,
|
||||||
31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f});
|
31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f});
|
||||||
|
|
||||||
nd4j::ops::upsampling2d op;
|
nd4j::ops::upsampling2d op;
|
||||||
|
@ -2192,11 +2224,11 @@ TEST_F(ConvolutionTests1, upsampling2d_test2) {
|
||||||
auto input = NDArrayFactory::create<float>('c', {bS, iC, iH, iW});
|
auto input = NDArrayFactory::create<float>('c', {bS, iC, iH, iW});
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
auto expOutput = NDArrayFactory::create<float>('c', {bS, iC, iH*factorH, iW*factorW}, {1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f,
|
auto expOutput = NDArrayFactory::create<float>('c', {bS, iC, iH*factorH, iW*factorW}, {1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 1.f, 1.f, 1.f, 2.f, 2.f, 2.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f,
|
||||||
5.f, 5.f, 5.f, 6.f, 6.f, 6.f, 5.f, 5.f, 5.f, 6.f, 6.f, 6.f, 7.f, 7.f, 7.f, 8.f, 8.f, 8.f, 7.f, 7.f, 7.f, 8.f, 8.f, 8.f, 9.f, 9.f, 9.f, 10.f, 10.f, 10.f, 9.f, 9.f, 9.f, 10.f, 10.f, 10.f, 11.f, 11.f, 11.f, 12.f, 12.f, 12.f, 11.f, 11.f, 11.f, 12.f, 12.f, 12.f,
|
5.f, 5.f, 5.f, 6.f, 6.f, 6.f, 5.f, 5.f, 5.f, 6.f, 6.f, 6.f, 7.f, 7.f, 7.f, 8.f, 8.f, 8.f, 7.f, 7.f, 7.f, 8.f, 8.f, 8.f, 9.f, 9.f, 9.f, 10.f, 10.f, 10.f, 9.f, 9.f, 9.f, 10.f, 10.f, 10.f, 11.f, 11.f, 11.f, 12.f, 12.f, 12.f, 11.f, 11.f, 11.f, 12.f, 12.f, 12.f,
|
||||||
13.f, 13.f, 13.f, 14.f, 14.f, 14.f, 13.f, 13.f, 13.f, 14.f, 14.f, 14.f, 15.f, 15.f, 15.f, 16.f, 16.f, 16.f, 15.f, 15.f, 15.f, 16.f, 16.f, 16.f, 17.f, 17.f, 17.f, 18.f, 18.f, 18.f, 17.f, 17.f, 17.f, 18.f, 18.f, 18.f, 19.f, 19.f, 19.f, 20.f, 20.f, 20.f, 19.f, 19.f, 19.f, 20.f, 20.f, 20.f,
|
13.f, 13.f, 13.f, 14.f, 14.f, 14.f, 13.f, 13.f, 13.f, 14.f, 14.f, 14.f, 15.f, 15.f, 15.f, 16.f, 16.f, 16.f, 15.f, 15.f, 15.f, 16.f, 16.f, 16.f, 17.f, 17.f, 17.f, 18.f, 18.f, 18.f, 17.f, 17.f, 17.f, 18.f, 18.f, 18.f, 19.f, 19.f, 19.f, 20.f, 20.f, 20.f, 19.f, 19.f, 19.f, 20.f, 20.f, 20.f,
|
||||||
21.f, 21.f, 21.f, 22.f, 22.f, 22.f, 21.f, 21.f, 21.f, 22.f, 22.f, 22.f, 23.f, 23.f, 23.f, 24.f, 24.f, 24.f, 23.f, 23.f, 23.f, 24.f, 24.f, 24.f, 25.f, 25.f, 25.f, 26.f, 26.f, 26.f, 25.f, 25.f, 25.f, 26.f, 26.f, 26.f, 27.f, 27.f, 27.f, 28.f, 28.f, 28.f, 27.f, 27.f, 27.f, 28.f, 28.f, 28.f,
|
21.f, 21.f, 21.f, 22.f, 22.f, 22.f, 21.f, 21.f, 21.f, 22.f, 22.f, 22.f, 23.f, 23.f, 23.f, 24.f, 24.f, 24.f, 23.f, 23.f, 23.f, 24.f, 24.f, 24.f, 25.f, 25.f, 25.f, 26.f, 26.f, 26.f, 25.f, 25.f, 25.f, 26.f, 26.f, 26.f, 27.f, 27.f, 27.f, 28.f, 28.f, 28.f, 27.f, 27.f, 27.f, 28.f, 28.f, 28.f,
|
||||||
29.f, 29.f, 29.f, 30.f, 30.f, 30.f, 29.f, 29.f, 29.f, 30.f, 30.f, 30.f, 31.f, 31.f, 31.f, 32.f, 32.f, 32.f, 31.f, 31.f, 31.f, 32.f, 32.f, 32.f,
|
29.f, 29.f, 29.f, 30.f, 30.f, 30.f, 29.f, 29.f, 29.f, 30.f, 30.f, 30.f, 31.f, 31.f, 31.f, 32.f, 32.f, 32.f, 31.f, 31.f, 31.f, 32.f, 32.f, 32.f,
|
||||||
33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f});
|
33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 33.f, 33.f, 33.f, 34.f, 34.f, 34.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f, 35.f, 35.f, 35.f, 36.f, 36.f, 36.f});
|
||||||
|
|
||||||
nd4j::ops::upsampling2d op;
|
nd4j::ops::upsampling2d op;
|
||||||
|
@ -2221,20 +2253,20 @@ TEST_F(ConvolutionTests1, upsampling3d_test1) {
|
||||||
auto input = NDArrayFactory::create<float>('c', {bS, iD, iH, iW, iC});
|
auto input = NDArrayFactory::create<float>('c', {bS, iD, iH, iW, iC});
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
auto expOutput = NDArrayFactory::create<float>('c', {bS, iD*factorD, iH*factorH, iW*factorW, iC}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f,
|
auto expOutput = NDArrayFactory::create<float>('c', {bS, iD*factorD, iH*factorH, iW*factorW, iC}, {1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f,
|
||||||
7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f,
|
7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 1.f, 2.f, 3.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f,
|
||||||
7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
|
7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 7.f, 8.f, 9.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 10.f, 11.f, 12.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
|
||||||
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
|
19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f,
|
||||||
13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f,
|
13.f, 14.f, 15.f, 13.f, 14.f, 15.f, 16.f, 17.f, 18.f, 16.f, 17.f, 18.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 19.f, 20.f, 21.f, 19.f, 20.f, 21.f, 22.f, 23.f, 24.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f,
|
||||||
25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f,
|
25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f,
|
||||||
25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f,
|
25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 25.f, 26.f, 27.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 28.f, 29.f, 30.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f,
|
||||||
31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f,
|
31.f, 32.f, 33.f, 31.f, 32.f, 33.f, 34.f, 35.f, 36.f, 34.f, 35.f, 36.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f,
|
||||||
43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f,
|
43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f, 37.f, 38.f, 39.f, 37.f, 38.f, 39.f, 40.f, 41.f, 42.f, 40.f, 41.f, 42.f,
|
||||||
43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f,
|
43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 43.f, 44.f, 45.f, 43.f, 44.f, 45.f, 46.f, 47.f, 48.f, 46.f, 47.f, 48.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f,
|
||||||
49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f,
|
49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f,
|
||||||
49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f,
|
49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 49.f, 50.f, 51.f, 49.f, 50.f, 51.f, 52.f, 53.f, 54.f, 52.f, 53.f, 54.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f, 55.f, 56.f, 57.f, 55.f, 56.f, 57.f, 58.f, 59.f, 60.f, 58.f, 59.f, 60.f,
|
||||||
61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f,
|
61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f,
|
||||||
67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f,
|
67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 61.f, 62.f, 63.f, 61.f, 62.f, 63.f, 64.f, 65.f, 66.f, 64.f, 65.f, 66.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f,
|
||||||
67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f});
|
67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f, 67.f, 68.f, 69.f, 67.f, 68.f, 69.f, 70.f, 71.f, 72.f, 70.f, 71.f, 72.f});
|
||||||
|
|
||||||
nd4j::ops::upsampling3d op;
|
nd4j::ops::upsampling3d op;
|
||||||
|
@ -2258,17 +2290,17 @@ TEST_F(ConvolutionTests1, upsampling3d_test2) {
|
||||||
auto input = NDArrayFactory::create<float>('c', {bS, iC, iD, iH, iW});
|
auto input = NDArrayFactory::create<float>('c', {bS, iC, iD, iH, iW});
|
||||||
input.linspace(1);
|
input.linspace(1);
|
||||||
|
|
||||||
auto expOutput = NDArrayFactory::create<float>('c', {bS, iC, iD*factorD, iH*factorH, iW*factorW}, { 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f,
|
auto expOutput = NDArrayFactory::create<float>('c', {bS, iC, iD*factorD, iH*factorH, iW*factorW}, { 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 1.f, 1.f, 2.f, 2.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 3.f, 3.f, 4.f, 4.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f,
|
||||||
5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f,
|
5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 5.f, 5.f, 6.f, 6.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 7.f, 7.f, 8.f, 8.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 9.f, 9.f, 10.f, 10.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f, 11.f, 11.f, 12.f, 12.f,
|
||||||
13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f,
|
13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 13.f, 13.f, 14.f, 14.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 15.f, 15.f, 16.f, 16.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f,
|
||||||
17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f,
|
17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 17.f, 17.f, 18.f, 18.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 19.f, 19.f, 20.f, 20.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 21.f, 21.f, 22.f, 22.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f, 23.f, 23.f, 24.f, 24.f,
|
||||||
25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f,
|
25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 25.f, 25.f, 26.f, 26.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 27.f, 27.f, 28.f, 28.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f,
|
||||||
29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f,
|
29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 29.f, 29.f, 30.f, 30.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 31.f, 31.f, 32.f, 32.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 33.f, 33.f, 34.f, 34.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f, 35.f, 35.f, 36.f, 36.f,
|
||||||
37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f,
|
37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 37.f, 37.f, 38.f, 38.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 39.f, 39.f, 40.f, 40.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f,
|
||||||
41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f,
|
41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 41.f, 41.f, 42.f, 42.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 43.f, 43.f, 44.f, 44.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 45.f, 45.f, 46.f, 46.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f, 47.f, 47.f, 48.f, 48.f,
|
||||||
49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f,
|
49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 49.f, 49.f, 50.f, 50.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 51.f, 51.f, 52.f, 52.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f,
|
||||||
53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f,
|
53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 53.f, 53.f, 54.f, 54.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 55.f, 55.f, 56.f, 56.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 57.f, 57.f, 58.f, 58.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f, 59.f, 59.f, 60.f, 60.f,
|
||||||
61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f,
|
61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 61.f, 61.f, 62.f, 62.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 63.f, 63.f, 64.f, 64.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f,
|
||||||
65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f});
|
65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 65.f, 65.f, 66.f, 66.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 67.f, 67.f, 68.f, 68.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 69.f, 69.f, 70.f, 70.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f, 71.f, 71.f, 72.f, 72.f});
|
||||||
|
|
||||||
nd4j::ops::upsampling3d op;
|
nd4j::ops::upsampling3d op;
|
||||||
|
@ -2412,13 +2444,13 @@ TEST_F(ConvolutionTests1, deconv2d_test1) {
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<float>('c', {bS, iH, iW, iC});
|
auto input = NDArrayFactory::create<float>('c', {bS, iH, iW, iC});
|
||||||
auto weights = NDArrayFactory::create<float>('c', {kH, kW, oC, iC});
|
auto weights = NDArrayFactory::create<float>('c', {kH, kW, oC, iC});
|
||||||
auto exp = NDArrayFactory::create<float>('c', {bS, oH, oW, oC}, { 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f,
|
auto exp = NDArrayFactory::create<float>('c', {bS, oH, oW, oC}, { 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f,
|
||||||
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f,
|
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f,
|
||||||
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f,
|
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f,
|
||||||
52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f,
|
52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f,
|
||||||
2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f,
|
2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f,
|
||||||
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f,
|
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f,
|
||||||
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f,
|
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f,
|
||||||
52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f});
|
52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f});
|
||||||
input = 0.5;
|
input = 0.5;
|
||||||
weights.linspace(0.1, 0.1);
|
weights.linspace(0.1, 0.1);
|
||||||
|
@ -2445,13 +2477,13 @@ TEST_F(ConvolutionTests1, deconv2d_test2) {
|
||||||
|
|
||||||
auto input = NDArrayFactory::create<float>('c', {bS, oH, oW, oC});
|
auto input = NDArrayFactory::create<float>('c', {bS, oH, oW, oC});
|
||||||
auto weights = NDArrayFactory::create<float>('c', {kH, kW, iC, oC});
|
auto weights = NDArrayFactory::create<float>('c', {kH, kW, iC, oC});
|
||||||
auto exp = NDArrayFactory::create<float>('c', {bS, iH, iW, iC}, {2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f,
|
auto exp = NDArrayFactory::create<float>('c', {bS, iH, iW, iC}, {2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f,
|
||||||
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f,
|
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f,
|
||||||
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f,
|
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f,
|
||||||
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f,
|
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f,
|
||||||
2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f,
|
2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f,
|
||||||
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f,
|
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f,
|
||||||
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f,
|
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f,
|
||||||
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f });
|
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f });
|
||||||
input = 0.5;
|
input = 0.5;
|
||||||
weights.linspace(0.1, 0.1);
|
weights.linspace(0.1, 0.1);
|
||||||
|
@ -2673,13 +2705,13 @@ TYPED_TEST(TypedConvolutionTests1, deconv2d_tf_test1) {
|
||||||
auto input = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC});
|
auto input = NDArrayFactory::create<TypeParam>('c', {bS, oH, oW, oC});
|
||||||
auto weights = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC});
|
auto weights = NDArrayFactory::create<TypeParam>('c', {kH, kW, iC, oC});
|
||||||
auto outShape = NDArrayFactory::create<TypeParam>('c', {4}, {static_cast<TypeParam>(bS), static_cast<TypeParam>(iH), static_cast<TypeParam>(iW), static_cast<TypeParam>(iC)});
|
auto outShape = NDArrayFactory::create<TypeParam>('c', {4}, {static_cast<TypeParam>(bS), static_cast<TypeParam>(iH), static_cast<TypeParam>(iW), static_cast<TypeParam>(iC)});
|
||||||
auto exp = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC}, { 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f,
|
auto exp = NDArrayFactory::create<TypeParam>('c', {bS, iH, iW, iC}, { 2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f,
|
||||||
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f,
|
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f,
|
||||||
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f,
|
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f,
|
||||||
52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f,
|
52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f,
|
||||||
2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f,
|
2.75f, 7.75f, 12.75f, 17.75f, 22.75f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 30.5f, 40.5f, 50.5f, 60.5f, 70.5f, 27.75f, 32.75f, 37.75f, 42.75f, 47.75f,
|
||||||
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f,
|
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f,
|
||||||
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f,
|
55.5f, 65.5f, 75.5f, 85.5f, 95.5f, 161.f, 181.f, 201.f, 221.f, 241.f, 161.f, 181.f, 201.f, 221.f, 241.f, 105.5f, 115.5f, 125.5f, 135.5f, 145.5f,
|
||||||
52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f});
|
52.75f, 57.75f, 62.75f, 67.75f, 72.75f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 130.5f, 140.5f, 150.5f, 160.5f, 170.5f, 77.75f, 82.75f, 87.75f, 92.75f, 97.75f});
|
||||||
input = 0.5;
|
input = 0.5;
|
||||||
weights.linspace(0.1, 0.1);
|
weights.linspace(0.1, 0.1);
|
||||||
|
|
|
@ -428,10 +428,11 @@ TEST_F(DeclarableOpsTests13, CellContains_test_1) {
|
||||||
TEST_F(DeclarableOpsTests13, adjustHue_1) {
|
TEST_F(DeclarableOpsTests13, adjustHue_1) {
|
||||||
|
|
||||||
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32);
|
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray factor = NDArrayFactory::create<float>(0.5);
|
||||||
NDArray exp ('c', {2,2,3}, {100,0,44, 208,5,220, 177,230,97, 2,255,244}, nd4j::DataType::FLOAT32);
|
NDArray exp ('c', {2,2,3}, {100,0,44, 208,5,220, 177,230,97, 2,255,244}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::adjust_hue op;
|
nd4j::ops::adjust_hue op;
|
||||||
auto results = op.execute({&input}, {0.5}, {2});
|
auto results = op.execute({&input, &factor}, {}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -525,10 +526,11 @@ TEST_F(DeclarableOpsTests13, adjustHue_5) {
|
||||||
TEST_F(DeclarableOpsTests13, adjustSaturation_1) {
|
TEST_F(DeclarableOpsTests13, adjustSaturation_1) {
|
||||||
|
|
||||||
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32);
|
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray factor = NDArrayFactory::create<float>(0.5);
|
||||||
NDArray exp ('c', {2,2,3}, {50,100,78, 118.5,220,112.5, 190,163.5,230, 255,128.5,134}, nd4j::DataType::FLOAT32);
|
NDArray exp ('c', {2,2,3}, {50,100,78, 118.5,220,112.5, 190,163.5,230, 255,128.5,134}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::adjust_saturation op;
|
nd4j::ops::adjust_saturation op;
|
||||||
auto results = op.execute({&input}, {0.5}, {2});
|
auto results = op.execute({&input, &factor}, {}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
|
|
@ -159,18 +159,19 @@ TEST_F(DeclarableOpsTests15, Test_standarize_bp_1) {
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_1) {
|
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_1) {
|
||||||
auto x = NDArrayFactory::create<double>('c', {4,4,3});
|
auto x = NDArrayFactory::create<double>('c', {4,4,3});
|
||||||
auto e = NDArrayFactory::create<double>('c', {4,4,3}, {
|
NDArray factor = NDArrayFactory::create<double>(2.);
|
||||||
-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5,
|
auto e = NDArrayFactory::create<double>('c', {4,4,3}, {-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5,
|
||||||
2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5,
|
2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5,
|
||||||
26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5,
|
26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5,
|
||||||
50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5
|
50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5});
|
||||||
});
|
|
||||||
|
|
||||||
x.linspace(1.);
|
x.linspace(1.);
|
||||||
nd4j::ops::adjust_contrast op;
|
nd4j::ops::adjust_contrast op;
|
||||||
auto result = op.execute({&x}, {2.}, {}, {});
|
auto result = op.execute({&x, &factor}, {}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
auto out = result->at(0);
|
auto out = result->at(0);
|
||||||
// out->printIndexedBuffer("Adjusted Constrast");
|
|
||||||
ASSERT_TRUE(e.equalsTo(out));
|
ASSERT_TRUE(e.equalsTo(out));
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1774,6 +1774,28 @@ TEST_F(DeclarableOpsTests3, betainc_test10) {
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests3, betainc_test11) {
|
||||||
|
|
||||||
|
NDArray a('c', {4}, {0.7788f, 0.8012f, 0.7244f, 0.2309f}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray b('c', {4}, {0.7717f, 0.9281f, 0.9846f, 0.4838f}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray x('c', {4}, {0.9441f, 0.5957f, 0.8669f, 0.3502f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray expected('c', {4}, {0.912156, 0.634443, 0.898314, 0.624544}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
nd4j::ops::betainc op;
|
||||||
|
auto results = op.execute({&a, &b, &x}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto *output = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests3, zeta_test1) {
|
TEST_F(DeclarableOpsTests3, zeta_test1) {
|
||||||
|
|
||||||
|
@ -2092,8 +2114,26 @@ TEST_F(DeclarableOpsTests3, polygamma_test3) {
|
||||||
x.linspace(10.);
|
x.linspace(10.);
|
||||||
|
|
||||||
auto expected= NDArrayFactory::create<double>('c', {3,3}, {1.05166336e-01,-9.04983497e-03, 1.31009323e-03,-2.44459433e-04, 5.31593880e-05,-1.28049888e-05, 3.31755364e-06,-9.07408791e-07, 2.58758130e-07});
|
auto expected= NDArrayFactory::create<double>('c', {3,3}, {1.05166336e-01,-9.04983497e-03, 1.31009323e-03,-2.44459433e-04, 5.31593880e-05,-1.28049888e-05, 3.31755364e-06,-9.07408791e-07, 2.58758130e-07});
|
||||||
|
nd4j::ops::polygamma op;
|
||||||
|
auto results = op.execute({&n, &x}, {}, {});
|
||||||
|
|
||||||
//ASSERT_FALSE(true);
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto output = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests3, polygamma_test4) {
|
||||||
|
|
||||||
|
NDArray n('c', {3,4}, {/*0.7788*/0, 0,1,2,3,4,5,6,7,8,9,10}, nd4j::DataType::DOUBLE);
|
||||||
|
NDArray x('c', {3,4}, {0.7717,0.9281,0.9846,0.4838,0.6433,0.6041,0.6501,0.7612,0.7605,0.3948,0.9493,0.8600}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
|
NDArray expected('c', {3,4}, {/*std::numeric_limits<double>::quiet_NaN()*/-1.031918, -7.021327e-01, 1.682743e+00, -1.851378e+01,3.604167e+01, -3.008293e+02,
|
||||||
|
1.596005e+03, -4.876665e+03,4.510025e+04, -1.730340e+08, 6.110257e+05, -1.907087e+07}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
nd4j::ops::polygamma op;
|
nd4j::ops::polygamma op;
|
||||||
auto results = op.execute({&n, &x}, {}, {});
|
auto results = op.execute({&n, &x}, {}, {});
|
||||||
|
@ -2108,6 +2148,26 @@ TEST_F(DeclarableOpsTests3, polygamma_test3) {
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests3, digamma_1) {
|
||||||
|
|
||||||
|
NDArray x('c', {18}, {-25, -24.99999, -21.5, -21.2, -5.5, -4.1, -2.1, -0.5, -0.3, 0., 0.2, 1, 1.5, 2.2, 5.2, 19., 21, 22.2}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
|
NDArray expected('c', {18}, {std::numeric_limits<double>::infinity(), -99996.761229, 3.091129, 7.401432, 1.792911,11.196838,10.630354, 0.03649, 2.11331,
|
||||||
|
std::numeric_limits<double>::infinity(),-5.28904,-0.577216, 0.03649, 0.544293, 1.549434,2.917892, 3.020524, 3.077401}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
|
nd4j::ops::digamma op;
|
||||||
|
auto results = op.execute({&x}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto output = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests3, svd_test1) {
|
TEST_F(DeclarableOpsTests3, svd_test1) {
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue