Shyrma merge max ind (#443)

* - provide correct possible output types in mergeMaxIndex op

Signed-off-by: Yurii <iuriish@yahoo.com>

* - cleaning up the unneeded backprop arg in reverse_bp op

Signed-off-by: Yurii <iuriish@yahoo.com>

* - improve clipByNorm both ff and bp

Signed-off-by: Yurii <iuriish@yahoo.com>

* - implementation and testing clipByAvgNorm_bp op

Signed-off-by: Yurii <iuriish@yahoo.com>

* - pass biases in any way in dnnl lstm op, they are zeros when user doesn't provide them to us

Signed-off-by: Yurii <iuriish@yahoo.com>

* - start working on mkldnn concat op

Signed-off-by: Yurii <iuriish@yahoo.com>

* - further work on mkldnn concat

Signed-off-by: Yurii <iuriish@yahoo.com>

* missing declaration fix

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* - polishing mkl ops

Signed-off-by: Yurii <iuriish@yahoo.com>

* - testing and fixing bugs in mkl concat op

Signed-off-by: Yurii <iuriish@yahoo.com>

* - fix linkage error for windows cuda build

Signed-off-by: Yurii <iuriish@yahoo.com>

* - further conflicts resolving with master

Signed-off-by: Yurii <iuriish@yahoo.com>

* - fix format tags in mkldnn matmul op

Signed-off-by: Yurii <iuriish@yahoo.com>

* - provide additional type cast in clip.cu

Signed-off-by: Yurii <iuriish@yahoo.com>

* - finally bug in mkldnn tanh_bp was caught

Co-authored-by: raver119@gmail.com <raver119@gmail.com>
master
Yurii Shyrma 2020-05-12 07:47:09 +03:00 committed by GitHub
parent 872a511042
commit 76f3553679
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 2130 additions and 2397 deletions

View File

@ -981,12 +981,12 @@ namespace sd {
* these methods suited for FlatBuffers use
*/
template <typename T>
std::vector<T> getBufferAsVector();
std::vector<T> getBufferAsVector() const;
std::vector<Nd4jLong> getShapeAsVector() const;
std::vector<int> getShapeAsVectorInt() const;
std::vector<Nd4jLong> getShapeInfoAsVector();
std::vector<int64_t> getShapeInfoAsFlatVector();
std::vector<int64_t> getShapeAsFlatVector();
std::vector<Nd4jLong> getShapeInfoAsVector() const;
std::vector<int64_t> getShapeInfoAsFlatVector() const;
std::vector<int64_t> getShapeAsFlatVector() const;
/**
* set new order and shape in case of suitable array length (in-place operation)

View File

@ -982,16 +982,16 @@ std::string NDArray::asString(Nd4jLong limit) {
////////////////////////////////////////////////////////////////////////
template<typename T>
std::vector<T> NDArray::getBufferAsVector() {
std::vector<T> NDArray::getBufferAsVector() const {
std::vector<T> vector(lengthOf());
for (Nd4jLong e = 0; e < lengthOf(); e++)
vector[e] = this->e<T>(e);
return vector;
}
BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT std::vector, NDArray::getBufferAsVector(), LIBND4J_TYPES);
BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT std::vector, NDArray::getBufferAsVector() const, LIBND4J_TYPES);
////////////////////////////////////////////////////////////////////////
std::vector<int64_t> NDArray::getShapeAsFlatVector() {
std::vector<int64_t> NDArray::getShapeAsFlatVector() const {
std::vector<int64_t> vector(this->rankOf());
for (int e = 0; e < this->rankOf(); e++)
vector[e] = static_cast<int64_t>(this->sizeAt(e));
@ -1019,7 +1019,7 @@ std::vector<int> NDArray::getShapeAsVectorInt() const {
}
////////////////////////////////////////////////////////////////////////
std::vector<int64_t> NDArray::getShapeInfoAsFlatVector() {
std::vector<int64_t> NDArray::getShapeInfoAsFlatVector() const {
int magicNumber = shape::shapeInfoLength(this->rankOf());
std::vector<int64_t> vector(magicNumber);
@ -1030,7 +1030,7 @@ std::vector<int64_t> NDArray::getShapeInfoAsFlatVector() {
}
////////////////////////////////////////////////////////////////////////
std::vector<Nd4jLong> NDArray::getShapeInfoAsVector() {
std::vector<Nd4jLong> NDArray::getShapeInfoAsVector() const {
int magicNumber = shape::shapeInfoLength(this->rankOf());
std::vector<Nd4jLong> vector(magicNumber);
for (int e = 0; e < magicNumber; e++)

View File

@ -15,7 +15,8 @@
******************************************************************************/
//
// @author raver119@gmail.com
// @author raver119@gmail.com
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include <system/op_boilerplate.h>
@ -27,24 +28,58 @@
namespace sd {
namespace ops {
//////////////////////////////////////////////////////////////////////////
CONFIGURABLE_OP_IMPL(clipbyavgnorm, 1, 1, true, 1, 0) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
const bool isInplace = block.isInplace();
auto ts = NDArrayFactory::create(T_ARG(0), block.launchContext());
auto clipNorm = NDArrayFactory::create(T_ARG(0), block.launchContext());
helpers::clipByAveraged(block.launchContext(), *input, *output, *block.getIArguments(), ts, isInplace);
helpers::clipByNorm(block.launchContext(), *input, *output, *block.getIArguments(), clipNorm, isInplace, true);
return Status::OK();
}
DECLARE_TYPES(clipbyavgnorm) {
getOpDescriptor()
->setAllowedInputTypes(sd::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
DECLARE_TYPES(clipbyavgnorm) {
getOpDescriptor()
->setAllowedInputTypes(sd::DataType::ANY)
->setAllowedOutputTypes({ALL_FLOATS});
}
//////////////////////////////////////////////////////////////////////////
CUSTOM_OP_IMPL(clipbyavgnorm_bp, 2, 1, false, 1, 0) {
auto input = INPUT_VARIABLE(0);
auto gradO = INPUT_VARIABLE(1);
auto gradI = OUTPUT_VARIABLE(0);
const auto clipNorm = NDArrayFactory::create(gradI->dataType(), T_ARG(0), block.launchContext());
helpers::clipByNormBp(block.launchContext(), *input, *gradO, *gradI, *block.getIArguments(), clipNorm, true);
return Status::OK();
}
//////////////////////////////////////////////////////////////////////////
DECLARE_SHAPE_FN(clipbyavgnorm_bp) {
Nd4jLong *newShape = nullptr;
COPY_SHAPE(inputShape->at(1), newShape);
return SHAPELIST(CONSTANT(newShape));
}
DECLARE_TYPES(clipbyavgnorm_bp) {
getOpDescriptor()
->setAllowedInputTypes(0, DataType::ANY)
->setAllowedInputTypes(1, {ALL_FLOATS})
->setAllowedOutputTypes(0, {ALL_FLOATS});
}
}
}

View File

@ -31,10 +31,10 @@ namespace ops {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
const auto clipNorm = NDArrayFactory::create(input->dataType(), T_ARG(0), block.launchContext());
const auto clipNorm = NDArrayFactory::create(output->dataType(), T_ARG(0), block.launchContext());
const bool isInplace = block.isInplace();
helpers::clipByNorm(block.launchContext(), *input, *output, *block.getIArguments(), clipNorm, isInplace);
helpers::clipByNorm(block.launchContext(), *input, *output, *block.getIArguments(), clipNorm, isInplace, false);
return Status::OK();
}
@ -45,15 +45,15 @@ namespace ops {
auto gradO = INPUT_VARIABLE(1);
auto gradI = OUTPUT_VARIABLE(0);
const auto clipNorm = NDArrayFactory::create(T_ARG(0));
const auto clipNorm = NDArrayFactory::create(gradI->dataType(), T_ARG(0), block.launchContext());
helpers::clipByNormBP(block.launchContext(), *input, *gradO, *gradI, *block.getIArguments(), clipNorm);
helpers::clipByNormBp(block.launchContext(), *input, *gradO, *gradI, *block.getIArguments(), clipNorm, false);
return Status::OK();
}
DECLARE_SHAPE_FN(clipbynorm_bp) {
auto inShapeInfo = inputShape->at(0);
auto inShapeInfo = inputShape->at(1);
Nd4jLong *newShape = nullptr;
COPY_SHAPE(inShapeInfo, newShape);

View File

@ -23,8 +23,8 @@
#include<ops/declarable/helpers/transforms.h>
#include<array>
namespace sd {
namespace ops {
namespace sd {
namespace ops {
//////////////////////////////////////////////////////////////////////////
@ -85,6 +85,7 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 0) {
// ******** input validation ******** //
REQUIRE_TRUE(allOfSameType, 0, "CONCAT op: all of input arrays must have same type !");
REQUIRE_TRUE(nonEmptyArrs[0]->dataType() == OUTPUT_VARIABLE(0)->dataType(), 0, "CONCAT op: output array should have the same type as inputs arrays !");
REQUIRE_TRUE(0 <= axis && (axis < rank || (axis == 0 && rank == 0)), 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis);
for(int i = 1; i < numOfNonEmptyArrs; ++i)

View File

@ -33,7 +33,7 @@ CUSTOM_OP_IMPL(mergemaxindex, -1, 1, false, 0, 0) {
auto output = OUTPUT_VARIABLE(0);
std::vector<const NDArray*> inArrs(block.width());
for(int i = 0; i < block.width(); ++i)
inArrs[i] = INPUT_VARIABLE(i);
@ -46,7 +46,8 @@ DECLARE_SYN(MergeMaxIndex, mergemaxindex);
DECLARE_TYPES(mergemaxindex) {
getOpDescriptor()
->setAllowedInputTypes({ALL_INTS, ALL_FLOATS});
->setAllowedInputTypes({ALL_INTS, ALL_FLOATS})
->setAllowedOutputTypes({ALL_INDICES});
}
}
DECLARE_SHAPE_FN(mergemaxindex) {

View File

@ -52,7 +52,7 @@ namespace ops {
else {
// check the consistency of input dimensions to reverse along
shape::checkDimensions(input->rankOf(), axis);
helpers::reverse(block.launchContext(), input, output, &axis, false);
helpers::reverse(block.launchContext(), input, output, &axis);
}
return Status::OK();
@ -85,7 +85,7 @@ namespace ops {
// check the consistency of input dimensions to reverse along
shape::checkDimensions(input->rankOf(), axis);
// we just reverse back original array
helpers::reverse(block.launchContext(), eps, output, &axis, false);
helpers::reverse(block.launchContext(), eps, output, &axis);
}
return Status::OK();

View File

@ -36,6 +36,7 @@ namespace sd {
#if NOT_EXCLUDED(OP_clipbyavgnorm)
DECLARE_CONFIGURABLE_OP(clipbyavgnorm, 1, 1, true, 1, 0);
DECLARE_CUSTOM_OP(clipbyavgnorm_bp, 2, 1, false, 1, 0);
#endif
#if NOT_EXCLUDED(OP_cumsum)

View File

@ -15,83 +15,134 @@
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018
// @author Yurii Shyrma (iuriish@yahoo.com)
// @author sgazeos@gmail.com
// @author raver119@gmail.com
//
#include <ops/declarable/helpers/transforms.h>
#include <helpers/Loops.h>
#include <execution/Threads.h>
namespace sd {
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////////
template<typename T>
static void clipByNorm_(NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace) {
void clipByNorm(sd::LaunchContext* context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace, const bool useAverage) {
const int rank = input.rankOf();
const auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions);
NDArray* z = nullptr;
const T normActual = norm2.e<T>(0);
const T normClip = clipNorm.e<T>(0);
if(isInplace) {
z = &input;
}
else {
output.assign(input);
z = &output;
}
if (isInplace) {
if(dimensions.empty()) {
if(norm2.lengthOf() == 1) {
const NDArray actualNorm = useAverage ? z->reduceAlongDimension(reduce::Norm2, {}) / z->lengthOf() : z->reduceAlongDimension(reduce::Norm2, {});
if(normActual > normClip)
input *= (normClip / normActual);
}
else {
auto listOfInSubArrs = input.allTensorsAlongDimension(dimensions);
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i++) {
const T iNormActual = norm2.e<T>(i);
if (iNormActual > normClip)
*listOfInSubArrs.at(i) *= normClip / iNormActual;
}
};
samediff::Threads::parallel_tad(func, 0, listOfInSubArrs.size());
}
if(actualNorm.e<float>(0) > clipNorm.e<float>(0))
*z *= clipNorm / actualNorm;
}
else {
if(norm2.lengthOf() == 1) {
auto listOfSubArrs = z->allTensorsAlongDimension(dimensions);
if(normActual > normClip)
output.assign(input * (normClip / normActual));
else
output.assign(input);
}
else {
auto listOfInSubArrs = input.allTensorsAlongDimension(dimensions);
auto listOfOutSubArrs = output.allTensorsAlongDimension(dimensions);
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i++) {
auto inputSubArr = listOfInSubArrs.at(i);
auto outputSubArr = listOfOutSubArrs.at(i);
outputSubArr->assign(inputSubArr);
const T iNormActual = norm2.e<T>(i);
if (iNormActual > clipNorm.e<T>(0))
*outputSubArr *= clipNorm / iNormActual;
}
};
samediff::Threads::parallel_tad(func, 0, listOfInSubArrs.size());
}
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i++) {
const NDArray actualNorm = useAverage ? listOfSubArrs.at(i)->reduceAlongDimension(reduce::Norm2, {}) / listOfSubArrs.at(i)->lengthOf() : listOfSubArrs.at(i)->reduceAlongDimension(reduce::Norm2, {});
if(actualNorm.e<float>(0) > clipNorm.e<float>(0))
*listOfSubArrs.at(i) *= clipNorm / actualNorm;
}
};
samediff::Threads::parallel_tad(func, 0, listOfSubArrs.size());
}
}
//////////////////////////////////////////////////////////////////////////
void clipByNorm(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace) {
BUILD_SINGLE_SELECTOR(output.dataType(), clipByNorm_, (input, output, dimensions, clipNorm, isInplace), FLOAT_TYPES);
template<typename T>
static void clipByNormBp_(const NDArray& input, const NDArray& gradO, NDArray& gradI, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool useAverage) {
const int rank = input.rankOf();
auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions);
auto sums = input.reduceAlongDimension(reduce::Sum, dimensions);
if(norm2.lengthOf() == 1) {
const T norm = useAverage ? norm2.e<T>(0) / input.lengthOf() : norm2.e<T>(0);
auto clipVal = clipNorm.e<T>(0);
if(norm > clipVal) {
const T sum = sums.e<T>(0); // reduce to scalar
const T factor1 = clipVal / norm;
const T factor2 = static_cast<T>(1.f) / (norm * norm); // 1 / (norm*norm*norm)
auto lambda = LAMBDA_TT(x, y, sum, factor1, factor2) {
return factor1 * y * (static_cast<T>(1.f) - factor2 * x * sum);
};
const_cast<NDArray&>(input).applyPairwiseLambda<T>(const_cast<NDArray&>(gradO), lambda, gradI);
}
else
gradI.assign(gradO);
}
else {
auto gradISubArrs = gradI.allTensorsAlongDimension({dimensions});
auto gradOSubArrs = gradO.allTensorsAlongDimension({dimensions});
auto inputSubArrs = input.allTensorsAlongDimension({dimensions});
auto clipVal = clipNorm.e<T>(0);
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i++) {
auto gradOSubArr = gradOSubArrs.at(i);
auto gradISubArr = gradISubArrs.at(i);
const T norm = useAverage ? norm2.e<T>(i) / gradISubArr->lengthOf() : norm2.e<T>(i);
if (norm > clipVal) {
auto inputSubArr = inputSubArrs.at(i);
const T sum = sums.e<T>(i); // reduce to scalar
const T factor1 = clipVal / norm;
const T factor2 = static_cast<T>(1.f) / (norm * norm); // 1 / (norm*norm*norm)
auto lambda = LAMBDA_TT(x, y, sum, factor1, factor2) {
return factor1 * y * (static_cast<T>(1.f) - factor2 * x * sum);
};
inputSubArr->applyPairwiseLambda<T>(*gradOSubArr, lambda, *gradISubArr);
}
else
gradISubArr->assign(gradOSubArr);
}
};
samediff::Threads::parallel_tad(func, 0, gradISubArrs.size());
}
}
BUILD_SINGLE_TEMPLATE(template void clipByNormBp_, (const NDArray& input, const NDArray& gradO, NDArray& gradI, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool useAverage), FLOAT_TYPES);
//////////////////////////////////////////////////////////////////////////
void clipByNormBp(sd::LaunchContext* context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool useAverage) {
const NDArray& castedInput = gradI.dataType() == input.dataType() ? input : input.cast(gradI.dataType());
BUILD_SINGLE_SELECTOR(gradI.dataType(), clipByNormBp_, (castedInput, gradO, gradI, dimensions, clipNorm, useAverage), FLOAT_TYPES);
}
template <typename T>
@ -132,125 +183,6 @@ void clipByNorm(sd::LaunchContext * context, NDArray& input, NDArray& output, co
BUILD_SINGLE_TEMPLATE(template void clipByGlobalNorm_, (std::vector<NDArray*> const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace), FLOAT_TYPES);
//////////////////////////////////////////////////////////////////////////
template<typename T>
static void clipByNormBP_(const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector<int>& dimensions, const NDArray& clipNorm) {
const int rank = input.rankOf();
auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions);
if(norm2.lengthOf() == 1) {
const T N = norm2.e<T>(0);
auto cn = clipNorm.e<T>(0);
if(N > cn) {
const T sumOfProd = (input * gradO).reduceNumber(reduce::Sum).e<T>(0); // reduce to scalar
const T factor1 = static_cast<T>(1.f) / N;
const T factor3 = factor1 / (N * N); // 1 / (N*N*N)
auto lambda = LAMBDA_TT(elem1, elem2, cn, sumOfProd, factor1, factor3) {
return cn * (factor1 * elem2 - factor3 * elem1 * sumOfProd);
};
(const_cast<NDArray&>(input)).applyPairwiseLambda<T>(const_cast<NDArray&>(gradO), lambda, gradI);
}
else
gradI.assign(gradO);
}
else {
auto gradISubArrs = gradI.allTensorsAlongDimension({dimensions});
auto gradOSubArrs = gradO.allTensorsAlongDimension({dimensions});
auto inputSubArrs = input.allTensorsAlongDimension({dimensions});
auto cn = clipNorm.e<T>(0);
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i++) {
T N = norm2.e<T>(i);
auto gradOSubArr = gradOSubArrs.at(i);
auto gradISubArr = gradISubArrs.at(i);
if (N > cn) {
auto inputSubArr = inputSubArrs.at(i);
const T sumOfProd = (*inputSubArr * *gradOSubArr).reduceNumber(reduce::Sum).e<T>(0); // reduce to scalar
const T factor1 = static_cast<T>(1.f) / N;
const T factor3 = factor1 / (N * N); // 1 / (N*N*N)
auto lambda = LAMBDA_TT(elem1, elem2, cn, sumOfProd, factor1, factor3) {
return cn * (factor1 * elem2 - factor3 * elem1 * sumOfProd);
};
inputSubArr->applyPairwiseLambda<T>(*gradOSubArr, lambda, *gradISubArr);
} else
gradISubArr->assign(gradOSubArr);
}
};
samediff::Threads::parallel_tad(func, 0, gradISubArrs.size());
}
}
void clipByNormBP(sd::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector<int>& dimensions, const NDArray& clipNorm) {
BUILD_SINGLE_SELECTOR(gradI.dataType(), clipByNormBP_, (input, gradO, gradI, dimensions, clipNorm), FLOAT_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void clipByNormBP_, (const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector<int>& dimensions, const NDArray& clipNorm), FLOAT_TYPES);
//////////////////////////////////////////////////////////////////////////
template<typename T>
static void clipByAveraged_(NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace) {
auto cn = clipNorm.e<T>(0);
if (dimensions.size() == 0) {
// all-reduce
T n2 = input.reduceNumber(reduce::Norm2).e<T>(0) / input.lengthOf();
if (n2 <= cn) {
if (!isInplace)
output.assign(input);
}
else {
const T factor = cn / n2;
auto lambda = LAMBDA_T(_x, factor) { return _x * factor; };
input.applyLambda<T>(lambda, output);
}
}
else {
// along dimension
auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions, false);
if (!isInplace)
output.assign(input);
auto tads = output.allTensorsAlongDimension(dimensions);
// TODO: make this CUDA-compliant somehow
for (int e = 0; e < tads.size(); e++) {
T n2 = norm2.e<T>(e) / tads.at(e)->lengthOf();
const T factor = cn / n2;
if (n2 > cn) {
auto lambda = LAMBDA_T(_x, factor) {return _x * factor;};
tads.at(e)->applyLambda<T>(lambda, output);
}
}
}
}
void clipByAveraged(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace) {
BUILD_SINGLE_SELECTOR(input.dataType(), clipByAveraged_, (input, output, dimensions, clipNorm, isInplace), FLOAT_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void clipByAveraged_, (NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace), FLOAT_TYPES);
/*
if (d1 > params[1])
return params[1];
else if (d1 < params[0])
return params[0];
else return d1;
*/
template <typename T>
static void clipByValue_(NDArray& input, double leftBound, double rightBound, NDArray& output) {

View File

@ -29,7 +29,7 @@ namespace helpers {
//////////////////////////////////////////////////////////////////////////
template<typename T>
template<typename X, typename Z>
static void mergeMaxIndex_(const std::vector<const NDArray*>& inArrs, NDArray& output) {
const Nd4jLong numArgs = inArrs.size();
@ -37,17 +37,18 @@ static void mergeMaxIndex_(const std::vector<const NDArray*>& inArrs, NDArray& o
auto func = PRAGMA_THREADS_FOR {
for (auto e = start; e < stop; e++) {
T max = -DataTypeUtils::max<T>();
Nd4jLong idx = 0;
X max = -DataTypeUtils::max<X>();
Z idx = static_cast<Z>(0);
for (Nd4jLong i = 0; i < numArgs; i++) {
T v = inArrs[i]->e<T>(e);
X v = inArrs[i]->t<X>(e);
if (v > max) {
max = v;
idx = i;
idx = static_cast<Z>(i);
}
}
output.p(e, idx);
// FIXME, use .r<Z>(e)
output.t<Z>(e) = static_cast<Z>(idx);
}
};
@ -55,14 +56,14 @@ static void mergeMaxIndex_(const std::vector<const NDArray*>& inArrs, NDArray& o
}
void mergeMaxIndex(sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray& output) {
BUILD_SINGLE_SELECTOR(inArrs[0]->dataType(), mergeMaxIndex_, (inArrs, output), LIBND4J_TYPES);
BUILD_DOUBLE_SELECTOR(inArrs[0]->dataType(), output.dataType(), mergeMaxIndex_, (inArrs, output), LIBND4J_TYPES, INDEXING_TYPES);
}
//////////////////////////////////////////////////////////////////////////
template<typename T>
static void mergeMax_(const std::vector<const NDArray*>& inArrs, NDArray& output) {
const Nd4jLong numArgs = inArrs.size();
auto x = inArrs[0];
@ -89,15 +90,15 @@ void mergeMax(sd::LaunchContext * context, const std::vector<const NDArray*>& in
//////////////////////////////////////////////////////////////////////////
template<typename T>
static void mergeMaxBp_(const std::vector<const NDArray*>& inArrs, std::vector<NDArray*>& outArrs) {
// outArrs.size() == inArrs.size() - 1
const Nd4jLong numArgs = outArrs.size();
// last array is gradient
const auto gradient = inArrs[numArgs]->bufferAsT<T>();
auto length = inArrs[numArgs]->lengthOf();
bool bSameOrderAndEws1 = (1 == inArrs[numArgs]->ews());
if (bSameOrderAndEws1) {
auto gradOrdering = inArrs[numArgs]->ordering();
@ -108,8 +109,8 @@ static void mergeMaxBp_(const std::vector<const NDArray*>& inArrs, std::vector<N
bSameOrderAndEws1 &= (1 == outArrs[i]->ews());
}
}
if(bSameOrderAndEws1){
auto func = PRAGMA_THREADS_FOR{
for (auto e = start; e < stop; e++) {
@ -130,7 +131,7 @@ static void mergeMaxBp_(const std::vector<const NDArray*>& inArrs, std::vector<N
samediff::Threads::parallel_for(func, 0, length);
return;
}
auto gradShape = inArrs[numArgs]->shapeInfo();
std::vector<bool> vbSameShaepeAndStrides(numArgs);
for (int i = 0; i < numArgs; ++i) {
@ -145,12 +146,12 @@ static void mergeMaxBp_(const std::vector<const NDArray*>& inArrs, std::vector<N
shape::index2coordsCPU(start, e, gradShape, coords);
const auto gradOffset = shape::getOffset(gradShape, coords);
T max = -DataTypeUtils::max<T>();
Nd4jLong nMaxIndex = 0;
for (Nd4jLong i = 0; i < numArgs; i++) {
const auto xOffset = vbSameShaepeAndStrides[i] ? gradOffset : shape::getOffset(inArrs[i]->shapeInfo(), coords);
const T* v = inArrs[i]->bufferAsT<T>();
if (v[xOffset] > max) {
@ -160,7 +161,7 @@ static void mergeMaxBp_(const std::vector<const NDArray*>& inArrs, std::vector<N
}
const auto zOffset = vbSameShaepeAndStrides[nMaxIndex] ? gradOffset : shape::getOffset(outArrs[nMaxIndex]->shapeInfo(), coords);
T* z = outArrs[nMaxIndex]->bufferAsT<T>();
z[zOffset] = gradient[gradOffset];
}

View File

@ -193,13 +193,10 @@ static void reverseSequence_(sd::LaunchContext * context, const NDArray* input,
}
//////////////////////////////////////////////////////////////////////////
void reverse(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector<int>* intArgs, bool isBackProp) {
void reverse(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector<int>* intArgs) {
// we need to reverse axis only if that's new op
std::vector<int> dimensions = isBackProp ? ShapeUtils::evalDimsToExclude(input->rankOf(), *intArgs) : *intArgs;
auto listOut = output->allTensorsAlongDimension(dimensions);
auto listIn = input->allTensorsAlongDimension(dimensions);
auto listOut = output->allTensorsAlongDimension(*intArgs);
auto listIn = input->allTensorsAlongDimension(*intArgs);
NDArray *subArrIn, *subArrOut;

View File

@ -0,0 +1,334 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
// @author sgazeos@gmail.com
// @author raver119@gmail.com
//
#include <ops/declarable/helpers/transforms.h>
#include <helpers/ShapeUtils.h>
#include <helpers/PointersManager.h>
#include <helpers/ConstantTadHelper.h>
namespace sd {
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////////
template<typename T>
__global__ static void clipByNormCuda(const void* vClipNorm, const void* vNorm, const Nd4jLong* normShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int* dimensions, const int dimsLen, const bool useAverage) {
const T clipNorm = *reinterpret_cast<const T*>(vClipNorm);
const T* norm = reinterpret_cast<const T*>(vNorm);
T* z = reinterpret_cast<T*>(vz);
__shared__ Nd4jLong zLen, tadLen, totalThreads;
if (threadIdx.x == 0) {
zLen = shape::length(zShapeInfo);
tadLen = zLen / shape::length(normShapeInfo);
totalThreads = gridDim.x * blockDim.x;
}
__syncthreads();
int zCoords[MAX_RANK], normCoords[MAX_RANK];
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
shape::index2coords(i, zShapeInfo, zCoords);
// deduce norm coords
for (int j = 0; j < dimsLen; ++j)
normCoords[j] = zCoords[dimensions[j]];
const T actualNorm = useAverage ? norm[shape::getOffset(normShapeInfo, normCoords)] / tadLen : norm[shape::getOffset(normShapeInfo, normCoords)];
if(actualNorm > clipNorm)
z[shape::getOffset(zShapeInfo, zCoords)] *= clipNorm / actualNorm;
}
}
//////////////////////////////////////////////////////////////////////////
template<typename T>
__host__ static void clipByNormCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream,
const void* vClipNorm, const void* vNorm, const Nd4jLong* normShapeInfo, void* vz, const Nd4jLong* zShapeInfo,
const int* dimensions, const int dimsLen, const bool useAverage) {
clipByNormCuda<T><<<blocksPerGrid, threadsPerBlock, 512, *stream>>>(vClipNorm, vNorm, normShapeInfo, vz, zShapeInfo, dimensions, dimsLen, useAverage);
}
//////////////////////////////////////////////////////////////////////////
void clipByNorm(sd::LaunchContext* context, NDArray& input, NDArray& output, const std::vector<int>& dims, const NDArray& clipNorm, const bool isInplace, const bool useAverage) {
NDArray* z = nullptr;
if(isInplace) {
z = &input;
}
else {
output.assign(input);
z = &output;
}
if(dims.empty()) {
const NDArray actualNorm = useAverage ? z->reduceAlongDimension(reduce::Norm2, {}) / z->lengthOf() : z->reduceAlongDimension(reduce::Norm2, {});
if(actualNorm.e<float>(0) > clipNorm.e<float>(0))
*z *= clipNorm / actualNorm;
}
else {
const NDArray actualNorms = z->reduceAlongDimension(reduce::Norm2, dims);
std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(z->rankOf(), dims);
const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int blocksPerGrid = (z->lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
PointersManager manager(context, "clipByNorm");
const int* dimensions = reinterpret_cast<const int*>(manager.replicatePointer(dimsToExclude.data(), dimsToExclude.size() * sizeof(int)));
NDArray::prepareSpecialUse({z}, {z, &actualNorms, &clipNorm});
BUILD_SINGLE_SELECTOR(z->dataType(), clipByNormCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), clipNorm.specialBuffer(), actualNorms.specialBuffer(), actualNorms.specialShapeInfo(), z->specialBuffer(), z->specialShapeInfo(), dimensions, (int)dimsToExclude.size(), useAverage), FLOAT_TYPES);
NDArray::registerSpecialUse({z}, {z, &actualNorms, &clipNorm});
manager.synchronize();
}
}
//////////////////////////////////////////////////////////////////////////
template<typename T>
__global__ static void clipByNormBpCuda(const void* vClipNorm,
const void* vx, const Nd4jLong* xShapeInfo, // input
const void* vy, const Nd4jLong* yShapeInfo, // gradO
const void* vNorm, const Nd4jLong* normShapeInfo,
const void* vSum, const Nd4jLong* sumShapeInfo,
void* vz, const Nd4jLong* zShapeInfo, // gradI
const int* dimensions, const int dimsLen, const bool useAverage) {
const T clipNorm = *reinterpret_cast<const T*>(vClipNorm);
const T* norm = reinterpret_cast<const T*>(vNorm);
const T* sum = reinterpret_cast<const T*>(vSum);
const T* x = reinterpret_cast<const T*>(vx);
const T* y = reinterpret_cast<const T*>(vy);
T* z = reinterpret_cast<T*>(vz);
__shared__ Nd4jLong zLen, tadLen, totalThreads;
__shared__ bool sameOffsets;
if (threadIdx.x == 0) {
zLen = shape::length(zShapeInfo);
tadLen = zLen / shape::length(normShapeInfo);
totalThreads = gridDim.x * blockDim.x;
sameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo, zShapeInfo);
}
__syncthreads();
int zCoords[MAX_RANK], normCoords[MAX_RANK];
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
shape::index2coords(i, zShapeInfo, zCoords);
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
const auto yOffset = sameOffsets ? zOffset : shape::getOffset(yShapeInfo, zCoords);
// deduce norm coords
for (int j = 0; j < dimsLen; ++j)
normCoords[j] = zCoords[dimensions[j]];
const T actualNorm = useAverage ? norm[shape::getOffset(normShapeInfo, normCoords)] / tadLen : norm[shape::getOffset(normShapeInfo, normCoords)];
if(actualNorm > clipNorm) {
const T sumVal = sum[shape::getOffset(sumShapeInfo, normCoords)];
const auto xOffset = sameOffsets ? zOffset : shape::getOffset(xShapeInfo, zCoords);
z[zOffset] = (clipNorm / actualNorm) * y[yOffset] * (static_cast<T>(1.f) - (x[xOffset] * sumVal) / (actualNorm * actualNorm));
}
else
z[zOffset] = y[yOffset];
}
}
//////////////////////////////////////////////////////////////////////////
template<typename T>
void clipByNormBp_(sd::LaunchContext* context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const std::vector<int>& dims, const NDArray& clipNorm, const bool useAverage) {
const int rank = input.rankOf();
auto actualNorms = input.reduceAlongDimension(reduce::Norm2, dims);
if(actualNorms.lengthOf() == 1) {
const T norm = useAverage ? actualNorms.e<T>(0) / static_cast<T>(input.lengthOf()) : actualNorms.e<T>(0);
auto clipVal = clipNorm.e<T>(0);
if(norm > clipVal) {
const T sum = input.reduceNumber(reduce::Sum).e<T>(0); // reduce to scalar
const T factor1 = clipVal / norm;
const T factor2 = static_cast<T>(1.f) / (norm * norm); // 1 / (norm*norm*norm)
auto lambda = LAMBDA_TT(x, y, sum, factor1, factor2) {
return factor1 * y * (static_cast<T>(1.f) - factor2 * x * sum);
};
const_cast<NDArray&>(input).applyPairwiseLambda(const_cast<NDArray&>(gradO), lambda, gradI);
}
else
gradI.assign(gradO);
}
else {
const NDArray actualNorms = input.reduceAlongDimension(reduce::Norm2, dims);
const NDArray sums = input.reduceAlongDimension(reduce::Sum, dims);
std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(gradI.rankOf(), dims);
const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int blocksPerGrid = (gradI.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
PointersManager manager(context, "clipByNormBp");
const int* dimensions = reinterpret_cast<const int*>(manager.replicatePointer(dimsToExclude.data(), dimsToExclude.size() * sizeof(int)));
NDArray::prepareSpecialUse({&gradI}, {&actualNorms, &sums, &clipNorm, &input, &gradO});
clipByNormBpCuda<T><<<blocksPerGrid, threadsPerBlock, 512, *context->getCudaStream()>>>(clipNorm.specialBuffer(), input.specialBuffer(), input.specialShapeInfo(), gradO.specialBuffer(), gradO.specialShapeInfo(), actualNorms.specialBuffer(), actualNorms.specialShapeInfo(), sums.specialBuffer(), sums.specialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), dimensions, (int)dimsToExclude.size(), useAverage);
NDArray::registerSpecialUse({&gradI}, {&actualNorms, &sums, &clipNorm, &input, &gradO});
manager.synchronize();
}
}
BUILD_SINGLE_TEMPLATE(template void clipByNormBp_, (sd::LaunchContext* context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool useAverage), FLOAT_TYPES);
//////////////////////////////////////////////////////////////////////////
void clipByNormBp(sd::LaunchContext* context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool useAverage) {
const NDArray& castedInput = gradI.dataType() == input.dataType() ? input : input.cast(gradI.dataType());
BUILD_SINGLE_SELECTOR(gradI.dataType(), clipByNormBp_, (context, castedInput, gradO, gradI, dimensions, clipNorm, useAverage), FLOAT_TYPES);
}
template <typename T>
void clipByGlobalNorm_(sd::LaunchContext * context, std::vector<NDArray*> const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace) {
NDArray globalNorm = NDArrayFactory::create<T>(0, inputs[0]->getContext()); //sqrt(sum([l2norm(t)**2 for t in t_list]))
for (auto i = 0; i < inputs.size(); i++) {
auto input = inputs[i];
auto l2norm = input->reduceNumber(reduce::Norm2);
globalNorm += l2norm * l2norm;
}
globalNorm.applyTransform(transform::Sqrt, globalNorm); // = sd::math::nd4j_sqrt(globalNorm);
outputs[inputs.size()]->p(0, globalNorm);
globalNorm.syncToHost();
const T factor = static_cast<T>(clipNorm) / globalNorm.e<T>(0);
for (size_t e = 0; e < inputs.size(); e++) {
// all-reduce
auto input = inputs[e];
auto output = outputs[e];
if (globalNorm.e<double>(0) <= clipNorm) {
output->assign(input);
}
else {
auto lambda = LAMBDA_T(_x, factor) { return _x * factor; };
input->applyLambda(lambda, *output);
}
}
}
void clipByGlobalNorm(sd::LaunchContext * context, std::vector<NDArray*> const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace) {
BUILD_SINGLE_SELECTOR(outputs[0]->dataType(), clipByGlobalNorm_, (context, inputs, clipNorm, workspace, outputs, isInplace), FLOAT_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void clipByGlobalNorm_, (sd::LaunchContext * context, std::vector<NDArray*> const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace), FLOAT_TYPES);
template <typename T>
static void __global__ clipByValueKernel(void* input, const Nd4jLong* inputShape, void* output, const Nd4jLong* outputShape, double leftBound, double rightBound) {
__shared__ T* outputBuf;
__shared__ T* inputBuf;
__shared__ Nd4jLong length;
__shared__ bool linearBuffers;
if (threadIdx.x == 0) {
outputBuf = reinterpret_cast<T *>(output);
inputBuf = reinterpret_cast<T *>(input);
length = shape::length(inputShape);
linearBuffers = shape::elementWiseStride(inputShape) == shape::elementWiseStride(outputShape) && shape::elementWiseStride(inputShape) == 1;
}
__syncthreads();
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const auto step = gridDim.x * blockDim.x;
for (Nd4jLong e = tid; e < length; e += step) {
if (linearBuffers) {
if (inputBuf[e] > rightBound) outputBuf[e] = (T) rightBound;
else if (inputBuf[e] < leftBound) outputBuf[e] = (T) leftBound;
else outputBuf[e] = inputBuf[e];
}
else {
auto inputOffset = shape::getIndexOffset(e, inputShape);
auto outputOffset = shape::getIndexOffset(e, outputShape);
if (inputBuf[inputOffset] > rightBound) outputBuf[outputOffset] = (T) rightBound;
else if (inputBuf[inputOffset] < leftBound) outputBuf[outputOffset] = (T) leftBound;
else outputBuf[outputOffset] = inputBuf[outputOffset];
}
}
}
template <typename T>
static void clipByValue_(sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output) {
auto stream = context->getCudaStream();
if (!input.isActualOnDeviceSide())
input.syncToDevice();
NDArray::prepareSpecialUse({&output}, {&input});
clipByValueKernel<T><<<256, 512, 8192, *stream>>>(input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), leftBound, rightBound);
NDArray::registerSpecialUse({&output}, {&input});
}
void clipByValue(sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output) {
BUILD_SINGLE_SELECTOR(input.dataType(), clipByValue_, (context, input, leftBound, rightBound, output), FLOAT_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void clipByValue_, (sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output);, FLOAT_TYPES);
}
}
}

View File

@ -210,14 +210,10 @@ namespace helpers {
}
//////////////////////////////////////////////////////////////////////////
void reverse(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector<int>* intArgs, bool isBackProp) {
// we need to reverse axis only if that's new op
std::vector<int> dimensions = isBackProp ? ShapeUtils::evalDimsToExclude(input->rankOf(), *intArgs) : *intArgs;
std::vector<int> axis = ShapeUtils::evalDimsToExclude(input->rankOf(), dimensions);
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimensions);
auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimensions);
void reverse(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector<int>* intArgs) {
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), *intArgs);
auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), *intArgs);
NDArray::prepareSpecialUse({output}, {input});

View File

@ -300,269 +300,6 @@ void tileBP(sd::LaunchContext * context, const NDArray& gradO /*input*/, NDArray
manager.synchronize();
}
//////////////////////////////////////////////////////////////////////////
// x - input, y - gradO, z - gradI
template<typename X, typename Z>
__global__ static void clipByNormBPWholeArrCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, void* vreducBuff, const Z clipNormVal) {
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
if(tid >= shape::length(zShapeInfo))
return;
const auto x = reinterpret_cast<const X*>(vx);
const auto y = reinterpret_cast<const Z*>(vy);
auto z = reinterpret_cast<Z*>(vz);
auto reducBuff = reinterpret_cast<Z*>(vreducBuff);
uint* count = reinterpret_cast<uint*>(vreducBuff) + 16384;
__shared__ Z* shMem;
__shared__ Nd4jLong len;
__shared__ bool amIinLastBlock;
if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
shMem = reinterpret_cast<Z*>(shmem);
len = shape::length(zShapeInfo); // xLen = yLen = zLen
}
__syncthreads();
// fill shared memory with array elements
const auto xVal = x[shape::getIndexOffset(tid, xShapeInfo)];
const auto yVal = y[shape::getIndexOffset(tid, yShapeInfo)];
shMem[2*threadIdx.x] = static_cast<Z>(xVal * xVal); // for norm
shMem[2*threadIdx.x + 1] = static_cast<Z>(xVal * yVal); // for input * gradO
__syncthreads();
// accumulate sum per block
for (int activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) {
if (threadIdx.x < activeThreads && tid + activeThreads < len) {
shMem[2*threadIdx.x] += shMem[2*(threadIdx.x + activeThreads)];
shMem[2*threadIdx.x + 1] += shMem[2*(threadIdx.x + activeThreads) + 1];
}
__syncthreads();
}
// store accumulated sums in reduction buffer (reducBuff)
if (threadIdx.x == 0) {
reducBuff[2*blockIdx.x] = shMem[0];
reducBuff[2*blockIdx.x + 1] = shMem[1];
__threadfence();
amIinLastBlock = gridDim.x == 1 || (atomicInc(count, gridDim.x) == gridDim.x - 1);
}
__syncthreads();
// shared memory of last block is used for final summation of values stored in reduction buffer
if (amIinLastBlock) {
for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) {
shMem[2*threadIdx.x] = (i == threadIdx.x ) ? reducBuff[2*i] : reducBuff[2*i] + shMem[2*threadIdx.x];
shMem[2*threadIdx.x + 1] = (i == threadIdx.x ) ? reducBuff[2*i + 1] : reducBuff[2*i + 1] + shMem[2*threadIdx.x + 1];
}
__syncthreads();
// accumulate sum
for (int activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) {
if (threadIdx.x < activeThreads && threadIdx.x + activeThreads < gridDim.x) {
shMem[2*threadIdx.x] += shMem[2*(threadIdx.x + activeThreads)];
shMem[2*threadIdx.x + 1] += shMem[2*(threadIdx.x + activeThreads) + 1];
}
__syncthreads();
}
if (threadIdx.x == 0) {
reducBuff[0] = math::nd4j_sqrt<Z,Z>(shMem[0]);
reducBuff[1] = shMem[1];
count = 0;
}
}
}
//////////////////////////////////////////////////////////////////////////
// x - input, y - gradO, z - gradI
template<typename X, typename Z>
__global__ static void clipByNormBPCalcGradCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, void* vreducBuff, const Z clipNormVal) {
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const Nd4jLong len = shape::length(zShapeInfo); // xLen = yLen = zLen
if(tid >= len)
return;
const auto x = reinterpret_cast<const X*>(vx);
const auto y = reinterpret_cast<const Z*>(vy);
auto z = reinterpret_cast<Z*>(vz);
__shared__ Z norm, sumOfProd;
if (threadIdx.x == 0) {
norm = reinterpret_cast<Z*>(vreducBuff)[0];
sumOfProd = reinterpret_cast<Z*>(vreducBuff)[1];
}
__syncthreads();
const auto yOffset = shape::getIndexOffset(tid, yShapeInfo);
const auto zOffset = shape::getIndexOffset(tid, zShapeInfo);
if(norm > clipNormVal) {
const auto xOffset = shape::getIndexOffset(tid, xShapeInfo);
const Z factor1 = static_cast<Z>(1) / norm; // 1 / norm
const Z factor2 = factor1 / (norm * norm); // 1 / (norm * norm * norm)
z[zOffset] = clipNormVal * (factor1 * y[yOffset] - factor2 * sumOfProd * x[xOffset]);
}
else {
z[zOffset] = y[yOffset];
}
}
//////////////////////////////////////////////////////////////////////////
// x - input, y - gradO, z - gradI
template<typename X, typename Z>
__global__ static void clipByNormBPTadsCuda(const void* vx, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const void* vy, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, void* vz, const Nd4jLong* zTadShapeInfo, const Nd4jLong* zTadOffsets, const Z clipNormVal) {
const auto x = reinterpret_cast<const X*>(vx);
const auto y = reinterpret_cast<const Z*>(vy);
auto z = reinterpret_cast<Z*>(vz);
__shared__ Z* shMem;
__shared__ Nd4jLong tadLen;
if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
shMem = reinterpret_cast<Z*>(shmem);
tadLen = shape::length(zTadShapeInfo); // xTadLen = yTadLen = zTadLen
}
__syncthreads();
const auto* xTad = x + xTadOffsets[blockIdx.x];
const auto* yTad = y + yTadOffsets[blockIdx.x];
auto* zTad = z + zTadOffsets[blockIdx.x];
// *** FIRST STAGE - ACCUMULATE REQUIRED SUMS *** //
Z norm = 0;
Z sumOfProd = 0;
for (uint i = threadIdx.x; i < tadLen; i += blockDim.x) {
const auto xOffset = shape::getIndexOffset(i, xTadShapeInfo);
const auto yOffset = shape::getIndexOffset(i, yTadShapeInfo);
shMem[2*threadIdx.x] = static_cast<Z>(xTad[xOffset] * xTad[xOffset]); // for norm
shMem[2*threadIdx.x + 1] = static_cast<Z>(xTad[xOffset] * yTad[yOffset]); // for input * gradO
__syncthreads();
// accumulate sum per block
for (uint activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) {
if (threadIdx.x < activeThreads && i + activeThreads < tadLen) {
shMem[2*threadIdx.x] += shMem[2*(threadIdx.x + activeThreads)];
shMem[2*threadIdx.x + 1] += shMem[2*(threadIdx.x + activeThreads) + 1];
}
__syncthreads();
}
norm += shMem[0];
sumOfProd += shMem[1];
}
// *** SECOND STAGE - GRADIENT CALCULATION *** //
norm = math::nd4j_sqrt<Z,Z>(norm);
for (uint i = threadIdx.x; i < tadLen; i += blockDim.x) {
const auto yOffset = shape::getIndexOffset(i, yTadShapeInfo);
const auto zOffset = shape::getIndexOffset(i, zTadShapeInfo);
if(norm > clipNormVal) {
const auto xOffset = shape::getIndexOffset(i, xTadShapeInfo);
const Z factor1 = static_cast<Z>(1) / norm; // 1 / norm
const Z factor2 = factor1 / (norm * norm); // 1 / (norm * norm * norm)
zTad[zOffset] = clipNormVal * (factor1 * yTad[yOffset] - factor2 * sumOfProd * xTad[xOffset]);
}
else {
zTad[zOffset] = yTad[yOffset];
}
}
}
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Z>
static void clipByNormBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream,
const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets,
const void* vy, const Nd4jLong* yShapeInfo, const Nd4jLong* yTadOffsets,
void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets,
void* vreducBuff, const double clipNormVal) {
if(xTadOffsets == nullptr) { // means whole array
clipByNormBPWholeArrCuda<X,Z><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vreducBuff, static_cast<Z>(clipNormVal));
clipByNormBPCalcGradCuda<X,Z><<<blocksPerGrid, threadsPerBlock, 256, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vreducBuff, static_cast<Z>(clipNormVal));
}
else // means tads using
clipByNormBPTadsCuda<X,Z><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, xTadOffsets, vy, yShapeInfo, yTadOffsets, vz, zShapeInfo, zTadOffsets, static_cast<Z>(clipNormVal));
}
BUILD_DOUBLE_TEMPLATE(template void clipByNormBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong* xTadOffsets, const void *vy, const Nd4jLong *yShapeInfo, const Nd4jLong* yTadOffsets, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, void* vreducBuff, const double clipNormVal), FLOAT_TYPES, FLOAT_TYPES);
//////////////////////////////////////////////////////////////////////////
void clipByNormBP(sd::LaunchContext* context, const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector<int>& dimensions, const NDArray& clipNorm) {
PointersManager manager(context, "clipByNormBP");
const double clipNormVal = clipNorm.e<double>(0);
const auto xType = input.dataType();
const auto zType = gradI.dataType();
const int threadsPerBlock = MAX_NUM_THREADS / 2;
const int sharedMem = threadsPerBlock * 2 * input.sizeOfT() + 128;
NDArray::prepareSpecialUse({&gradI}, {&input, &gradO});
if(dimensions.empty() || dimensions.size() == input.rankOf()) { // means whole array
const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
BUILD_DOUBLE_SELECTOR(xType, zType, clipByNormBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), nullptr, gradO.specialBuffer(), gradO.specialShapeInfo(), nullptr, gradI.specialBuffer(), gradI.specialShapeInfo(), nullptr, context->getReductionPointer(), clipNormVal), FLOAT_TYPES, FLOAT_TYPES);
}
else { // means tads using
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dimensions);
auto packY = ConstantTadHelper::getInstance()->tadForDimensions(gradO.shapeInfo(), dimensions);
auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(gradI.shapeInfo(), dimensions);
const int blocksPerGrid = packX.numberOfTads();
BUILD_DOUBLE_SELECTOR(xType, zType, clipByNormBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), gradO.specialBuffer(), packY.platformShapeInfo(), packY.platformOffsets(), gradI.specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), nullptr, clipNormVal), FLOAT_TYPES, FLOAT_TYPES);
}
NDArray::registerSpecialUse({&gradI}, {&input, &gradO});
manager.synchronize();
}
template <typename T>
static __global__ void swapShuffleKernel(T* input, Nd4jLong const* shape, Nd4jLong firstDim, sd::graph::RandomGenerator* rng) {
auto tid = blockIdx.x * blockDim.x;
@ -692,252 +429,6 @@ void clipByNormBP(sd::LaunchContext* context, const NDArray& input, const NDArra
output.setIdentity();
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
static __global__ void clipByNormInplaceKernel(Nd4jLong numOfSubArrs, T* inputBuffer, Nd4jLong const* shape, Nd4jLong const* inputOffsets, T* norm2Buf, Nd4jLong const* norm2shape, T clipNorm) {
for (int arr = blockIdx.x; arr < numOfSubArrs; arr += gridDim.x) {
__shared__ T* z;
__shared__ Nd4jLong len;
if (threadIdx.x == 0) {
len = shape::length(shape);
z = inputBuffer + inputOffsets[arr];
}
__syncthreads();
for (int j = threadIdx.x; j < len; j+= blockDim.x) {
auto xIndex = shape::getIndexOffset(j, shape);
if(norm2Buf[arr] > clipNorm)
z[xIndex] *= clipNorm / norm2Buf[arr]; // case with ews = 1 and ordering is 'c'
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
static __global__ void clipByNormKernel(Nd4jLong numOfSubArrs, T* inputBuffer, Nd4jLong const* shape, Nd4jLong const* inputOffsets, T* outputBuffer, Nd4jLong const* outputShape, Nd4jLong const* outputOffsets, T* norm2Buf, Nd4jLong const* norm2shape, T clipNorm) {
for (Nd4jLong arr = blockIdx.x; arr < numOfSubArrs; arr += gridDim.x) {
__shared__ T* x, *z;
__shared__ Nd4jLong lenZ;
__shared__ T norm2;
if (threadIdx.x == 0) {
x = inputBuffer + inputOffsets[arr];
z = outputBuffer + outputOffsets[arr];
lenZ = shape::length(outputShape);
norm2 = norm2Buf[shape::getIndexOffset(arr, norm2shape)];
}
__syncthreads();
for (Nd4jLong j = threadIdx.x; j < lenZ; j+= blockDim.x) {
auto xIndex = shape::getIndexOffset(j, shape);
auto zIndex = shape::getIndexOffset(j, outputShape);
if(norm2 > clipNorm) {
z[zIndex] = x[xIndex] * clipNorm / norm2; // case with ews = 1 and ordering is 'c'
} else {
z[zIndex] = x[xIndex];
}
//printf("%lld: %lf %lf\n", j, z[zIndex], x[xIndex]);
}
__syncthreads();
}
}
//////////////////////////////////////////////////////////////////////////
template<typename T>
static void clipByNorm_(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, NDArray const& clipNormA, const bool isInplace) {
const int rank = input.rankOf();
auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions);
clipNormA.syncToHost();
//norm2.printBuffer("Norm2");
T const clipNorm = clipNormA.e<T>(0);
//clipNormA.printBuffer("ClipNorm");
auto stream = context->getCudaStream();
if (isInplace) {
if(norm2.lengthOf() == 1) {
norm2.syncToHost();
T norm2Val = norm2.e<T>(0);
if(norm2Val > clipNorm)
input *= clipNorm / norm2Val;
}
else {
std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(rank, dimensions);
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(input.shapeInfo(), dimsToExclude);
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dimensions);
//auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output.shapeInfo(), dimsToExclude);
T* inputBuffer = reinterpret_cast<T*>(input.specialBuffer());
T* norm2buf = reinterpret_cast<T*>(norm2.specialBuffer());
clipByNormInplaceKernel<T><<<256, 512, 1024, *stream>>>(numOfSubArrs, inputBuffer, packX.specialShapeInfo(), packX.specialOffsets(), norm2buf, norm2.specialShapeInfo(), clipNorm);
}
}
else {
if(norm2.lengthOf() == 1) {
norm2.syncToHost();
T norm2Val = norm2.e<T>(0);
if(norm2Val > clipNorm)
output.assign( input * (clipNorm / norm2Val));
else
output.assign( input );
}
else {
std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(rank, dimensions);
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(input.shapeInfo(), dimsToExclude);
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dimensions);
auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output.shapeInfo(), dimensions);
T* inputBuffer = reinterpret_cast<T*>(input.specialBuffer());
T* norm2buf = reinterpret_cast<T*>(norm2.specialBuffer());
T* outputBuffer = reinterpret_cast<T*>(output.specialBuffer());
clipByNormKernel<T><<<256, 512, 1024, *stream>>>(numOfSubArrs, inputBuffer, packX.specialShapeInfo(), packX.specialOffsets(), outputBuffer, packZ.specialShapeInfo(), packZ.specialOffsets(), norm2buf, norm2.specialShapeInfo(), clipNorm);
}
}
}
void clipByNorm(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace) {
BUILD_SINGLE_SELECTOR(output.dataType(), clipByNorm_, (context, input, output, dimensions, clipNorm, isInplace), FLOAT_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void clipByNorm_, (sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace), FLOAT_TYPES);
template <typename T>
void clipByGlobalNorm_(sd::LaunchContext * context, std::vector<NDArray*> const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace) {
NDArray globalNorm = NDArrayFactory::create<T>(0, inputs[0]->getContext()); //sqrt(sum([l2norm(t)**2 for t in t_list]))
for (auto i = 0; i < inputs.size(); i++) {
auto input = inputs[i];
auto l2norm = input->reduceNumber(reduce::Norm2);
globalNorm += l2norm * l2norm;
}
globalNorm.applyTransform(transform::Sqrt, globalNorm); // = sd::math::nd4j_sqrt(globalNorm);
outputs[inputs.size()]->p(0, globalNorm);
globalNorm.syncToHost();
const T factor = static_cast<T>(clipNorm) / globalNorm.e<T>(0);
for (size_t e = 0; e < inputs.size(); e++) {
// all-reduce
auto input = inputs[e];
auto output = outputs[e];
if (globalNorm.e<double>(0) <= clipNorm) {
output->assign(input);
}
else {
auto lambda = LAMBDA_T(_x, factor) { return _x * factor; };
input->applyLambda(lambda, *output);
}
}
}
void clipByGlobalNorm(sd::LaunchContext * context, std::vector<NDArray*> const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace) {
BUILD_SINGLE_SELECTOR(outputs[0]->dataType(), clipByGlobalNorm_, (context, inputs, clipNorm, workspace, outputs, isInplace), FLOAT_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void clipByGlobalNorm_, (sd::LaunchContext * context, std::vector<NDArray*> const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace), FLOAT_TYPES);
//////////////////////////////////////////////////////////////////////////
template<typename T>
static void clipByAveraged_(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace) {
auto cn = clipNorm.e<T>(0);
if (dimensions.size() == 0) {
// all-reduce
T n2 = input.reduceNumber(reduce::Norm2).e<T>(0) / static_cast<T>(input.lengthOf());
if (n2 <= cn) {
if (!isInplace)
output.assign(input);
}
else {
const T factor = cn / n2;
//auto lambda = LAMBDA_T(_x, factor) { return _x * factor; };
//input.applyLambda<T>(lambda, output);
output.assign(input * factor);
}
}
else {
// along dimension
auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions, false);
if (!isInplace)
output.assign(input);
auto tads = output.allTensorsAlongDimension(dimensions);
auto outTads = output.allTensorsAlongDimension(dimensions);
// TODO: make this CUDA-compliant somehow
for (int e = 0; e < tads.size(); e++) {
T n2 = norm2.e<T>(e) / static_cast<T>(tads.at(e)->lengthOf());
const T factor = cn / n2;
if (n2 > cn) {
//auto lambda = LAMBDA_T(_x, factor) {return _x * factor;};
tads.at(e)->applyScalar(scalar::Multiply, factor, *outTads.at(e));//applyLambda<T>(lambda, &output);
}
}
}
}
void clipByAveraged(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace) {
BUILD_SINGLE_SELECTOR(input.dataType(), clipByAveraged_, (context, input, output, dimensions, clipNorm, isInplace), FLOAT_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void clipByAveraged_, (sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace), FLOAT_TYPES);
/*
if (d1 > params[1])
return params[1];
else if (d1 < params[0])
return params[0];
else return d1;
*/
template <typename T>
static void __global__ clipByValueKernel(void* input, Nd4jLong const* inputShape, void* output, Nd4jLong const* outputShape, double leftBound, double rightBound) {
__shared__ T* outputBuf;
__shared__ T* inputBuf;
__shared__ Nd4jLong length;
__shared__ bool linearBuffers;
if (threadIdx.x == 0) {
outputBuf = reinterpret_cast<T *>(output);
inputBuf = reinterpret_cast<T *>(input);
length = shape::length(inputShape);
linearBuffers = shape::elementWiseStride(inputShape) == shape::elementWiseStride(outputShape) && shape::elementWiseStride(inputShape) == 1;
}
__syncthreads();
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
const auto step = gridDim.x * blockDim.x;
for (Nd4jLong e = tid; e < length; e += step) {
if (linearBuffers) {
if (inputBuf[e] > rightBound) outputBuf[e] = (T) rightBound;
else if (inputBuf[e] < leftBound) outputBuf[e] = (T) leftBound;
else outputBuf[e] = inputBuf[e];
}
else {
auto inputOffset = shape::getIndexOffset(e, inputShape);
auto outputOffset = shape::getIndexOffset(e, outputShape);
if (inputBuf[inputOffset] > rightBound) outputBuf[outputOffset] = (T) rightBound;
else if (inputBuf[inputOffset] < leftBound) outputBuf[outputOffset] = (T) leftBound;
else outputBuf[outputOffset] = inputBuf[outputOffset];
}
}
}
template <typename T>
static void clipByValue_(sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output) {
auto stream = context->getCudaStream();
if (!input.isActualOnDeviceSide())
input.syncToDevice();
NDArray::prepareSpecialUse({&output}, {&input});
clipByValueKernel<T><<<256, 512, 8192, *stream>>>(input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), leftBound, rightBound);
NDArray::registerSpecialUse({&output}, {&input});
}
void clipByValue(sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output) {
BUILD_SINGLE_SELECTOR(input.dataType(), clipByValue_, (context, input, leftBound, rightBound, output), FLOAT_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void clipByValue_, (sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output);, FLOAT_TYPES);
}
}
}

View File

@ -29,9 +29,9 @@ namespace helpers {
void reverseSequence(sd::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim);
void reverse(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector<int>* intArgs, bool isBackProp);
void reverse(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector<int>* intArgs);
}
}

View File

@ -63,13 +63,13 @@ namespace helpers {
void mergeAdd(sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray& output);
void mergeAddBp(sd::LaunchContext* context, const NDArray& gradient, std::vector<NDArray*>& outArrs);
void clipByNorm(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace);
void clipByNorm(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace, const bool useAverage);
void clipByGlobalNorm(sd::LaunchContext * context, std::vector<NDArray*> const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace);
void clipByNormBP(sd::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector<int>& dimensions, const NDArray& clipNorm);
void clipByNormBp(sd::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool useAverage);
void clipByAveraged(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace);
void clipByValue(sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output);
void clipByAveragedNorm(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace);
void mirrorPad(sd::LaunchContext * context, const NDArray& input, const NDArray& paddings, NDArray& output, const int mode);

View File

@ -1093,7 +1093,7 @@ namespace sd {
return ND4J_STATUS_OK;
NDArray *a0 = block.array(0);
for (int e = 0; e < block.width(); e++) {
for (int e = 1; e < block.width(); e++) {
auto aV = block.array(e);
if (!shape::equalsSoft(a0->shapeInfo(), aV->shapeInfo()))
return ND4J_STATUS_BAD_DIMENSIONS;

View File

@ -90,13 +90,12 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
// x
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format);
dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format);
mkldnnUtils::setBlockStrides(*x, x_user_md);
mkldnnUtils::setBlockStrides(x, x_user_md);
// z, output
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, format);
mkldnnUtils::setBlockStrides(z, z_user_md);
mkldnnUtils::setBlockStrides(*z, z_user_md);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -112,15 +111,10 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
// provide memory and check whether reorder is required
// x
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// z
auto z_user_mem = dnnl::memory(z_user_md, engine, z->buffer());
const bool zReorder = op_ff_prim_desc.dst_desc() != z_user_mem.get_desc();
auto z_mkl_mem = zReorder ? dnnl::memory(op_ff_prim_desc.dst_desc(), engine) : z_user_mem;
if (zReorder)
dnnl::reorder(z_user_mem, z_mkl_mem).execute(stream, z_user_mem, z_mkl_mem);
args[DNNL_ARG_DST] = z_mkl_mem;
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*z, engine, stream, z_user_md, op_ff_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
// mean
auto mean_mkl_mem = dnnl::memory(op_ff_prim_desc.mean_desc(), engine, const_cast<void*>(mean->buffer()));
@ -141,8 +135,8 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
dnnl::batch_normalization_forward(op_ff_prim_desc).execute(stream, args);
// reorder outputs if necessary
if (zReorder)
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
if (op_ff_prim_desc.dst_desc() != z_user_mem.get_desc())
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
stream.wait();
@ -151,7 +145,7 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
//////////////////////////////////////////////////////////////////////////
static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray &dLdO, const NDArray* weights,
static void batchnormBpMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray &dLdO, const NDArray* weights,
NDArray* dLdI, NDArray* dLdW, const float epsilon, const bool isNCHW) {
// unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for x
@ -206,20 +200,17 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
// x
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format);
dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format);
mkldnnUtils::setBlockStrides(x, x_user_md);
mkldnnUtils::setBlockStrides(*x, x_user_md);
// dLdO
dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format);
mkldnnUtils::setBlockStrides(&dLdO, dLdO_user_md);
mkldnnUtils::setBlockStrides(dLdO, dLdO_user_md);
// dLdI
dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc dLdI_user_md = dnnl::memory::desc(dims, type, format);
mkldnnUtils::setBlockStrides(dLdI, dLdI_user_md);
mkldnnUtils::setBlockStrides(*dLdI, dLdI_user_md);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -239,10 +230,10 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
// provide memory and check whether reorder is required
// x
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// dLdO
mkldnnUtils::loadDataToMklStream(&dLdO, engine, stream, dLdO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
mkldnnUtils::loadDataToMklStream(dLdO, engine, stream, dLdO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
// mean
auto mean_mkl_mem = dnnl::memory(op_bp_prim_desc.mean_desc(), engine, const_cast<void*>(mean->buffer()));
@ -253,10 +244,7 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
args[DNNL_ARG_VARIANCE] = var_mkl_mem;
// dLdI
auto dLdI_user_mem = dnnl::memory(dLdI_user_md, engine, dLdI->buffer());
const bool dLdIReorder = op_bp_prim_desc.diff_src_desc() != dLdI_user_mem.get_desc();
auto dLdI_mkl_mem = dLdIReorder ? dnnl::memory(op_bp_prim_desc.diff_src_desc(), engine) : dLdI_user_mem;
args[DNNL_ARG_DIFF_SRC] = dLdI_mkl_mem;
auto dLdI_user_mem = mkldnnUtils::loadDataToMklStream(*dLdI, engine, stream, dLdI_user_md, op_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]);
// gamma and beta (and their gradients) if they are present
if(weights != nullptr) {
@ -272,8 +260,8 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
dnnl::batch_normalization_backward(op_bp_prim_desc).execute(stream, args);
// reorder outputs if necessary
if (dLdIReorder)
dnnl::reorder(dLdI_mkl_mem, dLdI_user_mem).execute(stream, dLdI_mkl_mem, dLdI_user_mem);
if (op_bp_prim_desc.diff_src_desc() != dLdI_user_mem.get_desc())
dnnl::reorder(args[DNNL_ARG_DIFF_SRC], dLdI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], dLdI_user_mem);
stream.wait();
@ -662,9 +650,9 @@ PLATFORM_IMPL(batchnorm_bp, ENGINE_CPU) {
const bool isNCHW = !(axes[0] == inRank - 1 && inRank > 2);
if (shape::strideDescendingCAscendingF(dLdO->shapeInfo()))
batchnormBackPropMKLDNN(input, mean, variance, *dLdO, weights, dLdI, dLdW, epsilon, isNCHW);
batchnormBpMKLDNN(input, mean, variance, *dLdO, weights, dLdI, dLdW, epsilon, isNCHW);
else
batchnormBackPropMKLDNN(input, mean, variance, dLdO->dup(), weights, dLdI, dLdW, epsilon, isNCHW);
batchnormBpMKLDNN(input, mean, variance, dLdO->dup(), weights, dLdI, dLdW, epsilon, isNCHW);
*dLdM = 0;
*dLdV = 0;

View File

@ -0,0 +1,186 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include <ops/declarable/PlatformHelper.h>
#include <ops/declarable/OpRegistrator.h>
#include <system/platform_boilerplate.h>
#include <helpers/MKLDNNStream.h>
#include "mkldnnUtils.h"
#include <numeric>
namespace sd {
namespace ops {
namespace platforms {
//////////////////////////////////////////////////////////////////////////
static void concatMKLDNN(const std::vector<const NDArray*>& inArrs, NDArray& output, const int axis) {
// data type
dnnl::memory::data_type type;
if(output.dataType() == DataType::FLOAT32)
type = dnnl::memory::data_type::f32;
else if(output.dataType() == DataType::HALF)
type = dnnl::memory::data_type::f16;
else if(output.dataType() == DataType::BFLOAT16)
type = dnnl::memory::data_type::bf16;
else if(output.dataType() == DataType::UINT8)
type = dnnl::memory::data_type::u8;
else
type = dnnl::memory::data_type::s8;
std::vector<dnnl::memory::desc> x_user_md(inArrs.size()), x_mkl_md(inArrs.size());
// inputs
for (int i = 0; i < inArrs.size(); ++i) {
dnnl::memory::dims dims = inArrs[i]->getShapeAsFlatVector();
x_user_md[i] = x_mkl_md[i] = dnnl::memory::desc(dims, type, mkldnnUtils::getFormat(*inArrs[i]));
mkldnnUtils::setBlockStrides(*inArrs[i], x_user_md[i]);
}
// output
dnnl::memory::dims dims = output.getShapeAsFlatVector();
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, mkldnnUtils::getFormat(output));
mkldnnUtils::setBlockStrides(output, z_user_md);
std::unordered_map<int, dnnl::memory> args;
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
dnnl::concat::primitive_desc op_prim_desc(axis, x_mkl_md, engine);
dnnl::stream stream(engine);
// inputs
for (int i = 0; i < inArrs.size(); ++i)
mkldnnUtils::loadDataToMklStream(*inArrs[i], engine, stream, x_user_md[i], op_prim_desc.src_desc(i), args[DNNL_ARG_MULTIPLE_SRC + i]);
// outputs
auto z_user_mem = mkldnnUtils::loadDataToMklStream(output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
// primitive execution
dnnl::concat(op_prim_desc).execute(stream, args);
// reorder output if necessary
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
stream.wait();
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(concat, ENGINE_CPU) {
REQUIRE_TRUE(block.width() > 0, 0, "CONCAT MKLDNN op: No input arrays were provided");
const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0);
const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width();
// first of all take into account possible presence of empty arrays
// also if scalar is present -> copy its value to vector with length=1
std::vector<const NDArray*> nonEmptyArrs;
std::vector<int> arrsToDelete;
int index = 0;
bool allOfSameType = true;
auto rankOfFirstArr = block.width() > 0 ? INPUT_VARIABLE(0)->rankOf() : 0;
auto typeOfFirstArr = block.width() > 0 ? INPUT_VARIABLE(0)->dataType() : block.dataType();
for(int i = 0; i < numOfInArrs; ++i) {
auto input = INPUT_VARIABLE(i);
auto currentRank = input->rankOf();
if(!input->isEmpty()) {
allOfSameType &= (typeOfFirstArr == input->dataType());
if(input->rankOf() == 0) {
auto vec = new NDArray('c', {1}, input->dataType(), block.launchContext());
vec->assign(input);
nonEmptyArrs.push_back(vec);
arrsToDelete.push_back(index);
}
else{
nonEmptyArrs.push_back(input);
}
++index;
}
}
const int numOfNonEmptyArrs = nonEmptyArrs.size();
if(numOfNonEmptyArrs == 0){
//All inputs are empty arrays -> return empty, mainly for TF import compatibility (no op)
REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "CONCAT MKLDNN op: If all input variables are empty, output must be empty");
return Status::OK();
}
const int rank = nonEmptyArrs[0]->rankOf(); // look up to first non-empty array
int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e<int>(0) : INT_ARG(0);
if(axis < 0){
axis += rank;
}
// ******** input validation ******** //
REQUIRE_TRUE(allOfSameType, 0, "CONCAT MKLDNN op: all of input arrays must have same type !");
REQUIRE_TRUE(nonEmptyArrs[0]->dataType() == OUTPUT_VARIABLE(0)->dataType(), 0, "CONCAT MKLDNN op: output array should have the same type as inputs arrays !");
REQUIRE_TRUE(0 <= axis && (axis < rank || (axis == 0 && rank == 0)), 0, "CONCAT MKLDNN op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis);
for(int i = 1; i < numOfNonEmptyArrs; ++i)
REQUIRE_TRUE(nonEmptyArrs[i]->rankOf() == rank, 0, "CONCAT MKLDNN op: all input arrays must have the same rank !");
for(int i = 1; i < numOfNonEmptyArrs; ++i) {
for(int dim = 0; dim < rank; ++dim)
if(dim != axis)
REQUIRE_TRUE(nonEmptyArrs[i]->sizeAt(dim) == nonEmptyArrs[0]->sizeAt(dim), 0, "CONCAT MKLDNN op: all input arrays must have the same dimensions (except those on input axis) !");
}
// ******** end of input validation ******** //
auto output = OUTPUT_VARIABLE(0);
if(numOfNonEmptyArrs == 1)
output->assign(nonEmptyArrs[0]);
else
concatMKLDNN(nonEmptyArrs, *output, axis);
// delete dynamically allocated vectors with length=1
for(int index : arrsToDelete)
delete nonEmptyArrs[index];
return Status::OK();
}
//////////////////////////////////////////////////////////////////////////
PLATFORM_CHECK(concat, ENGINE_CPU) {
auto z = OUTPUT_VARIABLE(0);
const auto zType = z->dataType();
return z->rankOf() < 7 && (zType==DataType::FLOAT32 || zType==DataType::HALF || zType==DataType::BFLOAT16 || zType==DataType::UINT8 || zType==DataType::INT8);
}
}
}
}

View File

@ -62,33 +62,23 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
auto type = dnnl::memory::data_type::f32;
std::vector<int> permut;
if(0 == wFormat)
permut = {3,2,0,1}; // [kH, kW, iC, oC] -> [oC, iC, kH, kW]
else if(2 == wFormat)
permut = {0,3,1,2}; // [oC, kH, kW, iC] -> [oC, iC, kH, kW]
// memory descriptors for arrays
// input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
mkldnnUtils::setBlockStrides(input, x_user_md);
mkldnnUtils::setBlockStrides(*input, x_user_md);
// weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) {
w_user_md.data.format_kind = dnnl_blocked; // overrides format
uint i0, i1, i2, i3;
if(0 == wFormat) {
i0 = 3; i1 = 2; i2 = 0; i3 = 1; // [kH, kW, iC, oC] -> [oC, iC, kH, kW]
}
else if(1 == wFormat) {
i0 = 0; i1 = 1; i2 = 2; i3 = 3;
}
else {
i0 = 0; i1 = 3; i2 = 1; i3 = 2; // [oC, kH, kW, iC] -> [oC, iC, kH, kW]
}
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
}
mkldnnUtils::setBlockStrides(*weights, w_user_md, permut);
// bias
dnnl::memory::desc b_mkl_md;
@ -98,7 +88,7 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
// output
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl);
mkldnnUtils::setBlockStrides(output, z_user_md);
mkldnnUtils::setBlockStrides(*output, z_user_md);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -114,10 +104,10 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
// provide memory buffers and check whether reorder is required
// input
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// weights
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
// bias
if(bias != nullptr) {
@ -126,17 +116,14 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
}
// output
auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer());
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
args[DNNL_ARG_DST] = z_mkl_mem;
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
// run calculations
dnnl::convolution_forward(op_prim_desc).execute(stream, args);
// reorder outputs if necessary
if (zReorder)
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
stream.wait();
// shape::printArray(z_mkl_mem.map_data<float>(),8);
@ -170,64 +157,38 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
auto type = dnnl::memory::data_type::f32;
std::vector<int> permut;
if(0 == wFormat)
permut = {3,2,0,1}; // [kH, kW, iC, oC] -> [oC, iC, kH, kW]
else if(2 == wFormat)
permut = {0,3,1,2}; // [oC, kH, kW, iC] -> [oC, iC, kH, kW]
// memory descriptors for arrays
// input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
mkldnnUtils::setBlockStrides(input, x_user_md);
mkldnnUtils::setBlockStrides(*input, x_user_md);
// weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) {
w_user_md.data.format_kind = dnnl_blocked; // overrides format
uint i0, i1, i2, i3;
if(0 == wFormat) {
i0 = 3; i1 = 2; i2 = 0; i3 = 1; // [kH, kW, iC, oC] -> [oC, iC, kH, kW]
}
else if(1 == wFormat) {
i0 = 0; i1 = 1; i2 = 2; i3 = 3;
}
else {
i0 = 0; i1 = 3; i2 = 1; i3 = 2; // [oC, kH, kW, iC] -> [oC, iC, kH, kW]
}
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
}
mkldnnUtils::setBlockStrides(*weights, w_user_md, permut);
// gradO
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl);
mkldnnUtils::setBlockStrides(gradO, gradO_user_md);
mkldnnUtils::setBlockStrides(*gradO, gradO_user_md);
// gradI
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
mkldnnUtils::setBlockStrides(gradI, gradI_user_md);
mkldnnUtils::setBlockStrides(*gradI, gradI_user_md);
// gradW
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
if(gradW->ews() != 1 || gradW->ordering() != 'c' || 1 != wFormat) {
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
uint i0, i1, i2, i3;
if(0 == wFormat) {
i0 = 3; i1 = 2; i2 = 0; i3 = 1; // [kH, kW, iC, oC] -> [oC, iC, kH, kW]
}
else if(1 == wFormat) {
i0 = 0; i1 = 1; i2 = 2; i3 = 3;
}
else {
i0 = 0; i1 = 3; i2 = 1; i3 = 2; // [oC, kH, kW, iC] -> [oC, iC, kH, kW]
}
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0);
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1);
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2);
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3);
}
mkldnnUtils::setBlockStrides(*gradW, gradW_user_md, permut);
// gradB
dnnl::memory::desc gradB_mkl_md;
@ -256,10 +217,10 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
// provide memory buffers and check whether reorder is required
// input
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// weights
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
// gradO
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast<void*>(gradO->buffer()));
@ -274,16 +235,10 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
// gradI
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer());
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_data_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]);
// gradW
auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->buffer());
const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc();
auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
auto gradW_user_mem = mkldnnUtils::loadDataToMklStream(*gradW, engine, stream, gradW_user_md, op_weights_bp_prim_desc.diff_weights_desc(), args[DNNL_ARG_DIFF_WEIGHTS]);
// gradB
if(gradB != nullptr) {
@ -301,10 +256,10 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
// reorder gradI if necessary
if (gradIReorder)
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
if (gradWReorder)
dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc())
dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem);
if (op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc())
dnnl::reorder(args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem).execute(stream, args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem);
stream.wait();

View File

@ -63,6 +63,12 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
dnnl::memory::dims wDims = {oC, iC, kD, kH, kW};
dnnl::memory::dims zDims = {bS, oC, oD, oH, oW};
std::vector<int> permut;
if(0 == wFormat)
permut = {4,3,0,1,2}; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
else if(2 == wFormat)
permut = {0,4,1,2,3}; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW]
auto type = dnnl::memory::data_type::f32;
// memory descriptors for arrays
@ -70,29 +76,12 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
// input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
mkldnnUtils::setBlockStrides(input, x_user_md);
mkldnnUtils::setBlockStrides(*input, x_user_md);
// weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) {
w_user_md.data.format_kind = dnnl_blocked; // overrides format
uint i0, i1, i2, i3, i4;
if(0 == wFormat) {
i0 = 4; i1 = 3; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
}
else if(1 == wFormat) {
i0 = 0; i1 = 1; i2 = 2; i3 = 3; i4 = 4;
}
else {
i0 = 0; i1 = 4; i2 = 1; i3 = 2; i4 = 3; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW]
}
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4);
}
mkldnnUtils::setBlockStrides(*weights, w_user_md, permut);
// bias
dnnl::memory::desc b_mkl_md;
@ -102,7 +91,7 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
// output
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl);
mkldnnUtils::setBlockStrides(output, z_user_md);
mkldnnUtils::setBlockStrides(*output, z_user_md);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -118,10 +107,10 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
// provide memory buffers and check whether reorder is required
// input
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// weights
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
// bias
if(bias != nullptr) {
@ -130,17 +119,14 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
}
// output
auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer());
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
args[DNNL_ARG_DST] = z_mkl_mem;
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
// run calculations
dnnl::convolution_forward(op_prim_desc).execute(stream, args);
// reorder outputs if necessary
if (zReorder)
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
stream.wait();
}
@ -177,68 +163,40 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
auto type = dnnl::memory::data_type::f32;
std::vector<int> permut;
if(0 == wFormat)
permut = {4,3,0,1,2}; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
else if(2 == wFormat)
permut = {0,4,1,2,3}; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW]
// memory descriptors for arrays
// input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
mkldnnUtils::setBlockStrides(input, x_user_md);
mkldnnUtils::setBlockStrides(*input, x_user_md);
// weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) {
w_user_md.data.format_kind = dnnl_blocked; // overrides format
uint i0, i1, i2, i3, i4;
if(0 == wFormat) {
i0 = 4; i1 = 3; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
}
else if(1 == wFormat) {
i0 = 0; i1 = 1; i2 = 2; i3 = 3; i4 = 4;
}
else {
i0 = 0; i1 = 4; i2 = 1; i3 = 2; i4 = 3; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW]
}
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4);
}
mkldnnUtils::setBlockStrides(*weights, w_user_md, permut);
// gradO
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl);
mkldnnUtils::setBlockStrides(gradO, gradO_user_md);
mkldnnUtils::setBlockStrides(*gradO, gradO_user_md);
// gradI
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
mkldnnUtils::setBlockStrides(gradI, gradI_user_md);
mkldnnUtils::setBlockStrides(*gradI, gradI_user_md);
// gradW
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
if(gradW->ews() != 1 || gradW->ordering() != 'c' || 1 != wFormat) {
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
uint i0, i1, i2, i3, i4;
if(0 == wFormat) {
i0 = 4; i1 = 3; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
}
else if(1 == wFormat) {
i0 = 0; i1 = 1; i2 = 2; i3 = 3; i4 = 4;
}
else {
i0 = 0; i1 = 4; i2 = 1; i3 = 2; i4 = 3; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW]
}
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0);
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1);
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2);
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3);
gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(i4);
}
mkldnnUtils::setBlockStrides(*gradW, gradW_user_md, permut);
// gradB
dnnl::memory::desc gradB_mkl_md;
@ -267,10 +225,10 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
// provide memory buffers and check whether reorder is required
// input
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// weights
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
// gradO
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast<void*>(gradO->buffer()));
@ -285,16 +243,10 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
// gradI
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer());
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_data_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]);
// gradW
auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->buffer());
const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc();
auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
auto gradW_user_mem = mkldnnUtils::loadDataToMklStream(*gradW, engine, stream, gradW_user_md, op_weights_bp_prim_desc.diff_weights_desc(), args[DNNL_ARG_DIFF_WEIGHTS]);
// gradB
if(gradB != nullptr) {
@ -312,10 +264,10 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
// reorder gradI if necessary
if (gradIReorder)
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
if (gradWReorder)
dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc())
dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem);
if (op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc())
dnnl::reorder(args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem).execute(stream, args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem);
stream.wait();

View File

@ -47,16 +47,13 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
dnnl::memory::dims dilation = { dH-1, dW-1 };
uint i0, i1, i2, i3;
if(0 == wFormat) {
i0 = 2; i1 = 3; i2 = 0; i3 = 1; // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
}
else if(1 == wFormat) {
i0 = 1; i1 = 0; i2 = 2; i3 = 3; // [iC, oC, kH, kW] -> [oC, iC, kH, kW]
}
else {
i0 = 3; i1 = 0; i2 = 1; i3 = 2; // [iC, kH, kW, oC] -> [oC, iC, kH, kW]
}
std::vector<int> permut;
if(0 == wFormat)
permut = {2,3,0,1}; // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
else if(1 == wFormat)
permut = {1,0,2,3}; // [iC, oC, kH, kW] -> [oC, iC, kH, kW]
else
permut = {3,0,1,2}; // [iC, kH, kW, oC] -> [oC, iC, kH, kW]
// input type
dnnl::memory::data_type xType;
@ -99,16 +96,12 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
// input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl);
mkldnnUtils::setBlockStrides(input, x_user_md);
mkldnnUtils::setBlockStrides(*input, x_user_md);
// weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
w_user_md.data.format_kind = dnnl_blocked; // overrides format
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
mkldnnUtils::setBlockStrides(*weights, w_user_md, permut);
// bias
dnnl::memory::desc b_mkl_md;
@ -118,7 +111,7 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
// output
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormatMkl);
mkldnnUtils::setBlockStrides(output, z_user_md);
mkldnnUtils::setBlockStrides(*output, z_user_md);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -135,10 +128,10 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
// provide memory buffers and check whether reorder is required
// input
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// weights
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
// bias
if(bias != nullptr) {
@ -147,17 +140,14 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
}
// output
auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer());
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
args[DNNL_ARG_DST] = z_mkl_mem;
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
// run calculations
dnnl::deconvolution_forward(op_prim_desc).execute(stream, args);
// reorder outputs if necessary
if (zReorder)
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
stream.wait();
@ -180,16 +170,13 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
dnnl::memory::dims dilation = { dH-1, dW-1 };
uint i0, i1, i2, i3;
if(0 == wFormat) {
i0 = 2; i1 = 3; i2 = 0; i3 = 1; // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
}
else if(1 == wFormat) {
i0 = 1; i1 = 0; i2 = 2; i3 = 3; // [iC, oC, kH, kW] -> [oC, iC, kH, kW]
}
else {
i0 = 3; i1 = 0; i2 = 1; i3 = 2; // [iC, kH, kW, oC] -> [oC, iC, kH, kW]
}
std::vector<int> permut;
if(0 == wFormat)
permut = {2,3,0,1}; // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
else if(1 == wFormat)
permut = {1,0,2,3}; // [iC, oC, kH, kW] -> [oC, iC, kH, kW]
else
permut = {3,0,1,2}; // [iC, kH, kW, oC] -> [oC, iC, kH, kW]
// input type
dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
@ -216,35 +203,27 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
// input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl);
mkldnnUtils::setBlockStrides(input, x_user_md);
mkldnnUtils::setBlockStrides(*input, x_user_md);
// weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
w_user_md.data.format_kind = dnnl_blocked; // overrides format
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
mkldnnUtils::setBlockStrides(*weights, w_user_md, permut);
// gradO
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl);
mkldnnUtils::setBlockStrides(gradO, gradO_user_md);
mkldnnUtils::setBlockStrides(*gradO, gradO_user_md);
// gradI
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl);
mkldnnUtils::setBlockStrides(gradI, gradI_user_md);
mkldnnUtils::setBlockStrides(*gradI, gradI_user_md);
// gradW
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormatMkl);
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0);
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1);
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2);
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3);
mkldnnUtils::setBlockStrides(*gradW, gradW_user_md, permut);
// gradB
dnnl::memory::desc gradB_mkl_md;
@ -273,10 +252,10 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
// provide memory buffers and check whether reorder is required
// input
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// weights
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
// gradO
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast<void*>(gradO->buffer()));
@ -291,16 +270,10 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
// gradI
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer());
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_data_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]);
// gradW
auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->buffer());
const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc();
auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
auto gradW_user_mem = mkldnnUtils::loadDataToMklStream(*gradW, engine, stream, gradW_user_md, op_weights_bp_prim_desc.diff_weights_desc(), args[DNNL_ARG_DIFF_WEIGHTS]);
// gradB
if(gradB != nullptr) {
@ -318,10 +291,10 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
// reorder gradI if necessary
if (gradIReorder)
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
if (gradWReorder)
dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc())
dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem);
if (op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc())
dnnl::reorder(args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem).execute(stream, args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem);
stream.wait();

View File

@ -31,7 +31,7 @@ namespace ops {
namespace platforms {
//////////////////////////////////////////////////////////////////////////
static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* gradO, NDArray* gradI,
static void deconv2TFdBpMKLDNN(const NDArray* weights, const NDArray* gradO, NDArray* gradI,
const int bS, const int iC, const int iH, const int iW, const int oC, const int oH, const int oW,
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
const bool isNCHW, const int wFormat) {
@ -67,21 +67,17 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
// weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
w_user_md.data.format_kind = dnnl_blocked; // overrides format
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(2);
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
mkldnnUtils::setBlockStrides(*weights, w_user_md, {3,2,0,1}); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
// gradO
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl);
mkldnnUtils::setBlockStrides(gradO, gradO_user_md);
mkldnnUtils::setBlockStrides(*gradO, gradO_user_md);
// gradI
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl);
mkldnnUtils::setBlockStrides(gradI, gradI_user_md);
mkldnnUtils::setBlockStrides(*gradI, gradI_user_md);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -101,23 +97,20 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
// provide memory buffers and check whether reorder is required
// weights
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
// gradO
mkldnnUtils::loadDataToMklStream(gradO, engine, stream, gradO_user_md, op_data_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
mkldnnUtils::loadDataToMklStream(*gradO, engine, stream, gradO_user_md, op_data_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
// gradI
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer());
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_data_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]);
// run backward data calculations
dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
// reorder gradI if necessary
if (gradIReorder)
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc())
dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem);
stream.wait();
@ -189,7 +182,7 @@ PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) {
// gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
// }
deconv2TFdBackPropMKLDNN(weights, gradO, gradI, bS, iC, iH, iW, oC, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, wFormat);
deconv2TFdBpMKLDNN(weights, gradO, gradI, bS, iC, iH, iW, oC, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, wFormat);
// delete weights;

View File

@ -48,16 +48,13 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
dnnl::memory::dims dilation = { dD-1, dH-1, dW-1 };
uint i0, i1, i2, i3, i4;
if(0 == wFormat) {
i0 = 3; i1 = 4; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
}
else if(1 == wFormat) {
i0 = 1; i1 = 0; i2 = 2; i3 = 3; i4 = 4; // [iC, oC, kD, kH, kW] -> [oC, iC, kD, kH, kW]
}
else {
i0 = 4; i1 = 0; i2 = 1; i3 = 2; i4 = 3; // [iC, kD, kH, kW, oC] -> [oC, iC, kD, kH, kW]
}
std::vector<int> permut;
if(0 == wFormat)
permut = {3,4,0,1,2}; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
else if(1 == wFormat)
permut = {1,0,2,3,4}; // [iC, oC, kD, kH, kW] -> [oC, iC, kD, kH, kW]
else
permut = {4,0,1,2,3}; // [iC, kD, kH, kW, oC] -> [oC, iC, kD, kH, kW]
// input type
dnnl::memory::data_type xType;
@ -100,17 +97,12 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
// input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl);
mkldnnUtils::setBlockStrides(input, x_user_md);
mkldnnUtils::setBlockStrides(*input, x_user_md);
// weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
w_user_md.data.format_kind = dnnl_blocked; // overrides format
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4);
mkldnnUtils::setBlockStrides(*weights, w_user_md, permut);
// bias
dnnl::memory::desc b_mkl_md;
@ -120,7 +112,7 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
// output
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormatMkl);
mkldnnUtils::setBlockStrides(output, z_user_md);
mkldnnUtils::setBlockStrides(*output, z_user_md);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -137,10 +129,10 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
// provide memory buffers and check whether reorder is required
// input
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// weights
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
// bias
if(bias != nullptr) {
@ -149,17 +141,14 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
}
// output
auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer());
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
args[DNNL_ARG_DST] = z_mkl_mem;
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
// run calculations
dnnl::deconvolution_forward(op_prim_desc).execute(stream, args);
// reorder outputs if necessary
if (zReorder)
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
stream.wait();
@ -185,16 +174,13 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
dnnl::memory::dims dilation = { dD-1, dH-1, dW-1 };
uint i0, i1, i2, i3, i4;
if(0 == wFormat) {
i0 = 3; i1 = 4; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
}
else if(1 == wFormat) {
i0 = 1; i1 = 0; i2 = 2; i3 = 3; i4 = 4; // [iC, oC, kD, kH, kW] -> [oC, iC, kD, kH, kW]
}
else {
i0 = 4; i1 = 0; i2 = 1; i3 = 2; i4 = 3; // [iC, kD, kH, kW, oC] -> [oC, iC, kD, kH, kW]
}
std::vector<int> permut;
if(0 == wFormat)
permut = {3,4,0,1,2}; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
else if(1 == wFormat)
permut = {1,0,2,3,4}; // [iC, oC, kD, kH, kW] -> [oC, iC, kD, kH, kW]
else
permut = {4,0,1,2,3}; // [iC, kD, kH, kW, oC] -> [oC, iC, kD, kH, kW]
// input type
dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
@ -221,37 +207,27 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
// input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl);
mkldnnUtils::setBlockStrides(input, x_user_md);
mkldnnUtils::setBlockStrides(*input, x_user_md);
// weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
w_user_md.data.format_kind = dnnl_blocked; // overrides format
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4);
mkldnnUtils::setBlockStrides(*weights, w_user_md, permut);
// gradO
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl);
mkldnnUtils::setBlockStrides(gradO, gradO_user_md);
mkldnnUtils::setBlockStrides(*gradO, gradO_user_md);
// gradI
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl);
mkldnnUtils::setBlockStrides(gradI, gradI_user_md);
mkldnnUtils::setBlockStrides(*gradI, gradI_user_md);
// gradW
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormatMkl);
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0);
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1);
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2);
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3);
gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(i4);
mkldnnUtils::setBlockStrides(*gradW, gradW_user_md, permut);
// gradB
dnnl::memory::desc gradB_mkl_md;
@ -281,10 +257,10 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
// provide memory buffers and check whether reorder is required
// input
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// weights
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
// gradO
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast<void*>(gradO->buffer()));
@ -299,16 +275,10 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
// gradI
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer());
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_data_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]);
// gradW
auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->buffer());
const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc();
auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
auto gradW_user_mem = mkldnnUtils::loadDataToMklStream(*gradW, engine, stream, gradW_user_md, op_weights_bp_prim_desc.diff_weights_desc(), args[DNNL_ARG_DIFF_WEIGHTS]);
// gradB
if(gradB != nullptr) {
@ -326,10 +296,10 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
// reorder gradI if necessary
if (gradIReorder)
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
if (gradWReorder)
dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc())
dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem);
if (op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc())
dnnl::reorder(args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem).execute(stream, args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem);
stream.wait();

View File

@ -28,7 +28,7 @@
using namespace dnnl;
namespace sd {
namespace sd {
namespace ops {
namespace platforms {
@ -109,7 +109,7 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
// input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFormatMkl);
mkldnnUtils::setBlockStrides(input, x_user_md);
mkldnnUtils::setBlockStrides(*input, x_user_md);
// weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
@ -129,7 +129,7 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
// output
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xzFormatMkl);
mkldnnUtils::setBlockStrides(output, z_user_md);
mkldnnUtils::setBlockStrides(*output, z_user_md);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -146,10 +146,10 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
// provide memory buffers and check whether reorder is required
// input
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// weights
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
// bias
if(bias != nullptr) {
@ -158,24 +158,21 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
}
// output
auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer());
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
args[DNNL_ARG_DST] = z_mkl_mem;
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
// run calculations
dnnl::convolution_forward(op_prim_desc).execute(stream, args);
// reorder outputs if necessary
if (zReorder)
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
stream.wait();
// shape::printArray(z_mkl_mem.map_data<float>(),8);
}
//////////////////////////////////////////////////////////////////////////
static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB,
static void depthwiseConv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB,
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
const int paddingMode, const bool isNCHW, const int wFormat) {
@ -235,7 +232,7 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
// input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFormatMkl);
mkldnnUtils::setBlockStrides(input, x_user_md);
mkldnnUtils::setBlockStrides(*input, x_user_md);
// weights
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
@ -250,12 +247,12 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
// gradO
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xzFormatMkl);
mkldnnUtils::setBlockStrides(gradO, gradO_user_md);
mkldnnUtils::setBlockStrides(*gradO, gradO_user_md);
// gradI
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xzFormatMkl);
mkldnnUtils::setBlockStrides(gradI, gradI_user_md);
mkldnnUtils::setBlockStrides(*gradI, gradI_user_md);
// gradW
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
@ -294,10 +291,10 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
// provide memory buffers and check whether reorder is required
// input
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// weights
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
// gradO
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast<void*>(gradO->buffer()));
@ -312,16 +309,10 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
// gradI
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer());
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_data_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]);
// gradW
auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->buffer());
const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc();
auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
auto gradW_user_mem = mkldnnUtils::loadDataToMklStream(*gradW, engine, stream, gradW_user_md, op_weights_bp_prim_desc.diff_weights_desc(), args[DNNL_ARG_DIFF_WEIGHTS]);
// gradB
if(gradB != nullptr) {
@ -339,10 +330,10 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
// reorder gradI if necessary
if (gradIReorder)
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
if (gradWReorder)
dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc())
dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem);
if (op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc())
dnnl::reorder(args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem).execute(stream, args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem);
stream.wait();
@ -458,7 +449,7 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) {
if(bias)
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
depthwiseConv2dNackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat);
depthwiseConv2dBpMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat);
return Status::OK();
}

View File

@ -169,71 +169,43 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
x_lstm_md = dnnl::memory::desc({sL, bS, nIn}, xType, dnnl::memory::format_tag::any);
// x_user_md = dataFormat == 0 ? dnnl::memory::desc({sL, bS, nIn}, type, dnnl::memory::format_tag::tnc) : dnnl::memory::desc({bS, sL, nIn}, type, dnnl::memory::format_tag::ntc);
x_user_md = dnnl::memory::desc({sL, bS, nIn}, xType, dnnl::memory::format_tag::tnc);
x_user_md.data.format_kind = dnnl_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0];
x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1];
x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2];
mkldnnUtils::setBlockStrides(*x, x_user_md);
// wx
wx_lstm_md = dnnl::memory::desc({1,dirDim,nIn,4,nOut}, wType, dnnl::memory::format_tag::any);
wx_user_md = dnnl::memory::desc({1,dirDim,nIn,4,nOut}, wType, dnnl::memory::format_tag::ldigo);
wx_user_md.data.format_kind = dnnl_blocked; // overrides format
wx_user_md.data.format_desc.blocking.strides[0] = Wx->stridesOf()[0];
wx_user_md.data.format_desc.blocking.strides[1] = Wx->stridesOf()[1];
wx_user_md.data.format_desc.blocking.strides[2] = Wx->stridesOf()[2];
wx_user_md.data.format_desc.blocking.strides[3] = Wx->stridesOf()[3];
wx_user_md.data.format_desc.blocking.strides[4] = Wx->stridesOf()[4];
mkldnnUtils::setBlockStrides(*Wx, wx_user_md);
// wr
wr_lstm_md = dnnl::memory::desc({1,dirDim,nOut,4,nOut}, wType, dnnl::memory::format_tag::any);
wr_user_md = dnnl::memory::desc({1,dirDim,nOut,4,nOut}, wType, dnnl::memory::format_tag::ldigo);
wr_user_md.data.format_kind = dnnl_blocked; // overrides format
wr_user_md.data.format_desc.blocking.strides[0] = Wr->stridesOf()[0];
wr_user_md.data.format_desc.blocking.strides[1] = Wr->stridesOf()[1];
wr_user_md.data.format_desc.blocking.strides[2] = Wr->stridesOf()[2];
wr_user_md.data.format_desc.blocking.strides[3] = Wr->stridesOf()[3];
wr_user_md.data.format_desc.blocking.strides[4] = Wr->stridesOf()[4];
mkldnnUtils::setBlockStrides(*Wr, wr_user_md);
// h
h_lstm_md = dnnl::memory::desc({sL, bS, hDirDim*nOut}, hType, dnnl::memory::format_tag::any);
// h_user_md = dataFormat == 0 ? dnnl::memory::desc({sL, bS, hDirDim*nOut}, type, dnnl::memory::format_tag::tnc) : dnnl::memory::desc({bS, sL, hDirDim*nOut}, type, dnnl::memory::format_tag::ntc);
h_user_md = dnnl::memory::desc({sL, bS, hDirDim*nOut}, hType, dnnl::memory::format_tag::tnc);
h_user_md.data.format_kind = dnnl_blocked; // overrides format
h_user_md.data.format_desc.blocking.strides[0] = h->stridesOf()[0];
h_user_md.data.format_desc.blocking.strides[1] = h->stridesOf()[1];
h_user_md.data.format_desc.blocking.strides[2] = h->stridesOf()[2];
mkldnnUtils::setBlockStrides(*h, h_user_md);
// b
if(b) {
b_lstm_md = dnnl::memory::desc({1,dirDim,4,nOut}, bType, dnnl::memory::format_tag::any);
b_user_md = dnnl::memory::desc({1,dirDim,4,nOut}, bType, dnnl::memory::format_tag::ldgo);
b_user_md.data.format_kind = dnnl_blocked; // overrides format
b_user_md.data.format_desc.blocking.strides[0] = b->stridesOf()[0];
b_user_md.data.format_desc.blocking.strides[1] = b->stridesOf()[1];
b_user_md.data.format_desc.blocking.strides[2] = b->stridesOf()[2];
b_user_md.data.format_desc.blocking.strides[3] = b->stridesOf()[3];
mkldnnUtils::setBlockStrides(*b, b_user_md);
}
// hI
if(hI) {
hI_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::any);
hI_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::ldnc);
hI_user_md.data.format_kind = dnnl_blocked; // overrides format
hI_user_md.data.format_desc.blocking.strides[0] = hI->stridesOf()[0];
hI_user_md.data.format_desc.blocking.strides[1] = hI->stridesOf()[1];
hI_user_md.data.format_desc.blocking.strides[2] = hI->stridesOf()[2];
hI_user_md.data.format_desc.blocking.strides[3] = hI->stridesOf()[3];
mkldnnUtils::setBlockStrides(*hI, hI_user_md);
}
// cI
if(cI) {
cI_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::any);
cI_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::ldnc);
cI_user_md.data.format_kind = dnnl_blocked; // overrides format
cI_user_md.data.format_desc.blocking.strides[0] = cI->stridesOf()[0];
cI_user_md.data.format_desc.blocking.strides[1] = cI->stridesOf()[1];
cI_user_md.data.format_desc.blocking.strides[2] = cI->stridesOf()[2];
cI_user_md.data.format_desc.blocking.strides[2] = cI->stridesOf()[3];
mkldnnUtils::setBlockStrides(*cI, cI_user_md);
}
// hL
@ -241,20 +213,13 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
hL_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::any);
hL_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc);
hL_user_md.data.format_kind = dnnl_blocked; // overrides format
hL_user_md.data.format_desc.blocking.strides[0] = hL->stridesOf()[0];
hL_user_md.data.format_desc.blocking.strides[1] = hL->stridesOf()[1];
hL_user_md.data.format_desc.blocking.strides[2] = hL->stridesOf()[2];
hL_user_md.data.format_desc.blocking.strides[3] = hL->stridesOf()[3];
mkldnnUtils::setBlockStrides(*hL, hL_user_md);
}
if(cL) {
cL_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc);
cL_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc);
cL_user_md.data.format_kind = dnnl_blocked; // overrides format
cL_user_md.data.format_desc.blocking.strides[0] = cL->stridesOf()[0];
cL_user_md.data.format_desc.blocking.strides[1] = cL->stridesOf()[1];
cL_user_md.data.format_desc.blocking.strides[2] = cL->stridesOf()[2];
cL_user_md.data.format_desc.blocking.strides[3] = cL->stridesOf()[3];
mkldnnUtils::setBlockStrides(*cL, cL_user_md);
}
// lstm memory description
@ -272,64 +237,49 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
// provide memory and check whether reorder is required
// x
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, lstm_prim_desc.src_layer_desc(), args[DNNL_ARG_SRC_LAYER]);
mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, lstm_prim_desc.src_layer_desc(), args[DNNL_ARG_SRC_LAYER]);
// wx
mkldnnUtils::loadDataToMklStream(Wx, engine, stream, wx_user_md, lstm_prim_desc.weights_layer_desc(), args[DNNL_ARG_WEIGHTS_LAYER]);
mkldnnUtils::loadDataToMklStream(*Wx, engine, stream, wx_user_md, lstm_prim_desc.weights_layer_desc(), args[DNNL_ARG_WEIGHTS_LAYER]);
// wr
mkldnnUtils::loadDataToMklStream(Wr, engine, stream, wr_user_md, lstm_prim_desc.weights_iter_desc(), args[DNNL_ARG_WEIGHTS_ITER]);
mkldnnUtils::loadDataToMklStream(*Wr, engine, stream, wr_user_md, lstm_prim_desc.weights_iter_desc(), args[DNNL_ARG_WEIGHTS_ITER]);
// h
auto h_user_mem = dnnl::memory(h_user_md, engine, h->buffer());
const bool hReorder = lstm_prim_desc.dst_layer_desc() != h_user_mem.get_desc();
auto h_lstm_mem = hReorder ? dnnl::memory(lstm_prim_desc.dst_layer_desc(), engine) : h_user_mem;
args[DNNL_ARG_DST_LAYER] = h_lstm_mem;
auto h_user_mem = mkldnnUtils::loadDataToMklStream(*h, engine, stream, h_user_md, lstm_prim_desc.dst_layer_desc(), args[DNNL_ARG_DST_LAYER]);
// b
if(b) {
mkldnnUtils::loadDataToMklStream(b, engine, stream, b_user_md, lstm_prim_desc.bias_desc(), args[DNNL_ARG_BIAS]);
}
if(b)
mkldnnUtils::loadDataToMklStream(*b, engine, stream, b_user_md, lstm_prim_desc.bias_desc(), args[DNNL_ARG_BIAS]);
// hI
if(hI) {
mkldnnUtils::loadDataToMklStream(hI, engine, stream, hI_user_md, lstm_prim_desc.src_iter_desc(), args[DNNL_ARG_SRC_ITER]);
}
if(hI)
mkldnnUtils::loadDataToMklStream(*hI, engine, stream, hI_user_md, lstm_prim_desc.src_iter_desc(), args[DNNL_ARG_SRC_ITER]);
// cI
if(cI) {
mkldnnUtils::loadDataToMklStream(cI, engine, stream, cI_user_md, lstm_prim_desc.src_iter_c_desc(), args[DNNL_ARG_SRC_ITER_C]);
}
if(cI)
mkldnnUtils::loadDataToMklStream(*cI, engine, stream, cI_user_md, lstm_prim_desc.src_iter_c_desc(), args[DNNL_ARG_SRC_ITER_C]);
bool hLReorder(false), cLReorder(false);
dnnl::memory hL_user_mem, cL_user_mem, hL_lstm_mem, cL_lstm_mem;
// hL
if(hL) {
hL_user_mem = dnnl::memory(hL_user_md, engine, hL->buffer());
hLReorder = lstm_prim_desc.dst_iter_desc() != hL_user_mem.get_desc();
hL_lstm_mem = hLReorder ? dnnl::memory(lstm_prim_desc.dst_iter_desc(), engine) : hL_user_mem;
args[DNNL_ARG_DST_ITER] = hL_lstm_mem;
}
if(hL)
hL_user_mem = mkldnnUtils::loadDataToMklStream(*hL, engine, stream, hL_user_md, lstm_prim_desc.dst_iter_desc(), args[DNNL_ARG_DST_ITER]);
// cL
if(cL) {
cL_user_mem = dnnl::memory(cL_user_md, engine, cL->buffer());
cLReorder = lstm_prim_desc.dst_iter_c_desc() != cL_user_mem.get_desc();
cL_lstm_mem = cLReorder ? dnnl::memory(lstm_prim_desc.dst_iter_c_desc(), engine) : cL_user_mem;
args[DNNL_ARG_DST_ITER_C] = cL_lstm_mem;
}
if(cL)
cL_user_mem = mkldnnUtils::loadDataToMklStream(*cL, engine, stream, cL_user_md, lstm_prim_desc.dst_iter_c_desc(), args[DNNL_ARG_DST_ITER_C]);
// run calculations
lstm_forward(lstm_prim_desc).execute(stream, args);
// reorder outputs if necessary
if (hReorder)
reorder(h_lstm_mem, h_user_mem).execute(stream, h_lstm_mem, h_user_mem);
if(hLReorder)
reorder(hL_lstm_mem, hL_user_mem).execute(stream, hL_lstm_mem, hL_user_mem);
if(cLReorder)
reorder(cL_lstm_mem, cL_user_mem).execute(stream, cL_lstm_mem, cL_user_mem);
if (lstm_prim_desc.dst_layer_desc() != h_user_mem.get_desc())
reorder(args[DNNL_ARG_DST_LAYER], h_user_mem).execute(stream, args[DNNL_ARG_DST_LAYER], h_user_mem);
if(lstm_prim_desc.dst_iter_desc() != hL_user_mem.get_desc())
reorder(args[DNNL_ARG_DST_ITER], hL_user_mem).execute(stream, args[DNNL_ARG_DST_ITER], hL_user_mem);
if(lstm_prim_desc.dst_iter_c_desc() != cL_user_mem.get_desc())
reorder(args[DNNL_ARG_DST_ITER_C], cL_user_mem).execute(stream, args[DNNL_ARG_DST_ITER_C], cL_user_mem);
stream.wait();
}
@ -377,9 +327,9 @@ PLATFORM_IMPL(lstmLayer, ENGINE_CPU) {
auto cL = retLastC ? OUTPUT_VARIABLE(count++) : nullptr; // cell state at last step
// evaluate dimensions
const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat);
const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(-2);
const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(-1);
const Nd4jLong sL = x->sizeAt(dataFormat);
const Nd4jLong bS = dataFormat == 0 ? x->sizeAt(1) : x->sizeAt(0);
const Nd4jLong nIn = x->sizeAt(2);
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
// inputs validations
@ -435,14 +385,21 @@ PLATFORM_IMPL(lstmLayer, ENGINE_CPU) {
WxR = new NDArray(Wx->reshape(Wx->ordering(), {1,dirDim,nIn,4,nOut}));
WrR = new NDArray(Wr->reshape(Wr->ordering(), {1,dirDim,nOut,4,nOut}));
if(b)
bR = new NDArray(b->reshape(b->ordering(), {1,dirDim,4,nOut}));
bR = new NDArray(b->reshape(b->ordering(), {1,dirDim,4,nOut}));
else
bR = new NDArray(x->ordering(), {1,dirDim,4,nOut}, x->dataType(), x->getContext()); // already nullified
if(hI)
hIR = new NDArray(hI->reshape(hI->ordering(), {1,dirDim,bS,nOut}));
if(cI)
cIR = new NDArray(cI->reshape(cI->ordering(), {1,dirDim,bS,nOut}));
if(hL)
hLR = new NDArray(hL->reshape(hL->ordering(), {1,dirDim,bS,nOut}, false));
if(cL)
cLR = new NDArray(cL->reshape(cL->ordering(), {1,dirDim,bS,nOut}, false));

View File

@ -31,20 +31,6 @@ namespace sd {
namespace ops {
namespace platforms {
dnnl::memory::format_tag get_format_tag(const sd::NDArray &array) {
switch (array.rankOf()) {
case 1:
return dnnl::memory::format_tag::ab;
case 2:
return array.ordering() == 'c' ? dnnl::memory::format_tag::ab : dnnl::memory::format_tag::ba;
case 3:
return array.ordering() == 'c' ? dnnl::memory::format_tag::abc : dnnl::memory::format_tag::cba;
default:
throw std::runtime_error("MKLDNN matmul only supports 2D/3D arrays");
}
}
//////////////////////////////////////////////////////////////////////////
static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const bool transX, const bool transY, float alpha = 1.f, float beta = 0.f) {
@ -123,11 +109,16 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b
else if(z->dataType() == DataType::INT8)
zType = dnnl::memory::data_type::s8;
const auto xFormat = xRank == 1 ? dnnl::memory::format_tag::ab : mkldnnUtils::getFormat(*xTR);
const auto yFormat = yRank == 1 ? dnnl::memory::format_tag::ab : mkldnnUtils::getFormat(*yTR);
const auto zFormat = zRank == 1 ? dnnl::memory::format_tag::ab : mkldnnUtils::getFormat(*zR);
// memory descriptors for arrays
dnnl::memory::desc x_mkl_md, x_user_md, y_mkl_md, y_user_md, z_mkl_md, z_user_md;
// x
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, get_format_tag(*xTR));
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, get_format_tag(*xTR));
x_user_md = x_mkl_md = dnnl::memory::desc(xShape, xType, xFormat);
if(xTR->ews() != 1) {
x_user_md.data.format_kind = dnnl_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = xRank == 1 ? 1 : xTR->strideAt(0);
@ -137,8 +128,7 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b
}
// y
dnnl::memory::desc y_mkl_md = dnnl::memory::desc(yShape, yType, get_format_tag(*yTR));
dnnl::memory::desc y_user_md = dnnl::memory::desc(yShape, yType, get_format_tag(*yTR));
y_user_md = y_mkl_md = dnnl::memory::desc(yShape, yType, yFormat);
if(yTR->ews() != 1) {
y_user_md.data.format_kind = dnnl_blocked; // overrides format
y_user_md.data.format_desc.blocking.strides[0] = yRank == 1 ? 1 : yTR->strideAt(0);
@ -148,8 +138,7 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b
}
// z
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, zType, get_format_tag(*zR));
dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, zType, get_format_tag(*zR));
z_user_md = z_mkl_md = dnnl::memory::desc(zShape, zType, zFormat);
if(zR->ews() != 1) {
z_user_md.data.format_kind = dnnl_blocked; // overrides format
z_user_md.data.format_desc.blocking.strides[0] = zRank == 1 ? 1 : zR->strideAt(0);
@ -181,37 +170,20 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b
// provide memory buffers and check whether reorder is required
// input
mkldnnUtils::loadDataToMklStream(xTR, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
/*
auto x_user_mem = dnnl::memory(x_user_md, engine, xTR->buffer());
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
if (xReorder)
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
args[DNNL_ARG_SRC] = x_mkl_mem;
*/
mkldnnUtils::loadDataToMklStream(*xTR, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// y
mkldnnUtils::loadDataToMklStream(yTR, engine, stream, y_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
/*
auto y_user_mem = dnnl::memory(y_user_md, engine, yTR->buffer());
const bool yReorder = op_prim_desc.weights_desc() != y_user_mem.get_desc();
auto y_mkl_mem = yReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : y_user_mem;
if (yReorder)
dnnl::reorder(y_user_mem, y_mkl_mem).execute(stream, y_user_mem, y_mkl_mem);
args[DNNL_ARG_WEIGHTS] = y_mkl_mem;
*/
mkldnnUtils::loadDataToMklStream(*yTR, engine, stream, y_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
// z
auto z_user_mem = dnnl::memory(z_user_md, engine, zR->buffer());
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
args[DNNL_ARG_DST] = z_mkl_mem;
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*zR, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
// run calculations
dnnl::matmul(op_prim_desc).execute(stream, args);
// reorder outputs if necessary
if (zReorder)
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
stream.wait();

View File

@ -38,45 +38,65 @@ void getDims(const NDArray* array, const int rank, dnnl::memory::dims& mklDims){
mklDims = dnnl::memory::dims(vDims);
}
//////////////////////////////////////////////////////////////////////
dnnl::memory::format_tag getFormat(const int rank){
if (2 == rank) {
return dnnl::memory::format_tag::ab;
}
else if (3 == rank) {
return dnnl::memory::format_tag::abc;
}
else if (4 == rank) {
return dnnl::memory::format_tag::abcd;
}
else if (5 == rank) {
return dnnl::memory::format_tag::abcde;
}
else if (6 == rank) {
return dnnl::memory::format_tag::abcdef;
}
return dnnl::memory::format_tag::a; // 1 == dataSetRank
dnnl::memory::format_tag getFormat(const NDArray& arr) {
dnnl::memory::format_tag result;
switch (arr.rankOf()) {
case 1:
result = dnnl::memory::format_tag::a;
break;
case 2:
result = arr.ordering() == 'c' ? dnnl::memory::format_tag::ab : dnnl::memory::format_tag::ba;
break;
case 3:
result = arr.ordering() == 'c' ? dnnl::memory::format_tag::abc : dnnl::memory::format_tag::cba;
break;
case 4:
result = dnnl::memory::format_tag::abcd;
break;
case 5:
result = dnnl::memory::format_tag::abcde;
break;
case 6:
result = dnnl::memory::format_tag::abcdef;
break;
default:
throw std::invalid_argument("MKLDNN getFormat: do we really want to use arras with rank > 6 ?");
}
return result;
}
//////////////////////////////////////////////////////////////////////
void setBlockStrides(const NDArray* array, dnnl::memory::desc& mklMd){
void setBlockStrides(const NDArray& array, dnnl::memory::desc& mklMd, const std::vector<int>& permut) {
if (array->ews() != 1 || array->ordering() != 'c') {
mklMd.data.format_kind = dnnl_blocked; // overrides format
for (auto i = 0; i < array->rankOf(); ++i) {
mklMd.data.format_desc.blocking.strides[i] = array->strideAt(i);
if (array.ews() != 1 || (array.rankOf() > 3 && array.ordering() == 'f') || !permut.empty()) {
mklMd.data.format_kind = dnnl_blocked; // overrides format
if(permut.empty())
for (auto i = 0; i < array.rankOf(); ++i)
mklMd.data.format_desc.blocking.strides[i] = array.strideAt(i);
else {
if(array.rankOf() != permut.size())
throw std::invalid_argument("mkldnnUtils::setBlockStrides: size of permut vector is not equal to array rank !");
for (auto i = 0; i < array.rankOf(); ++i)
mklMd.data.format_desc.blocking.strides[i] = array.strideAt(permut[i]);
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////
void loadDataToMklStream(const NDArray* array, const dnnl::engine& engine, const dnnl::stream& stream, const dnnl::memory::desc& user_md, const dnnl::memory::desc& primitive_md,
dnnl::memory& arg) {
dnnl::memory loadDataToMklStream(const NDArray& array, const dnnl::engine& engine, const dnnl::stream& stream,
const dnnl::memory::desc& user_md, const dnnl::memory::desc& primitive_md, dnnl::memory& arg) {
auto user_mem = dnnl::memory(user_md, engine,const_cast<void*>(array->buffer()));
auto user_mem = dnnl::memory(user_md, engine, const_cast<NDArray&>(array).buffer());
const bool bReorder = primitive_md != user_mem.get_desc();
auto mkl_mem = bReorder ? dnnl::memory(primitive_md, engine) : user_mem;
if (bReorder)
dnnl::reorder(user_mem, mkl_mem).execute(stream, user_mem, mkl_mem);
arg = mkl_mem;
return user_mem;
}
//////////////////////////////////////////////////////////////////////
@ -122,33 +142,21 @@ void poolingMKLDNN(const NDArray *input, NDArray *output,
xzFrmat = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
}
std::vector<int> permut;
if(!isNCHW)
permut = rank == 4 ? std::vector<int>({0,3,1,2}) : std::vector<int>({0,4,1,2,3});
// memory descriptors for arrays
// input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, xzFrmat);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
if(input->ews() != 1 || input->ordering() != 'c') {
x_user_md.data.format_kind = dnnl_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(isNCHW ? 1 :-1);
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(isNCHW ? 2 : 1);
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(isNCHW ? 3 : 2);
if(rank == 5)
x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(isNCHW ? 4 : 3);
}
mkldnnUtils::setBlockStrides(*input, x_user_md, permut);
// output
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
if(output->ews() != 1 || output->ordering() != 'c') {
z_user_md.data.format_kind = dnnl_blocked; // overrides format
z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(isNCHW ? 1 :-1);
z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(isNCHW ? 2 : 1);
z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(isNCHW ? 3 : 2);
if(rank == 5)
z_user_md.data.format_desc.blocking.strides[4] = output->strideAt(isNCHW ? 4 : 3);
}
mkldnnUtils::setBlockStrides(*output, z_user_md, permut);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -164,20 +172,17 @@ void poolingMKLDNN(const NDArray *input, NDArray *output,
// provide memory buffers and check whether reorder is required
// input
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// output
auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer());
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
args[DNNL_ARG_DST] = z_mkl_mem;
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
// run calculations
dnnl::pooling_forward(op_prim_desc).execute(stream, args);
// reorder outputs if necessary
if (zReorder)
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
stream.wait();
}
@ -226,46 +231,27 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI,
xzFrmat = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
}
std::vector<int> permut;
if(!isNCHW)
permut = rank == 4 ? std::vector<int>({0,3,1,2}) : std::vector<int>({0,4,1,2,3});
// memory descriptors for arrays
// input
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, xzFrmat);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
if(input->ews() != 1 || input->ordering() != 'c') {
x_user_md.data.format_kind = dnnl_blocked; // overrides format
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(isNCHW ? 1 :-1);
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(isNCHW ? 2 : 1);
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(isNCHW ? 3 : 2);
if(rank == 5)
x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(isNCHW ? 4 : 3);
}
mkldnnUtils::setBlockStrides(*input, x_user_md, permut);
// gradO
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
if(gradO->ews() != 1 || gradO->ordering() != 'c') {
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(isNCHW ? 1 :-1);
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(isNCHW ? 2 : 1);
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(isNCHW ? 3 : 2);
if(rank == 5)
gradO_user_md.data.format_desc.blocking.strides[4] = gradO->strideAt(isNCHW ? 4 : 3);
}
mkldnnUtils::setBlockStrides(*gradO, gradO_user_md, permut);
// gradI
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
if(gradI->ews() != 1 || gradI->ordering() != 'c') {
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(isNCHW ? 1 :-1);
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(isNCHW ? 2 : 1);
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(isNCHW ? 3 : 2);
if(rank == 5)
gradI_user_md.data.format_desc.blocking.strides[4] = gradI->strideAt(isNCHW ? 4 : 3);
}
mkldnnUtils::setBlockStrides(*gradI, gradI_user_md, permut);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
dnnl::stream stream(engine);
@ -282,18 +268,15 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI,
std::unordered_map<int, dnnl::memory> args;
// gradO
mkldnnUtils::loadDataToMklStream(gradO, engine, stream, gradO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
mkldnnUtils::loadDataToMklStream(*gradO, engine, stream, gradO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
// gradI
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer());
const bool gradIReorder = op_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]);
if(mode == algorithm::pooling_max) {
// input
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// z
auto z_mkl_mem = dnnl::memory(op_ff_prim_desc.dst_desc(), engine);
@ -310,10 +293,9 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI,
// run backward calculations
dnnl::pooling_backward(op_bp_prim_desc).execute(stream, args);
// reorder gradI if necessary
if (gradIReorder)
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
if (op_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc())
dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem);
stream.wait();
}

View File

@ -100,6 +100,8 @@ namespace sd {
DECLARE_PLATFORM(xw_plus_b_bp, ENGINE_CPU);
DECLARE_PLATFORM(concat, ENGINE_CPU);
}
}
@ -123,19 +125,13 @@ namespace sd {
*/
void getDims(const NDArray* array, const int rank, dnnl::memory::dims& mklDims);
/**
* This function generate memory format tag based on rank
* @param const array rank
* This function evaluate memory format tag based on array shapeInfo
* @param const array
* @return memory format
*/
dnnl::memory::format_tag getFormat(const int rank);
/**
* This function generate memory format tag based on rank
* @param const pointer to dataset
* @param const dataset rank
* @param reference to memory descriptor
* @return memory format
*/
void setBlockStrides(const NDArray* array, dnnl::memory::desc& mklMd);
dnnl::memory::format_tag getFormat(const NDArray& arr);
void setBlockStrides(const NDArray& array, dnnl::memory::desc& mklMd, const std::vector<int>& permut = {});
//////////////////////////////////////////////////////////////////////
/**
* This function load and reorder user memory to mkl
@ -147,7 +143,7 @@ namespace sd {
* @param primitive memory descriptor
* @param dnnl arg activation enumerator
*/
void loadDataToMklStream(const NDArray* array, const dnnl::engine& engine, const dnnl::stream& stream, const dnnl::memory::desc& user_md, const dnnl::memory::desc& primitive_md,
dnnl::memory loadDataToMklStream(const NDArray& array, const dnnl::engine& engine, const dnnl::stream& stream, const dnnl::memory::desc& user_md, const dnnl::memory::desc& primitive_md,
dnnl::memory& arg);
/**

View File

@ -35,32 +35,37 @@ namespace sd {
//////////////////////////////////////////////////////////////////////
static void softmaxMKLDNN(const NDArray* x, NDArray* z, const int axis) {
const auto xRank = x->rankOf();
dnnl::memory::dims xShape, zShape;
dnnl::memory::dims shape = x->getShapeAsFlatVector();
mkldnnUtils::getDims(x, xRank, xShape);
mkldnnUtils::getDims(z, xRank, zShape);
const int xRank = x->rankOf();
dnnl::memory::format_tag xFormat = mkldnnUtils::getFormat(*x);
dnnl::memory::format_tag zFormat = mkldnnUtils::getFormat(*z);
dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank);
// optimized cases
if (2 == xRank && 0 == axis) {
format = dnnl::memory::format_tag::ba;
if(x->ews() == 1)
xFormat = dnnl::memory::format_tag::ba;
if(z->ews() == 1)
zFormat = dnnl::memory::format_tag::ba;
}
else if (4 == xRank && 1 == axis && (x->sizeAt(2) * x->sizeAt(3)) > 1) {
format = dnnl::memory::format_tag::acdb;
if(x->ews() == 1)
xFormat = dnnl::memory::format_tag::acdb;
if(z->ews() == 1)
zFormat = dnnl::memory::format_tag::acdb;
}
dnnl::memory::data_type xType = dnnl::memory::data_type::f32;
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, format);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format);
mkldnnUtils::setBlockStrides(x, x_user_md);
dnnl::memory::desc x_mkl_md, x_user_md, z_mkl_md, z_user_md;
x_user_md = x_mkl_md = dnnl::memory::desc(shape, xType, xFormat);
mkldnnUtils::setBlockStrides(*x, x_user_md);
// z
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, xType, format);
dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, xType, format);
mkldnnUtils::setBlockStrides(z, z_user_md);
z_user_md = z_mkl_md = dnnl::memory::desc(shape, xType, zFormat);
mkldnnUtils::setBlockStrides(*z, z_user_md);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -80,20 +85,17 @@ namespace sd {
// provide memory buffers and check whether reorder is required
// input
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// z
auto z_user_mem = dnnl::memory(z_user_md, engine, z->buffer());
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
args[DNNL_ARG_DST] = z_mkl_mem;
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*z, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
// run calculations
dnnl::softmax_forward(op_prim_desc).execute(stream, args);
// reorder outputs if necessary
if (zReorder)
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
stream.wait();
}
@ -142,33 +144,19 @@ namespace sd {
//////////////////////////////////////////////////////////////////////
static void softmaxBpMKLDNN(const NDArray* x, const NDArray* dLdz, NDArray* dLdx, const int axis) {
const auto xRank = x->rankOf();
const auto dLdzRank = dLdz->rankOf();
dnnl::memory::dims xShape, dLdxShape, dLdzShape;
mkldnnUtils::getDims(x, xRank, xShape);
mkldnnUtils::getDims(dLdx, xRank, dLdxShape);
mkldnnUtils::getDims(dLdz, dLdzRank, dLdzShape);
dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank);
dnnl::memory::desc x_user_md, x_mkl_md, dLdx_mkl_md, dLdx_user_md, dLdz_mkl_md, dLdz_user_md;
// x
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
mkldnnUtils::setBlockStrides(x, x_user_md);
x_mkl_md = x_user_md = dnnl::memory::desc(x->getShapeAsFlatVector(), dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*x));
mkldnnUtils::setBlockStrides(*x, x_user_md);
// dLdx
dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(dLdxShape, dnnl::memory::data_type::f32, format);
dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(dLdxShape, dnnl::memory::data_type::f32, format);
mkldnnUtils::setBlockStrides(dLdx, dLdx_user_md);
// todo if mkl does not support broadcast we can remove this
format = mkldnnUtils::getFormat(dLdzRank);
dLdx_mkl_md = dLdx_user_md = dnnl::memory::desc(dLdx->getShapeAsFlatVector(), dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*dLdx));
mkldnnUtils::setBlockStrides(*dLdx, dLdx_user_md);
// dLdz
dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(dLdzShape, dnnl::memory::data_type::f32, format);
dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(dLdzShape, dnnl::memory::data_type::f32, format);
mkldnnUtils::setBlockStrides(dLdz, dLdz_user_md);
dLdz_mkl_md = dLdz_user_md = dnnl::memory::desc(dLdz->getShapeAsFlatVector(), dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*dLdz));
mkldnnUtils::setBlockStrides(*dLdz, dLdz_user_md);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -188,19 +176,18 @@ namespace sd {
// provide memory buffers and check whether reorder is required for forward
// input
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), argsff[DNNL_ARG_SRC]);
mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), argsff[DNNL_ARG_SRC]);
// dLdz
mkldnnUtils::loadDataToMklStream(*dLdz, engine, stream, dLdz_user_md, op_bp_prim_desc.diff_dst_desc(), argsbp[DNNL_ARG_DIFF_DST]);
// dLdx
auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->buffer());
const bool dLdxReorder = op_ff_prim_desc.dst_desc() != dLdx_user_mem.get_desc();
auto dLdx_mkl_mem = dLdxReorder ? dnnl::memory(op_ff_prim_desc.dst_desc(), engine) : dLdx_user_mem;
argsff[DNNL_ARG_DST] = dLdx_mkl_mem;
auto dLdx_user_mem = mkldnnUtils::loadDataToMklStream(*dLdx, engine, stream, dLdx_user_md, op_ff_prim_desc.src_desc(), argsff[DNNL_ARG_DST]);
// check and arg set for backprob
argsbp[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem;
argsbp[DNNL_ARG_DST] = dLdx_mkl_mem;
// dLdz
mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_bp_prim_desc.diff_dst_desc(), argsbp[DNNL_ARG_DIFF_DST]);
argsbp[DNNL_ARG_DIFF_SRC] = argsff[DNNL_ARG_DST];
argsbp[DNNL_ARG_DST] = argsff[DNNL_ARG_DST];
// run calculations forward
dnnl::softmax_forward(op_ff_prim_desc).execute(stream, argsff);
@ -209,8 +196,8 @@ namespace sd {
dnnl::softmax_backward(op_bp_prim_desc).execute(stream, argsbp);
// reorder outputs if necessary
if (dLdxReorder)
dnnl::reorder(dLdx_mkl_mem, dLdx_user_mem).execute(stream, dLdx_mkl_mem, dLdx_user_mem);
if (op_ff_prim_desc.dst_desc() != dLdx_user_mem.get_desc())
dnnl::reorder(argsff[DNNL_ARG_DST], dLdx_user_mem).execute(stream, argsff[DNNL_ARG_DST], dLdx_user_mem);
stream.wait();
}

View File

@ -34,22 +34,16 @@ namespace sd {
//////////////////////////////////////////////////////////////////////
static void tanhMKLDNN(const NDArray* x, NDArray* z) {
const auto xRank = x->rankOf();
dnnl::memory::dims xShape, zShape;
dnnl::memory::dims shape = x->getShapeAsFlatVector();
mkldnnUtils::getDims(x, xRank, xShape);
mkldnnUtils::getDims(z, xRank, zShape);
dnnl::memory::desc x_mkl_md, x_user_md, z_mkl_md, z_user_md;
dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank);
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
mkldnnUtils::setBlockStrides(x, x_user_md);
x_user_md = x_mkl_md = dnnl::memory::desc(shape, dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*x));
mkldnnUtils::setBlockStrides(*x, x_user_md);
// z
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format);
dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format);
mkldnnUtils::setBlockStrides(z, z_user_md);
z_user_md = z_mkl_md = dnnl::memory::desc(shape, dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*z));
mkldnnUtils::setBlockStrides(*z, z_user_md);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -68,20 +62,17 @@ namespace sd {
// provide memory buffers and check whether reorder is required
// input
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// z
auto z_user_mem = dnnl::memory(z_user_md, engine, z->buffer());
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
args[DNNL_ARG_DST] = z_mkl_mem;
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*z, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
// run calculations
dnnl::eltwise_forward(op_prim_desc).execute(stream, args);
// reorder outputs if necessary
if (zReorder)
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
stream.wait();
}
@ -121,28 +112,21 @@ namespace sd {
//////////////////////////////////////////////////////////////////////
static void tanhBpMKLDNN(const NDArray* x, const NDArray* dLdz, NDArray* dLdx) {
const auto xRank = x->rankOf();
dnnl::memory::dims xShape, dLdzShape, dLdxShape;
dnnl::memory::dims shape = x->getShapeAsFlatVector();
mkldnnUtils::getDims(x, xRank, xShape);
mkldnnUtils::getDims(dLdz, xRank, dLdzShape);
mkldnnUtils::getDims(dLdx, xRank, dLdxShape);
dnnl::memory::desc x_mkl_md, x_user_md, dLdx_mkl_md, dLdx_user_md, dLdz_mkl_md, dLdz_user_md;
dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank);
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
mkldnnUtils::setBlockStrides(x, x_user_md);
// x
x_user_md = x_mkl_md = dnnl::memory::desc(shape, dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*x));
mkldnnUtils::setBlockStrides(*x, x_user_md);
// dLdz
dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
mkldnnUtils::setBlockStrides(dLdz, dLdz_user_md);
dLdz_user_md = dLdz_mkl_md = dnnl::memory::desc(shape, dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*dLdz));
mkldnnUtils::setBlockStrides(*dLdz, dLdz_user_md);
// dLdx
dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
mkldnnUtils::setBlockStrides(dLdx, dLdx_user_md);
dLdx_user_md = dLdx_mkl_md = dnnl::memory::desc(shape, dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*dLdx));
mkldnnUtils::setBlockStrides(*dLdx, dLdx_user_md);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -162,23 +146,20 @@ namespace sd {
// provide memory buffers and check whether reorder is required for forward
// input
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// dLdz
mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
mkldnnUtils::loadDataToMklStream(*dLdz, engine, stream, dLdz_user_md, op_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
// dLdx
auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->buffer());
const bool dLdxReorder = op_prim_desc.diff_src_desc() != dLdx_user_mem.get_desc();
auto dLdx_mkl_mem = dLdxReorder ? dnnl::memory(op_prim_desc.diff_src_desc(), engine) : dLdx_user_mem;
args[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem;
auto dLdx_user_mem = mkldnnUtils::loadDataToMklStream(*dLdx, engine, stream, dLdx_user_md, op_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]);
// run calculations backward
dnnl::eltwise_backward(op_prim_desc).execute(stream, args);
// reorder outputs if necessary
if (dLdxReorder)
dnnl::reorder(dLdx_mkl_mem, dLdx_user_mem).execute(stream, dLdx_mkl_mem, dLdx_user_mem);
if (op_prim_desc.diff_src_desc() != dLdx_user_mem.get_desc())
dnnl::reorder(args[DNNL_ARG_DIFF_SRC], dLdx_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], dLdx_user_mem);
stream.wait();
}

View File

@ -82,33 +82,23 @@ namespace sd {
// memory descriptors for arrays
// x
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format);
mkldnnUtils::setBlockStrides(x, x_user_md);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, mkldnnUtils::getFormat(*x));
mkldnnUtils::setBlockStrides(*x, x_user_md);
// weights
dnnl::memory::desc weights_mkl_md = dnnl::memory::desc(wShape, wType, dnnl::memory::format_tag::any);
dnnl::memory::desc weights_user_md = dnnl::memory::desc(wShape, wType, format);
if (weights->ews() != 1 || weights->ordering() != 'c' || bShouldTransp) {
dnnl::memory::desc weights_user_md = dnnl::memory::desc(wShape, wType, mkldnnUtils::getFormat(*weights));
mkldnnUtils::setBlockStrides(*weights, weights_user_md, bShouldTransp ? std::vector<int>({1,0}) : std::vector<int>());
weights_user_md.data.format_kind = dnnl_blocked; // overrides format
if (bShouldTransp) {
weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(1);
weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(0);
}
else {
weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(0);
weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(1);
}
}
// bias
dnnl::memory::desc bias_mkl_md = dnnl::memory::desc(bShape, bType, dnnl::memory::format_tag::x);
dnnl::memory::desc bias_user_md = dnnl::memory::desc(bShape, bType, dnnl::memory::format_tag::x);
mkldnnUtils::setBlockStrides(bias, bias_user_md);
dnnl::memory::desc bias_mkl_md = dnnl::memory::desc(bShape, bType, dnnl::memory::format_tag::a);
dnnl::memory::desc bias_user_md = dnnl::memory::desc(bShape, bType, dnnl::memory::format_tag::a);
mkldnnUtils::setBlockStrides(*bias, bias_user_md);
// z
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, zType, dnnl::memory::format_tag::any);
dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, zType, format);
mkldnnUtils::setBlockStrides(z, z_user_md);
dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, zType, mkldnnUtils::getFormat(*z));
mkldnnUtils::setBlockStrides(*z, z_user_md);
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
@ -125,27 +115,24 @@ namespace sd {
// provide memory buffers and check whether reorder is required
// input
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
// weights
mkldnnUtils::loadDataToMklStream(weights, engine, stream, weights_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, weights_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
// bias
auto bias_mkl_mem = dnnl::memory(bias_mkl_md, engine, const_cast<void*>(bias->buffer()));
args[DNNL_ARG_BIAS] = bias_mkl_mem;
// z
auto z_user_mem = dnnl::memory(z_user_md, engine, z->buffer());
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
args[DNNL_ARG_DST] = z_mkl_mem;
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*z, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
// run calculations
dnnl::inner_product_forward(op_prim_desc).execute(stream, args);
// reorder outputs if necessary
if (zReorder)
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
stream.wait();
}
@ -160,7 +147,7 @@ namespace sd {
// [M,K] x [K,N] = [M,N]
const int M = x->sizeAt(0);
const int K = x->sizeAt(1); // K == wK
const int K = x->sizeAt(1); // K == wK
const int N = dLdz->sizeAt(1);
// input dims
dnnl::memory::dims xShape = dnnl::memory::dims({ M, K });
@ -168,71 +155,53 @@ namespace sd {
dnnl::memory::dims dLdzShape = dnnl::memory::dims({ M, N });
dnnl::memory::dims bShape = dnnl::memory::dims({ N });
// output dims
dnnl::memory::dims dLdxShape = xShape;
dnnl::memory::dims dLdwShape = wShape;
dnnl::memory::format_tag format = dnnl::memory::format_tag::ab;
dnnl::memory::data_type dataType = dnnl::memory::data_type::f32;
// memory descriptors for arrays
// x
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dataType, dnnl::memory::format_tag::any);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dataType, format);
mkldnnUtils::setBlockStrides(x, x_user_md);
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dataType, mkldnnUtils::getFormat(*x));
mkldnnUtils::setBlockStrides(*x, x_user_md);
// weights
dnnl::memory::desc weights_mkl_md = dnnl::memory::desc(wShape, dataType, dnnl::memory::format_tag::any);
dnnl::memory::desc weights_user_md = dnnl::memory::desc(wShape, dataType, format);
if (weights->ews() != 1 || weights->ordering() != 'c' || bShouldTransp) {
dnnl::memory::desc weights_user_md = dnnl::memory::desc(wShape, dataType, mkldnnUtils::getFormat(*weights));
mkldnnUtils::setBlockStrides(*weights, weights_user_md, bShouldTransp ? std::vector<int>({1,0}) : std::vector<int>());
weights_user_md.data.format_kind = dnnl_blocked; // overrides format
if (bShouldTransp) {
weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(1);
weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(0);
}
else {
weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(0);
weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(1);
}
}
// bias
dnnl::memory::desc bias_mkl_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x);
dnnl::memory::desc bias_user_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x);
mkldnnUtils::setBlockStrides(bias, bias_user_md);
dnnl::memory::desc bias_mkl_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::any);
dnnl::memory::desc bias_user_md = dnnl::memory::desc(bShape, dataType, mkldnnUtils::getFormat(*bias));
mkldnnUtils::setBlockStrides(*bias, bias_user_md);
// dLdz
dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(dLdzShape, dataType, dnnl::memory::format_tag::any);
dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(dLdzShape, dataType, format);
mkldnnUtils::setBlockStrides(dLdz, dLdz_user_md);
dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(dLdzShape, dataType, mkldnnUtils::getFormat(*dLdz));
mkldnnUtils::setBlockStrides(*dLdz, dLdz_user_md);
// dLdw
dnnl::memory::desc dLdw_mkl_md = dnnl::memory::desc(wShape, dataType, format);
dnnl::memory::desc dLdw_user_md = dnnl::memory::desc(wShape, dataType, format);
if (dLdw->ews() != 1 || dLdw->ordering() != 'c' || bShouldTransp) {
dLdw_user_md.data.format_kind = dnnl_blocked; // overrides format
if (bShouldTransp) {
dLdw_user_md.data.format_desc.blocking.strides[0] = dLdw->strideAt(1);
dLdw_user_md.data.format_desc.blocking.strides[1] = dLdw->strideAt(0);
}
else {
dLdw_user_md.data.format_desc.blocking.strides[0] = dLdw->strideAt(0);
dLdw_user_md.data.format_desc.blocking.strides[1] = dLdw->strideAt(1);
}
}
dnnl::memory::desc dLdw_mkl_md = dnnl::memory::desc(wShape, dataType, dnnl::memory::format_tag::any);
dnnl::memory::desc dLdw_user_md = dnnl::memory::desc(wShape, dataType, mkldnnUtils::getFormat(*dLdw));
mkldnnUtils::setBlockStrides(*dLdw, dLdw_user_md, bShouldTransp ? std::vector<int>({1,0}) : std::vector<int>());
// dLdb
dnnl::memory::desc dLdb_mkl_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x);
dnnl::memory::desc dLdb_user_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x);
mkldnnUtils::setBlockStrides(dLdb, dLdb_user_md);
dnnl::memory::desc dLdb_mkl_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::any);
dnnl::memory::desc dLdb_user_md = dnnl::memory::desc(bShape, dataType, mkldnnUtils::getFormat(*dLdb));
mkldnnUtils::setBlockStrides(*dLdb, dLdb_user_md);
// dLdx
dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(xShape, dataType, dnnl::memory::format_tag::any);
dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(xShape, dataType, format);
mkldnnUtils::setBlockStrides(dLdx, dLdx_user_md);
dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(xShape, dataType, mkldnnUtils::getFormat(*dLdx));
mkldnnUtils::setBlockStrides(*dLdx, dLdx_user_md);
// create engine
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
// forward
// operation primitive description
dnnl::inner_product_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, weights_mkl_md, bias_mkl_md, dLdz_mkl_md);
@ -254,34 +223,25 @@ namespace sd {
dnnl::stream stream(engine);
// dLdz dw
mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_bpdw_prim_desc.diff_dst_desc(), argsDw[DNNL_ARG_DIFF_DST]);
mkldnnUtils::loadDataToMklStream(*dLdz, engine, stream, dLdz_user_md, op_bpdw_prim_desc.diff_dst_desc(), argsDw[DNNL_ARG_DIFF_DST]);
// dLdz - dx
mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_bpdx_prim_desc.diff_dst_desc(), argsDx[DNNL_ARG_DIFF_DST]);
mkldnnUtils::loadDataToMklStream(*dLdz, engine, stream, dLdz_user_md, op_bpdx_prim_desc.diff_dst_desc(), argsDx[DNNL_ARG_DIFF_DST]);
// input x for dw
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_bpdw_prim_desc.src_desc(), argsDw[DNNL_ARG_SRC]);
mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_bpdw_prim_desc.src_desc(), argsDw[DNNL_ARG_SRC]);
// weights - dx
mkldnnUtils::loadDataToMklStream(weights, engine, stream, weights_user_md, op_bpdx_prim_desc.weights_desc(), argsDx[DNNL_ARG_WEIGHTS]);
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, weights_user_md, op_bpdx_prim_desc.weights_desc(), argsDx[DNNL_ARG_WEIGHTS]);
// dLdw
auto dLdw_user_mem = dnnl::memory(dLdw_user_md, engine, dLdw->buffer());
const bool dLdwReorder = op_bpdw_prim_desc.diff_weights_desc() != dLdw_user_mem.get_desc();
auto dLdw_mkl_mem = dLdwReorder ? dnnl::memory(op_bpdw_prim_desc.diff_weights_desc(), engine) : dLdw_user_mem;
argsDw[DNNL_ARG_DIFF_WEIGHTS] = dLdw_mkl_mem;
// dLdw
auto dLdw_user_mem = mkldnnUtils::loadDataToMklStream(*dLdw, engine, stream, dLdw_user_md, op_bpdw_prim_desc.diff_weights_desc(), argsDw[DNNL_ARG_DIFF_WEIGHTS]);
// dLdx
auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->buffer());
const bool dLdxReorder = op_bpdx_prim_desc.diff_src_desc() != dLdx_user_mem.get_desc();
auto dLdx_mkl_mem = dLdxReorder ? dnnl::memory(op_bpdx_prim_desc.diff_src_desc(), engine) : dLdx_user_mem;
argsDx[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem;
// dLdx
auto dLdx_user_mem = mkldnnUtils::loadDataToMklStream(*dLdx, engine, stream, dLdx_user_md, op_bpdx_prim_desc.diff_src_desc(), argsDx[DNNL_ARG_DIFF_SRC]);
// dLdb
auto dLdb_user_mem = dnnl::memory(dLdb_user_md, engine, dLdb->buffer());
const bool dLdbReorder = op_bpdw_prim_desc.diff_bias_desc() != dLdb_user_mem.get_desc();
auto dLdb_mkl_mem = dLdbReorder ? dnnl::memory(op_bpdw_prim_desc.diff_bias_desc(), engine) : dLdb_user_mem;
argsDw[DNNL_ARG_DIFF_BIAS] = dLdb_mkl_mem;
auto dLdb_user_mem = mkldnnUtils::loadDataToMklStream(*dLdb, engine, stream, dLdb_user_md, op_bpdw_prim_desc.diff_bias_desc(), argsDw[DNNL_ARG_DIFF_BIAS]);
// run calculations dw
dnnl::inner_product_backward_weights(op_bpdw_prim_desc).execute(stream, argsDw);
@ -289,14 +249,14 @@ namespace sd {
dnnl::inner_product_backward_data(op_bpdx_prim_desc).execute(stream, argsDx);
// reorder outputs if necessary
if (dLdxReorder)
dnnl::reorder(dLdx_mkl_mem, dLdx_user_mem).execute(stream, dLdx_mkl_mem, dLdx_user_mem);
if (op_bpdx_prim_desc.diff_src_desc() != dLdx_user_mem.get_desc())
dnnl::reorder(argsDx[DNNL_ARG_DIFF_SRC], dLdx_user_mem).execute(stream, argsDx[DNNL_ARG_DIFF_SRC], dLdx_user_mem);
if (dLdwReorder)
dnnl::reorder(dLdw_mkl_mem, dLdw_user_mem).execute(stream, dLdw_mkl_mem, dLdw_user_mem);
if (op_bpdw_prim_desc.diff_weights_desc() != dLdw_user_mem.get_desc())
dnnl::reorder(argsDw[DNNL_ARG_DIFF_WEIGHTS], dLdw_user_mem).execute(stream, argsDw[DNNL_ARG_DIFF_WEIGHTS], dLdw_user_mem);
if (dLdbReorder)
dnnl::reorder(dLdb_mkl_mem, dLdb_user_mem).execute(stream, dLdb_mkl_mem, dLdb_user_mem);
if (op_bpdw_prim_desc.diff_bias_desc() != dLdb_user_mem.get_desc())
dnnl::reorder(argsDw[DNNL_ARG_DIFF_BIAS], dLdb_user_mem).execute(stream, argsDw[DNNL_ARG_DIFF_BIAS], dLdb_user_mem);
stream.wait();
}
@ -315,7 +275,7 @@ namespace sd {
const int wRank = w->rankOf();
const int zRank = z->rankOf();
const bool bShouldTransp = block.getIArguments()->size() > 0 ? (1 != INT_ARG(0)) : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N]
const bool bShouldTransp = block.getIArguments()->size() > 0 ? (1 != INT_ARG(0)) : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N]
REQUIRE_TRUE(xRank == 2, 0, "xw_plus_b MKL: Input x array should have rank equal 2, but got instead %i!", xRank);
REQUIRE_TRUE(wRank == 2, 0, "xw_plus_b MKL: Input weights array should have rank equal 2, but got instead %i!", wRank);
@ -378,7 +338,7 @@ namespace sd {
const int wRank = w->rankOf();
const int dLdzRank = dLdz->rankOf();
const bool bShouldTransp = block.getIArguments()->size() > 0 ? (1 != INT_ARG(0)) : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N]
const bool bShouldTransp = block.getIArguments()->size() > 0 ? (1 != INT_ARG(0)) : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N]
REQUIRE_TRUE(x->rankOf() == 2, 0, "xw_plus_b BP MKL: Input x array should have rank equal 2, but got instead %i!", x->rankOf());
REQUIRE_TRUE(w->rankOf() == 2, 0, "xw_plus_b BP MKL: Input weights array should have rank equal 2, but got instead %i!", w->rankOf());

View File

@ -107,6 +107,25 @@ namespace sd {
// samediff::Threads::parallel_tad(func, 0, numOfArrs);
// }
// static Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo) {
// Nd4jLong result = 9223372036854775807LL;
// for(uint i = 0; i < shape::rank(inShapeInfo); ++i) {
// const auto currentStride = shape::stride(inShapeInfo)[i];
// if(i == axis || shape::shapeOf(inShapeInfo)[i] == 1)
// continue;
// if(result > currentStride)
// result = currentStride;
// }
// return result == 9223372036854775807LL ? 1 : result;
// }
template <typename T>
void SpecialMethods<T>::concatCpuGeneric(const std::vector<const NDArray*>& inArrs, NDArray& output, const int axis) {
@ -150,7 +169,7 @@ void SpecialMethods<T>::concatCpuGeneric(const std::vector<const NDArray*>& inAr
// if(!areInputsContin || !allSameOrder)
// break;
// strideOfContigStride[i] = shape::strideOverContigAxis(axis, inArrs[i]->shapeInfo());
// strideOfContigStride[i] = strideOverContigAxis(axis, inArrs[i]->getShapeInfo());
// }
// }
@ -158,7 +177,7 @@ void SpecialMethods<T>::concatCpuGeneric(const std::vector<const NDArray*>& inAr
// if(luckCase2) { // for example {2,1,3} + {2,5,3} + {2,10,3} = {2,16,3}, here axis 1 shoud have stride = 1 for all inputs arrays and output array
// const auto zStep = shape::strideOverContigAxis(axis, output.shapeInfo());
// const auto zStep = strideOverContigAxis(axis, output.getShapeInfo());
// for (uint i = 0; i < output.lengthOf() / output.sizeAt(axis); ++i) {

View File

@ -184,7 +184,7 @@ TEST_F(DeclarableOpsTests16, test_range_2) {
double tArgs[] = { -1.0, 1.0, 0.01 };
auto shapes = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 3, nullptr, 0, nullptr, 0, nullptr, 0);
shape::printShapeInfoLinear("Result", shapes->at(0));
// shape::printShapeInfoLinear("Result", shapes->at(0));
ASSERT_TRUE(shape::shapeEquals(z.shapeInfo(), shapes->at(0)));
delete shapes;
@ -426,7 +426,7 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_6) {
0.928968489f, 0.684074104f
});
//get subarray
//get subarray
//get subarray
NDArray subArrRgbs = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) });
NDArray expected = hsvs.subarray({ NDIndex::all(), NDIndex::point(0) });
@ -627,7 +627,7 @@ TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_6) {
});
auto actual = NDArrayFactory::create<float>('c', { 3 });
//get subarray
//get subarray
NDArray subArrHsvs = hsvs.subarray({ NDIndex::all(), NDIndex::point(0) });
subArrHsvs.reshapei({ 3 });
NDArray expected = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) });
@ -635,7 +635,7 @@ TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_6) {
#if 0
//[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER]
subArrHsvs.printShapeInfo("subArrHsvs");
#endif
#endif
Context ctx(1);
ctx.setInputArray(0, &subArrHsvs);
@ -855,7 +855,7 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_6) {
-0.04447775f, -0.44518381f
});
//get subarray
//get subarray
NDArray subArrRgbs = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) });
NDArray expected = yiqs.subarray({ NDIndex::all(), NDIndex::point(0) });
subArrRgbs.reshapei({ 3 });
@ -1054,7 +1054,7 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_6) {
0.280231822f, 1.91936605f
});
//get subarray
//get subarray
NDArray subArrYiqs = yiqs.subarray({ NDIndex::all(), NDIndex::point(0) });
NDArray expected = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) });
subArrYiqs.reshapei({ 3 });
@ -1074,3 +1074,422 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_6) {
ASSERT_EQ(ND4J_STATUS_OK, status);
ASSERT_TRUE(expected.equalsTo(actual));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests16, clipbynorm_1) {
auto x= NDArrayFactory::create<double>('c', {2, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0});
auto exp= NDArrayFactory::create<double>('c', {2, 3}, {-2.4, 0.0, 0.0, 3.2, 0.0, 0.0});
sd::ops::clipbynorm op;
auto result = op.evaluate({&x}, {4.0}, {});
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests16, clipbynorm_2) {
auto x= NDArrayFactory::create<double>('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f});
auto exp= NDArrayFactory::create<double>('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f});
sd::ops::clipbynorm op;
auto result = op.evaluate({&x}, {6.0}, {});
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests16, clipbynorm_3) {
auto x = NDArrayFactory::create<double>('c', {3, 5});
auto unities = NDArrayFactory::create<double>('c', {3, 1}, {1., 1., 1.});
auto scale = NDArrayFactory::create<double>('c', {3, 1}, {1.1, 1., 0.9});
x.linspace(100.);
auto xNorm1 = x.reduceAlongDimension(reduce::Norm2, {1}, true);
x /= xNorm1;
xNorm1 = x.reduceAlongDimension(reduce::Norm2,{1}, true);
ASSERT_TRUE(unities.isSameShape(xNorm1));
ASSERT_TRUE(unities.equalsTo(xNorm1));
x *= scale;
xNorm1 = x.reduceAlongDimension(reduce::Norm2, {1}, true);
sd::ops::clipbynorm op;
auto result = op.evaluate({&x}, {1.0}, {1});
auto z = result.at(0);
auto zNorm1 = z->reduceAlongDimension(reduce::Norm2, {1}, true);
auto exp = NDArrayFactory::create<double>('c', {3, 1}, {1., 1., xNorm1.e<double>(2)});
ASSERT_TRUE(exp.isSameShape(&zNorm1));
ASSERT_TRUE(exp.equalsTo(&zNorm1));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests16, clipbynorm_4) {
auto x = NDArrayFactory::create<double>('c', {3, 5}, {0.7044955, 0.55606544, 0.15833677, 0.001874401, 0.61595726, 0.3924779, 0.7414847, 0.4127324, 0.24026828, 0.26093036, 0.46741188, 0.01863421, 0.08528871, 0.529365, 0.5510694});
auto exp = NDArrayFactory::create<double>('c', {3, 5}, {0.405392, 0.319980, 0.091113, 0.001079, 0.354444, 0.225846, 0.426676, 0.237501, 0.138259, 0.150149, 0.268965, 0.010723, 0.049078, 0.304615, 0.317105});
sd::ops::clipbynorm op;
auto result = op.evaluate({&x}, {1.f}, {});
auto output = result.at(0);
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests16, clipbynorm_5) {
// auto x = NDArrayFactory::create<double>('c', {3, 5}, {1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5});
auto x = NDArrayFactory::create<double>('c', {3, 5});
auto exp = NDArrayFactory::create<double>('c', {3, 5}, {1., 2., 2.89271, 3.50524, 4.00892, 6., 7., 7.71389, 7.88678, 8.01784, 11., 12., 12.53507, 12.26833, 12.02676});
// auto exp = NDArrayFactory::create<double>('c', {3, 5}, {1,1,1,1,1,1,1,1,1,1,1,1,1,1,1});
x.linspace(1);
sd::ops::clipbynorm op;
auto result = op.evaluate({&x}, {15.f}, {0});
auto output = result.at(0);
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests16, clipbynorm_6) {
auto x = NDArrayFactory::create<double>('c', {3, 5});
auto exp = NDArrayFactory::create<double>('c', {3, 5}, {1., 2., 3., 4., 5., 4.95434, 5.78006, 6.60578, 7.43151, 8.25723, 5.64288, 6.15587, 6.66886, 7.18185, 7.69484});
x.linspace(1);
sd::ops::clipbynorm op;
auto result = op.evaluate({&x}, {15.f}, {1});
auto output = result.at(0);
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests16, clipbynorm_7) {
auto x = NDArrayFactory::create<double>('c', {3, 5});
auto exp = NDArrayFactory::create<double>('c', {3, 5}, {0.42597, 0.85194, 1.27791, 1.70389, 2.12986, 2.55583, 2.9818 , 3.40777, 3.83374, 4.25971, 4.68569, 5.11166, 5.53763, 5.9636 , 6.38957});
x.linspace(1);
sd::ops::clipbynorm op;
auto result = op.evaluate({&x}, {15.f}, {0,1});
auto output = result.at(0);
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests16, clipbynorm_8) {
auto x = NDArrayFactory::create<double>('c', {3, 5});
auto exp = NDArrayFactory::create<double>('c', {3, 5}, {0.42597, 0.85194, 1.27791, 1.70389, 2.12986, 2.55583, 2.9818 , 3.40777, 3.83374, 4.25971, 4.68569, 5.11166, 5.53763, 5.9636 , 6.38957});
x.linspace(1);
sd::ops::clipbynorm op;
auto result = op.evaluate({&x}, {15.}, {});
auto output = result.at(0);
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests16, clipbynorm_9) {
auto x = NDArrayFactory::create<double>('c', {2}, {3., 4.});
auto exp = NDArrayFactory::create<double>('c', {2}, {2.4, 3.2});
sd::ops::clipbynorm op;
auto result = op.evaluate({&x}, {4.}, {});
auto output = result.at(0);
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests16, clipbynorm_10) {
auto x = NDArrayFactory::create<double>(6.);
auto exp = NDArrayFactory::create<double>(5.);
sd::ops::clipbynorm op;
auto result = op.evaluate({&x}, {5.}, {});
auto output = result.at(0);
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests16, clipbynorm_11) {
auto x = NDArrayFactory::create<double>('c', {2, 3, 4});
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {1., 2., 3., 4., 4.44787, 5.33745, 6.22702, 7.1166 , 6.33046, 7.03384, 7.73723, 8.44061,
13., 14., 15., 16., 15.12277, 16.01235, 16.90192, 17.7915 ,14.77107, 15.47446, 16.17784, 16.88123});
x.linspace(1);
sd::ops::clipbynorm op;
auto result = op.evaluate({&x}, {35.}, {0, 2});
auto output = result.at(0);
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests16, clipbynorm_12) {
auto x = NDArrayFactory::create<double>('c', {3, 3}, {1, 2, 3, 4, 5,6, 7, 8, 9});
auto e = NDArrayFactory::create<double>('c', {3, 3}, {0.03198684, 0.06397368, 0.09596053, 0.12794736, 0.15993419, 0.19192106, 0.22390789, 0.25589472, 0.28788155});
sd::ops::clipbynorm op;
auto result = op.evaluate({&x}, {0.54}, {});
ASSERT_EQ(e, *result.at(0));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests16, clipbynorm_13) {
const int bS = 5;
const int nOut = 4;
const int axis = 0;
const double clip = 2.;
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.897 ,0.173 ,0.931 ,0.736 ,0.540 ,0.953 ,0.278 ,0.573 ,0.787 ,0.320 ,0.776 ,0.338 ,0.311 ,0.835 ,0.909 ,0.890 ,0.290}); // uniform random in range [0,1]
auto colVect = NDArrayFactory::create<double>('c', {bS, 1}, {0.9, 0.95, 1.00, 1.05, 1.1});
auto expect = NDArrayFactory::create<double>('c', {bS, nOut});
auto norm2 = x.reduceAlongDimension(reduce::Norm2, {axis}, true); // norm2 has shape [1, nOut]
auto y = ( (x / norm2) * clip) * colVect ;
auto temp = (x / norm2) * clip;
for (int j = 0; j < nOut; ++j) {
auto yCol = y({0,0, j,j+1});
const double norm2Col = yCol.reduceNumber(reduce::Norm2).e<double>(0);
if (norm2Col <= clip)
expect({0,0, j,j+1}).assign(yCol);
else
expect({0,0, j,j+1}).assign ( yCol * (clip / norm2Col) );
}
sd::ops::clipbynorm op;
auto result = op.evaluate({&y}, {clip}, {axis});
auto outFF = result.at(0);
ASSERT_TRUE(expect.isSameShape(outFF));
ASSERT_TRUE(expect.equalsTo(outFF));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests16, clipbynorm_bp_1) {
const int bS = 2;
const int nOut = 3;
const double clip = 0.7;
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1]
auto gradO = NDArrayFactory::create<double>('c', {bS, nOut});
const OpArgsHolder argsHolderFF({&x}, {clip}, {});
const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {});
sd::ops::clipbynorm opFF;
sd::ops::clipbynorm_bp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
ASSERT_TRUE(isGradCorrect);
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests16, clipbynorm_bp_2) {
const int bS = 2;
const int nOut = 3;
const int axis = 0;
const double clip = 0.7;
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1]
auto gradO = NDArrayFactory::create<double>('c', {bS, nOut});
const OpArgsHolder argsHolderFF({&x}, {clip}, {axis});
const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {axis});
sd::ops::clipbynorm opFF;
sd::ops::clipbynorm_bp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
ASSERT_TRUE(isGradCorrect);
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests16, clipbynorm_bp_3) {
const int bS = 2;
const int nOut = 3;
const int axis = 1;
const double clip = 1.;
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1]
auto gradO = NDArrayFactory::create<double>('c', {bS, nOut});
const OpArgsHolder argsHolderFF({&x}, {clip}, {axis});
const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {axis});
sd::ops::clipbynorm opFF;
sd::ops::clipbynorm_bp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
ASSERT_TRUE(isGradCorrect);
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests16, clipbyavgnorm_1) {
auto x = NDArrayFactory::create<double>('c', {2, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0});
auto exp = NDArrayFactory::create<double>('c', {2, 3}, {-2.88, 0.0, 0.0, 3.84, 0.0, 0.0});
sd::ops::clipbyavgnorm op;
auto result = op.evaluate({&x}, {0.8}, {});
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests16, clipbyavgnorm_2) {
auto x= NDArrayFactory::create<float>('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f});
auto exp= NDArrayFactory::create<float>('c', {2, 3}, {-3.f, 0.0f, 0.0f, 4.f, 0.0f, 0.0f});
sd::ops::clipbyavgnorm op;
auto result = op.evaluate({&x}, {0.9}, {});
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests16, clipbyavgnorm_bp_1) {
const int bS = 2;
const int nOut = 3;
const double clip = 0.7;
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1]
auto gradO = NDArrayFactory::create<double>('c', {bS, nOut});
const OpArgsHolder argsHolderFF({&x}, {clip}, {});
const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {});
sd::ops::clipbyavgnorm opFF;
sd::ops::clipbyavgnorm_bp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
ASSERT_TRUE(isGradCorrect);
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests16, clipbyavgnorm_bp_2) {
const int bS = 2;
const int nOut = 3;
const int axis = 1;
const double clip = 1.;
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1]
auto gradO = NDArrayFactory::create<double>('c', {bS, nOut});
const OpArgsHolder argsHolderFF({&x}, {clip}, {axis});
const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {axis});
sd::ops::clipbyavgnorm opFF;
sd::ops::clipbyavgnorm_bp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
ASSERT_TRUE(isGradCorrect);
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests16, clipbyavgnorm_bp_3) {
NDArray x('c', {2, 3, 4}, {-0.14 ,0.96 ,0.47 ,-0.98 ,0.03 ,0.95 ,0.33 ,-0.97 ,0.59 ,-0.92 ,-0.12 ,-0.33 ,0.82 ,-0.76 ,-0.69 ,-0.95 ,-0.77 ,0.25 ,-0.35 ,0.94 ,0.50 ,0.04 ,0.61 ,0.99}, sd::DataType::DOUBLE);
NDArray gradO('c', {2, 3, 4}, sd::DataType::DOUBLE);
const OpArgsHolder argsHolderFF({&x}, {0.7}, {0,2});
const OpArgsHolder argsHolderBP({&x, &gradO}, {0.7}, {0,2});
sd::ops::clipbyavgnorm opFF;
sd::ops::clipbyavgnorm_bp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
ASSERT_TRUE(isGradCorrect);
}

View File

@ -50,7 +50,7 @@ TEST_F(DeclarableOpsTests3, Test_Tile_1) {
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
@ -68,7 +68,7 @@ TEST_F(DeclarableOpsTests3, Test_Tile_2) {
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests3, Test_Permute_1) {
@ -123,7 +123,7 @@ TEST_F(DeclarableOpsTests3, Test_Unique_1) {
ASSERT_TRUE(expI.isSameShape(i));
ASSERT_TRUE(expI.equalsTo(i));
}
TEST_F(DeclarableOpsTests3, Test_Unique_2) {
@ -171,7 +171,7 @@ TEST_F(DeclarableOpsTests3, Test_Rint_1) {
ASSERT_TRUE(exp.equalsTo(z));
}
@ -226,7 +226,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_2) {
ASSERT_TRUE(exp0.isSameShape(z0));
ASSERT_TRUE(exp0.equalsTo(z0));
auto result1 = op.evaluate({&x, &axis}, {1}, {});
@ -244,94 +244,6 @@ TEST_F(DeclarableOpsTests3, Test_Norm_2) {
}
TEST_F(DeclarableOpsTests3, Test_ClipByAvgNorm_1) {
auto x = NDArrayFactory::create<double>('c', {2, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0});
auto exp = NDArrayFactory::create<double>('c', {2, 3}, {-2.88, 0.0, 0.0, 3.84, 0.0, 0.0});
sd::ops::clipbyavgnorm op;
auto result = op.evaluate({&x}, {0.8}, {});
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests3, Test_ClipByAvgNorm_2) {
auto x= NDArrayFactory::create<float>('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f});
auto exp= NDArrayFactory::create<float>('c', {2, 3}, {-3.f, 0.0f, 0.0f, 4.f, 0.0f, 0.0f});
sd::ops::clipbyavgnorm op;
auto result = op.evaluate({&x}, {0.9}, {});
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests3, Test_ClipByNorm_1) {
auto x= NDArrayFactory::create<double>('c', {2, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0});
auto exp= NDArrayFactory::create<double>('c', {2, 3}, {-2.4, 0.0, 0.0, 3.2, 0.0, 0.0});
sd::ops::clipbynorm op;
auto result = op.evaluate({&x}, {4.0}, {});
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests3, Test_ClipByNorm_2) {
auto x= NDArrayFactory::create<double>('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f});
auto exp= NDArrayFactory::create<double>('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f});
sd::ops::clipbynorm op;
auto result = op.evaluate({&x}, {6.0}, {});
auto z = result.at(0);
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests3, Test_ClipByNorm_3) {
auto x = NDArrayFactory::create<double>('c', {3, 5});
auto unities = NDArrayFactory::create<double>('c', {3, 1}, {1., 1., 1.});
auto scale = NDArrayFactory::create<double>('c', {3, 1}, {1.1, 1., 0.9});
x.linspace(100.);
auto xNorm1 = x.reduceAlongDimension(reduce::Norm2, {1}, true);
x /= xNorm1;
xNorm1 = x.reduceAlongDimension(reduce::Norm2,{1}, true);
ASSERT_TRUE(unities.isSameShape(xNorm1));
ASSERT_TRUE(unities.equalsTo(xNorm1));
x *= scale;
xNorm1 = x.reduceAlongDimension(reduce::Norm2, {1}, true);
sd::ops::clipbynorm op;
auto result = op.evaluate({&x}, {1.0}, {1});
auto z = result.at(0);
auto zNorm1 = z->reduceAlongDimension(reduce::Norm2, {1}, true);
auto exp = NDArrayFactory::create<double>('c', {3, 1}, {1., 1., xNorm1.e<double>(2)});
ASSERT_TRUE(exp.isSameShape(&zNorm1));
ASSERT_TRUE(exp.equalsTo(&zNorm1));
}
TEST_F(DeclarableOpsTests3, Test_ListDiff_1) {
auto x= NDArrayFactory::create<float>('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
auto y= NDArrayFactory::create<float>('c', {3}, {1.f, 3.f, 5.f});
@ -551,7 +463,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_1) {
}
delete exp;
}
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_2) {
@ -579,7 +491,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_2) {
}
delete exp;
}
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_3) {
@ -607,7 +519,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_3) {
}
delete exp;
}
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_4) {
@ -635,7 +547,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_4) {
}
delete exp;
}
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_5) {
@ -663,7 +575,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_5) {
}
delete exp;
}
@ -692,7 +604,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_6) {
}
delete exp;
}
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_7) {
@ -722,7 +634,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_7) {
}
delete exp;
}
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_Validation_1) {
@ -734,7 +646,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_Validation_1) {
sd::ops::batched_gemm op;
try {
auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3});
ASSERT_TRUE(false);
} catch (std::invalid_argument &e) {
//
@ -875,7 +787,7 @@ TEST_F(DeclarableOpsTests3, sruCell_test3) {
ASSERT_TRUE(expCt.isSameShape(ct));
ASSERT_TRUE(expCt.equalsTo(ct));
}
////////////////////////////////////////////////////////////////////
@ -946,7 +858,7 @@ TEST_F(DeclarableOpsTests3, gruCell_test2) {
ASSERT_TRUE(expHt.isSameShape(ht));
ASSERT_TRUE(expHt.equalsTo(ht));
}
////////////////////////////////////////////////////////////////////
@ -1001,7 +913,7 @@ TEST_F(DeclarableOpsTests3, invertPermutation_test1) {
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
}
///////////////////////////////////////////////////////////////////
@ -1021,7 +933,7 @@ TEST_F(DeclarableOpsTests3, invertPermutation_test2) {
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
}
///////////////////////////////////////////////////////////////////
@ -1099,7 +1011,7 @@ TEST_F(DeclarableOpsTests3, diag_test_vector) {
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
delete input;
}
@ -1120,7 +1032,7 @@ TEST_F(DeclarableOpsTests3, diag_test_col_vector) {
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
delete input;
}
///////////////////////////////////////////////////////////////////
@ -1245,7 +1157,7 @@ TEST_F(DeclarableOpsTests3, matrixSetDiag_test2) {
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
}
///////////////////////////////////////////////////////////////////
@ -1551,7 +1463,7 @@ TEST_F(DeclarableOpsTests3, betainc_test8) {
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output, 1e-6));
}
///////////////////////////////////////////////////////////////////
@ -1576,7 +1488,7 @@ TEST_F(DeclarableOpsTests3, betainc_test9) {
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
}
///////////////////////////////////////////////////////////////////
@ -1642,7 +1554,7 @@ TEST_F(DeclarableOpsTests3, betainc_test12) {
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
}
///////////////////////////////////////////////////////////////////
@ -1689,7 +1601,7 @@ TEST_F(DeclarableOpsTests3, zeta_test2) {
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
}
///////////////////////////////////////////////////////////////////
@ -1831,7 +1743,7 @@ TEST_F(DeclarableOpsTests3, zeta_test8) {
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
}
///////////////////////////////////////////////////////////////////
@ -1856,7 +1768,7 @@ TEST_F(DeclarableOpsTests3, zeta_test9) {
ASSERT_TRUE(expected.isSameShape(z));
ASSERT_TRUE(expected.equalsTo(z));
//
//
}
///////////////////////////////////////////////////////////////////
@ -1881,7 +1793,7 @@ TEST_F(DeclarableOpsTests3, zeta_test10) {
ASSERT_TRUE(expected.isSameShape(z));
ASSERT_TRUE(expected.equalsTo(z));
//
//
}
@ -1908,7 +1820,7 @@ TEST_F(DeclarableOpsTests3, polygamma_test1) {
x.assign(0.5);
auto expected= NDArrayFactory::create<double>('c', {3,3}, {4.934802, -16.828796, 97.409088, -771.474243, 7691.113770, -92203.460938, 1290440.250000, -20644900.000000, 3.71595e+08});
sd::ops::polygamma op;
auto result = op.evaluate({&n, &x}, {}, {});
@ -1920,7 +1832,7 @@ TEST_F(DeclarableOpsTests3, polygamma_test1) {
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
}
///////////////////////////////////////////////////////////////////
@ -2263,7 +2175,7 @@ TEST_F(DeclarableOpsTests3, svd_test7) {
ASSERT_TRUE(expS.equalsTo(s));
ASSERT_TRUE(expS.isSameShape(s));
}
///////////////////////////////////////////////////////////////////
@ -2416,7 +2328,7 @@ TEST_F(DeclarableOpsTests3, svd_test7) {
// ASSERT_NEAR(sd::math::nd4j_abs(expV.e<float>(i)), sd::math::nd4j_abs(v->e<float>(i)), 1e-5);
// }
//
//
// }
///////////////////////////////////////////////////////////////////

File diff suppressed because it is too large Load Diff

View File

@ -57,7 +57,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_1) {
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_2) {
@ -78,7 +78,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_2) {
ASSERT_EQ(exp, *z);
}
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_3) {
@ -100,7 +100,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_3) {
ASSERT_TRUE(z->isEmpty());
//ASSERT_EQ(exp, *z);
}
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_4) {
@ -122,7 +122,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_4) {
ASSERT_TRUE(z->equalsTo(exp));
//ASSERT_EQ(exp, *z);
}
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_04) {
@ -185,7 +185,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_5) {
auto z = result.at(0);
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_6) {
@ -205,7 +205,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_6) {
auto z = result.at(0);
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_7) {
@ -226,7 +226,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_7) {
auto z = result.at(0);
//ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_1) {
@ -248,7 +248,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_1) {
auto z = result.at(0);
//ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_2) {
@ -270,7 +270,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_2) {
auto z = result.at(0);
//ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_3) {
@ -292,7 +292,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_3) {
auto z = result.at(0);
//ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) {
@ -309,7 +309,7 @@ TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) {
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests6, Test_Order_1) {
@ -326,7 +326,7 @@ TEST_F(DeclarableOpsTests6, Test_Order_1) {
ASSERT_TRUE(exp.equalsTo(z));
ASSERT_NE(x.ordering(), z->ordering());
}
TEST_F(DeclarableOpsTests6, cumSum_1) {
@ -342,7 +342,7 @@ TEST_F(DeclarableOpsTests6, cumSum_1) {
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests6, cumSum_2) {
@ -359,7 +359,7 @@ TEST_F(DeclarableOpsTests6, cumSum_2) {
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests6, cumSum_3) {
@ -375,7 +375,7 @@ TEST_F(DeclarableOpsTests6, cumSum_3) {
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests6, cumSum_4) {
@ -391,7 +391,7 @@ TEST_F(DeclarableOpsTests6, cumSum_4) {
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests6, cumSum_5) {
@ -406,7 +406,7 @@ TEST_F(DeclarableOpsTests6, cumSum_5) {
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests6, cumSum_6) {
@ -421,7 +421,7 @@ TEST_F(DeclarableOpsTests6, cumSum_6) {
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests6, cumSum_7) {
@ -436,7 +436,7 @@ TEST_F(DeclarableOpsTests6, cumSum_7) {
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests6, cumSum_8) {
@ -452,7 +452,7 @@ TEST_F(DeclarableOpsTests6, cumSum_8) {
ASSERT_TRUE(exp.equalsTo(z));
}
////////////////////////////////////////////////////////////////////////////////
@ -477,7 +477,7 @@ TEST_F(DeclarableOpsTests6, cumSum_9) {
ASSERT_EQ(Status::OK(), result.status());
auto z = result.at(0);
ASSERT_TRUE(expFF.equalsTo(z));
//************************************//
exclusive = 1; reverse = 0;
@ -486,7 +486,7 @@ TEST_F(DeclarableOpsTests6, cumSum_9) {
ASSERT_EQ(Status::OK(), result.status());
z = result.at(0);
ASSERT_TRUE(expTF.equalsTo(z));
//************************************//
exclusive = 0; reverse = 1;
@ -495,7 +495,7 @@ TEST_F(DeclarableOpsTests6, cumSum_9) {
ASSERT_EQ(Status::OK(), result.status());
z = result.at(0);
ASSERT_TRUE(expFT.equalsTo(z));
//************************************//
exclusive = 1; reverse = 1;
@ -504,7 +504,7 @@ TEST_F(DeclarableOpsTests6, cumSum_9) {
ASSERT_EQ(Status::OK(), result.status());
z = result.at(0);
ASSERT_TRUE(expTT.equalsTo(z));
}
@ -517,7 +517,7 @@ TEST_F(DeclarableOpsTests6, cumSum_10) {
auto result = op.evaluate({&x, &y}, {}, {1, 1});
ASSERT_EQ(Status::OK(), result.status());
}
////////////////////////////////////////////////////////////////////////////////
@ -536,7 +536,7 @@ TEST_F(DeclarableOpsTests6, cumSum_11) {
ASSERT_TRUE(exp.equalsTo(z));
}
////////////////////////////////////////////////////////////////////////////////
@ -555,7 +555,7 @@ TEST_F(DeclarableOpsTests6, cumSum_12) {
ASSERT_TRUE(exp.equalsTo(z));
}
////////////////////////////////////////////////////////////////////////////////
@ -574,7 +574,7 @@ TEST_F(DeclarableOpsTests6, cumSum_13) {
ASSERT_TRUE(exp.equalsTo(z));
}
////////////////////////////////////////////////////////////////////////////////
@ -593,7 +593,7 @@ TEST_F(DeclarableOpsTests6, cumSum_14) {
ASSERT_TRUE(exp.equalsTo(z));
}
////////////////////////////////////////////////////////////////////////////////
@ -612,7 +612,7 @@ TEST_F(DeclarableOpsTests6, cumSum_15) {
ASSERT_TRUE(exp.equalsTo(z));
}
////////////////////////////////////////////////////////////////////////////////
@ -631,7 +631,7 @@ TEST_F(DeclarableOpsTests6, cumSum_16) {
ASSERT_TRUE(z->ews() == 1);
ASSERT_TRUE(x.ews() == 1);
}
////////////////////////////////////////////////////////////////////////////////
@ -664,7 +664,7 @@ TEST_F(DeclarableOpsTests6, cumSum_17) {
ASSERT_TRUE(exp.equalsTo(z));
}
////////////////////////////////////////////////////////////////////////////////
@ -697,7 +697,7 @@ TEST_F(DeclarableOpsTests6, cumSum_18) {
ASSERT_TRUE(exp.equalsTo(z));
}
////////////////////////////////////////////////////////////////////////////////
@ -731,7 +731,7 @@ TEST_F(DeclarableOpsTests6, cumSum_19) {
ASSERT_TRUE(exp.equalsTo(z));
}
////////////////////////////////////////////////////////////////////////////////
@ -764,7 +764,7 @@ TEST_F(DeclarableOpsTests6, cumSum_20) {
ASSERT_TRUE(exp.equalsTo(z));
}
////////////////////////////////////////////////////////////////////////////////
@ -779,30 +779,40 @@ TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_1) {
auto res = op.evaluate({&x, &y, &z}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, res.status());
// res.at(0)->printIndexedBuffer("MergeMaxIndex Result is ");
// res.at(0)->printShapeInfo("Shape info for MergeMaxIdex");
// x.printIndexedBuffer("Input is");
ASSERT_TRUE(res.at(0)->equalsTo(exp));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_2) {
auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
auto y = NDArrayFactory::create<double>('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f});
auto z = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 20.f, 3.f, 40.f, 5.f, 60.f, 7.f, 80.f});
auto exp = NDArrayFactory::create<Nd4jLong>('c', {2, 2, 2}, {1, 2, 1, 2, 1, 2, 1, 2});
auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 60.f, 7.f, 8.f});
auto y = NDArrayFactory::create<double>('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f});
auto z = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 20.f, 3.f, 40.f, 5.f, 6.f, 7.f, 80.f});
auto exp = NDArrayFactory::create<Nd4jLong>('c', {2, 2, 2}, {1, 2, 1, 2, 1, 0, 1, 2});
sd::ops::mergemaxindex op;
auto ress = op.evaluate({&x, &y, &z}, {}, {sd::DataType::INT64});
ASSERT_EQ(ND4J_STATUS_OK, ress.status());
// res.at(0)->printIndexedBuffer("MergeMaxIndex2 Result is ");
// res.at(0)->printShapeInfo("Shape info for MergeMaxIdex2");
// x.printIndexedBuffer("Input is");
ASSERT_TRUE(ress.at(0)->equalsTo(exp));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_3) {
auto x1 = NDArrayFactory::create<double>('c', {3}, {1.f, 0.f, 0.f});
auto x2 = NDArrayFactory::create<double>('c', {3}, {0.f, 1.f, 0.f});
auto x3 = NDArrayFactory::create<double>('c', {3}, {0.f, 0.f, 1.f});
NDArray z('c', {3}, sd::DataType::INT32);
NDArray expZ('c', {3}, {0, 1, 2}, sd::DataType::INT32);
sd::ops::mergemaxindex op;
auto result = op.execute({&x1, &x2, &x3}, {&z}, {}, {}, {});
ASSERT_EQ(Status::OK(), result);
ASSERT_TRUE(z.equalsTo(expZ));
}
////////////////////////////////////////////////////////////////////////////////
@ -818,7 +828,7 @@ TEST_F(DeclarableOpsTests6, TestDropout_1) {
//res.at(0)->printIndexedBuffer("Result is ");
//x.printIndexedBuffer("Input is");
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests6, TestMod_1) {
@ -834,7 +844,7 @@ TEST_F(DeclarableOpsTests6, TestMod_1) {
// res.at(0)->printIndexedBuffer("MOD Result is ");
// x.printIndexedBuffer("Input is");
ASSERT_TRUE(res.at(0)->equalsTo(exp));
}
////////////////////////////////////////////////////////////////////////////////
@ -853,7 +863,7 @@ TEST_F(DeclarableOpsTests6, TestMod_BP_1) {
// x.printIndexedBuffer("Input is");
ASSERT_TRUE(res.at(0)->equalsTo(exp));
}
///////////////////////////////////////////////////////////////////////////////
@ -870,7 +880,7 @@ TEST_F(DeclarableOpsTests6, TestRank_1) {
ASSERT_EQ(ND4J_STATUS_OK, res.status());
ASSERT_TRUE(res.at(0)->equalsTo(exp));
}
TEST_F(DeclarableOpsTests6, TestDropout_2) {
// auto x0 = NDArrayFactory::create<double>('c', {10, 10});
@ -883,7 +893,7 @@ TEST_F(DeclarableOpsTests6, TestDropout_2) {
ASSERT_EQ(ND4J_STATUS_OK, res.status());
}
TEST_F(DeclarableOpsTests6, TestDropout_3) {
@ -898,7 +908,7 @@ TEST_F(DeclarableOpsTests6, TestDropout_3) {
ASSERT_EQ(ND4J_STATUS_OK, res.status());
}
////////////////////////////////////////////////////////////////////////////////
@ -922,7 +932,7 @@ TEST_F(DeclarableOpsTests6, MaxPoolWithArgmax_1) {
ASSERT_TRUE(expI.equalsTo(res.at(1)));
}
////////////////////////////////////////////////////////////////////////////////
@ -947,7 +957,7 @@ TEST_F(DeclarableOpsTests6, SufficientStatistics_1) {
ASSERT_TRUE(sumExp.equalsTo(res.at(1)));
ASSERT_TRUE(sqrExp.equalsTo(res.at(2)));
}
////////////////////////////////////////////////////////////////////////////////
@ -979,7 +989,7 @@ TEST_F(DeclarableOpsTests6, SufficientStatistics_2) {
ASSERT_TRUE(sumExp.equalsTo(res.at(1)));
ASSERT_TRUE(sqrExp.equalsTo(res.at(2)));
}
////////////////////////////////////////////////////////////////////////////////
@ -1270,7 +1280,7 @@ TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_1) {
ASSERT_TRUE(exp.equalsTo(z));
// ASSERT_TRUE(expNorm.equalsTo(norm));
}
////////////////////////////////////////////////////////////////////////////////
@ -1310,7 +1320,7 @@ TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_2) {
ASSERT_TRUE(exp.equalsTo(z));
ASSERT_TRUE(exp.equalsTo(y));
}
////////////////////////////////////////////////////////////////////////////////
@ -1344,7 +1354,7 @@ TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_3) {
ASSERT_TRUE(exp.equalsTo(z));
ASSERT_TRUE(exp.equalsTo(y));
}
////////////////////////////////////////////////////////////////////////////////
@ -1365,7 +1375,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_1) {
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
////////////////////////////////////////////////////////////////////////////////
@ -1386,7 +1396,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_2) {
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
////////////////////////////////////////////////////////////////////////////////
@ -1407,7 +1417,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_3) {
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
////////////////////////////////////////////////////////////////////////////////
@ -1428,7 +1438,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_4) {
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
////////////////////////////////////////////////////////////////////////////////
@ -1452,7 +1462,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_5) {
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
////////////////////////////////////////////////////////////////////////////////
@ -1477,7 +1487,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_6) {
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
////////////////////////////////////////////////////////////////////////////////
@ -1496,7 +1506,7 @@ TEST_F(DeclarableOpsTests6, LogMatrixDeterminant_1) {
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
////////////////////////////////////////////////////////////////////////////////
@ -1514,7 +1524,7 @@ TEST_F(DeclarableOpsTests6, LogDet_1) {
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
////////////////////////////////////////////////////////////////////////////////
@ -1533,7 +1543,7 @@ TEST_F(DeclarableOpsTests6, LogDet_2) {
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
////////////////////////////////////////////////////////////////////////////////
@ -1552,7 +1562,7 @@ TEST_F(DeclarableOpsTests6, LogDet_3) {
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
////////////////////////////////////////////////////////////////////////////////
@ -1596,7 +1606,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_1) {
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
////////////////////////////////////////////////////////////////////////////////
@ -1615,7 +1625,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_010) {
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
////////////////////////////////////////////////////////////////////////////////
@ -1634,7 +1644,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_01) {
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
////////////////////////////////////////////////////////////////////////////////
@ -1653,7 +1663,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_02) {
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
////////////////////////////////////////////////////////////////////////////////
@ -1700,7 +1710,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_2) {
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
*/
TEST_F(DeclarableOpsTests6, MatrixInverse_03) {
@ -1733,7 +1743,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_03) {
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
////////////////////////////////////////////////////////////////////////////////
@ -1767,7 +1777,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_3) {
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
////////////////////////////////////////////////////////////////////////////////
@ -1801,7 +1811,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_4) {
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
////////////////////////////////////////////////////////////////////////////////
@ -1835,7 +1845,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_04) {
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
////////////////////////////////////////////////////////////////////////////////
@ -1864,7 +1874,7 @@ TEST_F(DeclarableOpsTests6, ReluLayer_1) {
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
TEST_F(DeclarableOpsTests6, Test_Reduce3_Edge) {
@ -1917,7 +1927,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test1) {
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
}
///////////////////////////////////////////////////////////////////
@ -1960,7 +1970,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test2) {
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
}
///////////////////////////////////////////////////////////////////
@ -2003,7 +2013,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test3) {
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
}
///////////////////////////////////////////////////////////////////
@ -2045,7 +2055,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test4) {
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
}
///////////////////////////////////////////////////////////////////
@ -2087,7 +2097,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test5) {
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
}
///////////////////////////////////////////////////////////////////
@ -2141,7 +2151,7 @@ TEST_F(DeclarableOpsTests6, static_bidir_rnn_test1) {
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
}
///////////////////////////////////////////////////////////////////
@ -2194,7 +2204,7 @@ TEST_F(DeclarableOpsTests6, static_bidir_rnn_test2) {
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
}
@ -2247,7 +2257,7 @@ TEST_F(DeclarableOpsTests6, static_bidir_rnn_test3) {
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
}
///////////////////////////////////////////////////////////////////
@ -2290,7 +2300,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test1) {
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
}
@ -2335,7 +2345,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test2) {
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
}
///////////////////////////////////////////////////////////////////
@ -2377,7 +2387,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test3) {
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
}
///////////////////////////////////////////////////////////////////
@ -2418,7 +2428,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test4) {
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
}
///////////////////////////////////////////////////////////////////
@ -2459,7 +2469,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test5) {
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
}
///////////////////////////////////////////////////////////////////
@ -2521,7 +2531,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test1) {
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
}
///////////////////////////////////////////////////////////////////
@ -2581,7 +2591,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test2) {
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
}
///////////////////////////////////////////////////////////////////
@ -2637,7 +2647,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test3) {
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
}
///////////////////////////////////////////////////////////////////
@ -2696,7 +2706,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test4) {
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
}
TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test5) {
@ -2749,7 +2759,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test5) {
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
}
@ -2763,7 +2773,7 @@ TEST_F(DeclarableOpsTests6, Test_Diag_119_1) {
ASSERT_EQ(e, *result.at(0));
}
TEST_F(DeclarableOpsTests6, Test_Diag_119_2) {
@ -2776,7 +2786,7 @@ TEST_F(DeclarableOpsTests6, Test_Diag_119_2) {
ASSERT_EQ(e, *result.at(0));
}
TEST_F(DeclarableOpsTests6, Test_Diag_119_3) {
@ -2789,7 +2799,7 @@ TEST_F(DeclarableOpsTests6, Test_Diag_119_3) {
ASSERT_EQ(e, *result.at(0));
}

File diff suppressed because it is too large Load Diff

View File

@ -236,10 +236,10 @@ TEST_F(DeclarableOpsTests9, ScalarOpTest_MixedOrders_1) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test1) {
auto x0 = NDArrayFactory::create<double>('c', {2,3,4});
auto x1 = NDArrayFactory::create<double>('c', {2,2,4});
auto x2 = NDArrayFactory::create<double>('c', {2,1,4});
auto exp = NDArrayFactory::create<double>('c', {2,6,4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f,
auto x0 = NDArrayFactory::create<float>('c', {2,3,4});
auto x1 = NDArrayFactory::create<float>('c', {2,2,4});
auto x2 = NDArrayFactory::create<float>('c', {2,1,4});
auto exp = NDArrayFactory::create<float>('c', {2,6,4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f,
13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.});
x0.linspace(1);
@ -261,10 +261,10 @@ TEST_F(DeclarableOpsTests9, concat_test1) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test2) {
auto x0 = NDArrayFactory::create<double>('c', {1,3,1});
auto x1 = NDArrayFactory::create<double>('c', {1,2,1});
auto x2 = NDArrayFactory::create<double>('c', {1,1,1});
auto exp = NDArrayFactory::create<double>('c', {1,6,1}, {1.f, 2.f, 3.f, 1.f, 2.f, 1.f});
auto x0 = NDArrayFactory::create<float>('c', {1,3,1});
auto x1 = NDArrayFactory::create<float>('c', {1,2,1});
auto x2 = NDArrayFactory::create<float>('c', {1,1,1});
auto exp = NDArrayFactory::create<float>('c', {1,6,1}, {1.f, 2.f, 3.f, 1.f, 2.f, 1.f});
x0.linspace(1);
x1.linspace(1);
@ -285,10 +285,10 @@ TEST_F(DeclarableOpsTests9, concat_test2) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test3) {
auto x0 = NDArrayFactory::create<double>('c', {3});
auto x1 = NDArrayFactory::create<double>('c', {2});
auto x2 = NDArrayFactory::create<double>('c', {1});
auto exp = NDArrayFactory::create<double>('c', {6}, {1.f, 2.f, 3.f, 1.f, 2.f, 1.f});
auto x0 = NDArrayFactory::create<float>('c', {3});
auto x1 = NDArrayFactory::create<float>('c', {2});
auto x2 = NDArrayFactory::create<float>('c', {1});
auto exp = NDArrayFactory::create<float>('c', {6}, {1.f, 2.f, 3.f, 1.f, 2.f, 1.f});
x0.linspace(1);
x1.linspace(1);
@ -300,21 +300,17 @@ TEST_F(DeclarableOpsTests9, concat_test3) {
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto output = result.at(0);
output->printBuffer();
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test4) {
auto x0 = NDArrayFactory::create<double>('c', {1,1,1}, {1.f});
auto x1 = NDArrayFactory::create<double>('c', {1,1,1}, {2.f});
auto x2 = NDArrayFactory::create<double>('c', {1,1,1}, {3.f});
auto exp = NDArrayFactory::create<double>('c', {1,3,1}, {1.f, 2.f, 3.f});
auto x0 = NDArrayFactory::create<float>('c', {1,1,1}, {1.f});
auto x1 = NDArrayFactory::create<float>('c', {1,1,1}, {2.f});
auto x2 = NDArrayFactory::create<float>('c', {1,1,1}, {3.f});
auto exp = NDArrayFactory::create<float>('c', {1,3,1}, {1.f, 2.f, 3.f});
sd::ops::concat op;
@ -331,10 +327,10 @@ TEST_F(DeclarableOpsTests9, concat_test4) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test5) {
auto x0 = NDArrayFactory::create<double>(1.f);
auto x1 = NDArrayFactory::create<double>('c', {1}, {2.f});
auto x2 = NDArrayFactory::create<double>(3.f);
auto exp = NDArrayFactory::create<double>('c', {3}, {1.f, 2.f, 3.f});
auto x0 = NDArrayFactory::create<float>(1.f);
auto x1 = NDArrayFactory::create<float>('c', {1}, {2.f});
auto x2 = NDArrayFactory::create<float>(3.f);
auto exp = NDArrayFactory::create<float>('c', {3}, {1.f, 2.f, 3.f});
sd::ops::concat op;
@ -351,10 +347,10 @@ TEST_F(DeclarableOpsTests9, concat_test5) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test6) {
auto x0 = NDArrayFactory::create<double>(1.f);
auto x1 = NDArrayFactory::create<double>('c', {2}, {2.f, 20.f});
auto x2 = NDArrayFactory::create<double>(3.f);
auto exp = NDArrayFactory::create<double>('c', {4}, {1.f, 2.f, 20.f, 3.f});
auto x0 = NDArrayFactory::create<float>(1.f);
auto x1 = NDArrayFactory::create<float>('c', {2}, {2.f, 20.f});
auto x2 = NDArrayFactory::create<float>(3.f);
auto exp = NDArrayFactory::create<float>('c', {4}, {1.f, 2.f, 20.f, 3.f});
sd::ops::concat op;
@ -371,10 +367,10 @@ TEST_F(DeclarableOpsTests9, concat_test6) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test7) {
auto x0 = NDArrayFactory::create<double>(1.f);
auto x1 = NDArrayFactory::create<double>(2.f);
auto x2 = NDArrayFactory::create<double>(3.f);
auto exp = NDArrayFactory::create<double>('c', {3}, {1.f, 2.f, 3.f});
auto x0 = NDArrayFactory::create<float>(1.f);
auto x1 = NDArrayFactory::create<float>(2.f);
auto x2 = NDArrayFactory::create<float>(3.f);
auto exp = NDArrayFactory::create<float>('c', {3}, {1.f, 2.f, 3.f});
sd::ops::concat op;
@ -391,8 +387,8 @@ TEST_F(DeclarableOpsTests9, concat_test7) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test8) {
auto x0 = NDArrayFactory::create<double>(1.f);
auto exp = NDArrayFactory::create<double>('c', {1}, {1.f});
auto x0 = NDArrayFactory::create<float>(1.f);
auto exp = NDArrayFactory::create<float>('c', {1}, {1.f});
sd::ops::concat op;
@ -409,8 +405,8 @@ TEST_F(DeclarableOpsTests9, concat_test8) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test9) {
auto x0 = NDArrayFactory::create<double>('c', {1}, {1.f});
auto exp = NDArrayFactory::create<double>('c', {1}, {1.f});
auto x0 = NDArrayFactory::create<float>('c', {1}, {1.f});
auto exp = NDArrayFactory::create<float>('c', {1}, {1.f});
sd::ops::concat op;
@ -427,10 +423,10 @@ TEST_F(DeclarableOpsTests9, concat_test9) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test10) {
auto x0 = NDArrayFactory::create<double>('c', {2,3,4});
auto x1 = NDArrayFactory::create<double>('f', {2,2,4});
auto x2 = NDArrayFactory::create<double>('c', {2,1,4});
auto exp = NDArrayFactory::create<double>('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f,
auto x0 = NDArrayFactory::create<float>('c', {2,3,4});
auto x1 = NDArrayFactory::create<float>('f', {2,2,4});
auto x2 = NDArrayFactory::create<float>('c', {2,1,4});
auto exp = NDArrayFactory::create<float>('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f,
13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.f});
x0.linspace(1);
@ -452,10 +448,10 @@ TEST_F(DeclarableOpsTests9, concat_test10) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test11) {
auto x0 = NDArrayFactory::create<double>('c', {2,3,4});
auto x1 = NDArrayFactory::create<double>('f', {2,2,4});
auto x2 = NDArrayFactory::create<double>('f', {2,1,4});
auto exp = NDArrayFactory::create<double>('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f,
auto x0 = NDArrayFactory::create<float>('c', {2,3,4});
auto x1 = NDArrayFactory::create<float>('f', {2,2,4});
auto x2 = NDArrayFactory::create<float>('f', {2,1,4});
auto exp = NDArrayFactory::create<float>('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f,
13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.f});
x0.linspace(1);
@ -477,10 +473,10 @@ TEST_F(DeclarableOpsTests9, concat_test11) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test12) {
auto x0 = NDArrayFactory::create<double>('c', {2,3,4});
auto x1 = NDArrayFactory::create<double>('f', {2,2,4});
auto x2 = NDArrayFactory::create<double>('f', {2,1,4});
auto exp = NDArrayFactory::create<double>('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f,
auto x0 = NDArrayFactory::create<float>('c', {2,3,4});
auto x1 = NDArrayFactory::create<float>('f', {2,2,4});
auto x2 = NDArrayFactory::create<float>('f', {2,1,4});
auto exp = NDArrayFactory::create<float>('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f,
13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.f});
x0.linspace(1);
@ -502,10 +498,10 @@ TEST_F(DeclarableOpsTests9, concat_test12) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test13) {
auto x0 = NDArrayFactory::create<double>('f', {2,3,4});
auto x1 = NDArrayFactory::create<double>('f', {2,2,4});
auto x2 = NDArrayFactory::create<double>('f', {2,1,4});
auto exp = NDArrayFactory::create<double>('f', {2,6,4}, { 1.f, 13.f, 5.f, 17.f, 9.f, 21.f, 1.f, 9.f, 5.f, 13.f, 1.f, 5.f, 2.f, 14.f, 6.f, 18.f,10.f, 22.f, 2.f, 10.f, 6.f, 14.f, 2.f, 6.f,
auto x0 = NDArrayFactory::create<float>('f', {2,3,4});
auto x1 = NDArrayFactory::create<float>('f', {2,2,4});
auto x2 = NDArrayFactory::create<float>('f', {2,1,4});
auto exp = NDArrayFactory::create<float>('f', {2,6,4}, { 1.f, 13.f, 5.f, 17.f, 9.f, 21.f, 1.f, 9.f, 5.f, 13.f, 1.f, 5.f, 2.f, 14.f, 6.f, 18.f,10.f, 22.f, 2.f, 10.f, 6.f, 14.f, 2.f, 6.f,
3.f, 15.f, 7.f, 19.f,11.f, 23.f, 3.f, 11.f, 7.f, 15.f, 3.f, 7.f, 4.f, 16.f, 8.f, 20.f,12.f, 24.f, 4.f, 12.f, 8.f, 16.f, 4.f, 8.f});
x0.linspace(1);
@ -527,8 +523,8 @@ TEST_F(DeclarableOpsTests9, concat_test13) {
TEST_F(DeclarableOpsTests9, concat_test14) {
NDArray x0('c', {1, 40, 60}, sd::DataType::DOUBLE);
NDArray x1('c', {1, 40, 60}, sd::DataType::DOUBLE);
NDArray x0('c', {1, 40, 60}, sd::DataType::FLOAT32);
NDArray x1('c', {1, 40, 60}, sd::DataType::FLOAT32);
x0 = 1.;
x1 = 2.;
@ -544,7 +540,7 @@ TEST_F(DeclarableOpsTests9, concat_test14) {
for (int e = 0; e < numOfTads; ++e) {
NDArray tad = (*z)(e, {0});
auto mean = tad.meanNumber().e<double>(0);
auto mean = tad.meanNumber().e<float>(0);
ASSERT_NEAR((e+1)*1., mean, 1e-5);
}
@ -552,9 +548,9 @@ TEST_F(DeclarableOpsTests9, concat_test14) {
}
TEST_F(DeclarableOpsTests9, concat_test15) {
auto x = NDArrayFactory::create<double>('c', {2}, {1, 0});
auto y = NDArrayFactory::create<double> (3.0f);
auto exp = NDArrayFactory::create<double>('c', {3}, {1, 0, 3});
auto x = NDArrayFactory::create<float>('c', {2}, {1, 0});
auto y = NDArrayFactory::create<float> (3.0f);
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 0, 3});
sd::ops::concat op;
auto result = op.evaluate({&x, &y}, {}, {0});
@ -571,9 +567,9 @@ TEST_F(DeclarableOpsTests9, concat_test15) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test16) {
auto x = NDArrayFactory::create<double>('c', {0,2,3});
auto y = NDArrayFactory::create<double>('c', {0,2,3});
auto exp = NDArrayFactory::create<double>('c', {0,2,3});
auto x = NDArrayFactory::create<float>('c', {0,2,3});
auto y = NDArrayFactory::create<float>('c', {0,2,3});
auto exp = NDArrayFactory::create<float>('c', {0,2,3});
sd::ops::concat op;
auto result = op.evaluate({&x, &y}, {}, {0});
@ -587,8 +583,8 @@ TEST_F(DeclarableOpsTests9, concat_test16) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test17) {
NDArray x0('c', {1, 55, 40}, sd::DataType::DOUBLE);
NDArray x1('c', {1, 55, 40}, sd::DataType::DOUBLE);
NDArray x0('c', {1, 55, 40}, sd::DataType::FLOAT32);
NDArray x1('c', {1, 55, 40}, sd::DataType::FLOAT32);
x0 = 1.;
x1 = 2.;
@ -606,7 +602,7 @@ TEST_F(DeclarableOpsTests9, concat_test17) {
for (int e = 0; e < numOfTads; ++e) {
NDArray tad = (*z)(e, {0});
auto mean = tad.meanNumber().e<double>(0);
auto mean = tad.meanNumber().e<float>(0);
ASSERT_NEAR((e+1)*1., mean, 1e-5);
}
}
@ -664,10 +660,10 @@ TEST_F(DeclarableOpsTests9, concat_test19) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test20) {
auto x0 = NDArrayFactory::create<double>('c', {1, 100, 150});
auto x1 = NDArrayFactory::create<double>('c', {1, 100, 150});
auto x2 = NDArrayFactory::create<double>('c', {1, 100, 150});
auto x3 = NDArrayFactory::create<double>('c', {1, 100, 150});
auto x0 = NDArrayFactory::create<float>('c', {1, 100, 150});
auto x1 = NDArrayFactory::create<float>('c', {1, 100, 150});
auto x2 = NDArrayFactory::create<float>('c', {1, 100, 150});
auto x3 = NDArrayFactory::create<float>('c', {1, 100, 150});
x0.assign(1.0);
x1.assign(2.0);
@ -685,8 +681,8 @@ TEST_F(DeclarableOpsTests9, concat_test20) {
for (int e = 0; e < numOfTads; e++) {
NDArray tad = (*z)(e, {0});
auto mean = tad.meanNumber().e<double>(0);
ASSERT_NEAR((double) e+1, mean, 1e-5);
auto mean = tad.meanNumber().e<float>(0);
ASSERT_NEAR((float) e+1, mean, 1e-5);
}
@ -710,10 +706,10 @@ TEST_F(DeclarableOpsTests9, concat_test21) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test22) {
NDArray x0('c', {1,6}, {1,2,3,4,5,6});
NDArray x1('c', {1,6}, {7,8,9,10,11,12});
NDArray output('f', {2,6}, sd::DataType::DOUBLE);
NDArray exp('c', {2,6}, {1,2,3,4,5,6,7,8,9,10,11,12});
NDArray x0('c', {1,6}, {1,2,3,4,5,6}, sd::DataType::FLOAT32);
NDArray x1('c', {1,6}, {7,8,9,10,11,12}, sd::DataType::FLOAT32);
NDArray output('f', {2,6}, sd::DataType::FLOAT32);
NDArray exp('c', {2,6}, {1,2,3,4,5,6,7,8,9,10,11,12}, sd::DataType::FLOAT32);
sd::ops::concat op;
@ -726,10 +722,10 @@ TEST_F(DeclarableOpsTests9, concat_test22) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test23) {
NDArray x0('c', {1,4}, {1,2,3,4});
NDArray x1('c', {1,4}, {5,6,7,8});
NDArray output('c', {2,4}, sd::DataType::DOUBLE);
NDArray exp('c', {2,4}, {1,2,3,4,5,6,7,8});
NDArray x0('c', {1,4}, {1,2,3,4},sd::DataType::FLOAT32);
NDArray x1('c', {1,4}, {5,6,7,8},sd::DataType::FLOAT32);
NDArray output('c', {2,4}, sd::DataType::FLOAT32);
NDArray exp('c', {2,4}, {1,2,3,4,5,6,7,8}, sd::DataType::FLOAT32);
sd::ops::concat op;
@ -741,10 +737,10 @@ TEST_F(DeclarableOpsTests9, concat_test23) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test24) {
auto x = NDArrayFactory::create<double>('c', {2, 1}, {1, 1});
auto y = NDArrayFactory::create<double>('c', {2, 1}, {0, 0});
auto e = NDArrayFactory::create<double>('c', {2, 2}, {1, 0, 1, 0});
auto z = NDArrayFactory::create<double>('c', {2, 2});
auto x = NDArrayFactory::create<float>('c', {2, 1}, {1, 1});
auto y = NDArrayFactory::create<float>('c', {2, 1}, {0, 0});
auto e = NDArrayFactory::create<float>('c', {2, 2}, {1, 0, 1, 0});
auto z = NDArrayFactory::create<float>('c', {2, 2});
sd::ops::concat op;
auto status = op.execute({&x, &y}, {&z}, {}, {1}, {});
@ -756,10 +752,10 @@ TEST_F(DeclarableOpsTests9, concat_test24) {
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test25) {
auto x0 = NDArrayFactory::create<double>('c', {1,4}, {1,2,3,4});
auto x1 = NDArrayFactory::create<double>('c', {1,4}, {5,6,7,8});
auto axis = NDArrayFactory::create<double>('c', {1}, {0.});
auto exp = NDArrayFactory::create<double>('c', {2,4}, {1,2,3,4,5,6,7,8});
auto x0 = NDArrayFactory::create<float>('c', {1,4}, {1,2,3,4});
auto x1 = NDArrayFactory::create<float>('c', {1,4}, {5,6,7,8});
auto axis = NDArrayFactory::create<float>('c', {1}, {0.});
auto exp = NDArrayFactory::create<float>('c', {2,4}, {1,2,3,4,5,6,7,8});
sd::ops::concat op;
@ -793,7 +789,7 @@ TEST_F(DeclarableOpsTests9, concat_test26) {
ASSERT_EQ(ND4J_STATUS_OK, result.status());
auto output = result.at(0);
output->printLinearBuffer();
// output->printLinearBuffer();
ASSERT_TRUE(exp.isSameShape(output));
ASSERT_TRUE(exp.equalsTo(output));
@ -802,10 +798,10 @@ TEST_F(DeclarableOpsTests9, concat_test26) {
//////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, concat_test27) {
auto x1 = NDArrayFactory::create<double>('c', {0,1});
auto x2 = NDArrayFactory::create<double>('c', {0,1});
auto x3 = NDArrayFactory::create<double>('c', {0,1});
auto x4 = NDArrayFactory::create<double>('c', {0,1});
auto x1 = NDArrayFactory::create<float>('c', {0,1});
auto x2 = NDArrayFactory::create<float>('c', {0,1});
auto x3 = NDArrayFactory::create<float>('c', {0,1});
auto x4 = NDArrayFactory::create<float>('c', {0,1});
std::vector<Nd4jLong> expShape = {0, 4};
@ -1245,109 +1241,6 @@ TEST_F(DeclarableOpsTests9, test_unstack_SGO_1) {
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, clipbynorm_test12) {
const int bS = 5;
const int nOut = 4;
const int axis = 0;
const double clip = 2.;
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.897 ,0.173 ,0.931 ,0.736 ,0.540 ,0.953 ,0.278 ,0.573 ,0.787 ,0.320 ,0.776 ,0.338 ,0.311 ,0.835 ,0.909 ,0.890 ,0.290}); // uniform random in range [0,1]
auto colVect = NDArrayFactory::create<double>('c', {bS, 1}, {0.9, 0.95, 1.00, 1.05, 1.1});
auto expect = NDArrayFactory::create<double>('c', {bS, nOut});
auto norm2 = x.reduceAlongDimension(reduce::Norm2, {axis}, true); // norm2 has shape [1, nOut]
auto y = ( (x / norm2) * clip) * colVect ;
auto temp = (x / norm2) * clip;
for (int j = 0; j < nOut; ++j) {
auto yCol = y({0,0, j,j+1});
const double norm2Col = yCol.reduceNumber(reduce::Norm2).e<double>(0);
if (norm2Col <= clip)
expect({0,0, j,j+1}).assign(yCol);
else
expect({0,0, j,j+1}).assign ( yCol * (clip / norm2Col) );
}
sd::ops::clipbynorm op;
auto result = op.evaluate({&y}, {clip}, {axis});
auto outFF = result.at(0);
ASSERT_TRUE(expect.isSameShape(outFF));
ASSERT_TRUE(expect.equalsTo(outFF));
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, clipbynorm_bp_test1) {
const int bS = 2;
const int nOut = 3;
const double clip = 0.7;
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1]
auto gradO = NDArrayFactory::create<double>('c', {bS, nOut});
const OpArgsHolder argsHolderFF({&x}, {clip}, {});
const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {});
sd::ops::clipbynorm opFF;
sd::ops::clipbynorm_bp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
ASSERT_TRUE(isGradCorrect);
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, clipbynorm_bp_test2) {
const int bS = 2;
const int nOut = 3;
const int axis = 0;
const double clip = 0.7;
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1]
auto gradO = NDArrayFactory::create<double>('c', {bS, nOut});
const OpArgsHolder argsHolderFF({&x}, {clip}, {axis});
const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {axis});
sd::ops::clipbynorm opFF;
sd::ops::clipbynorm_bp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
ASSERT_TRUE(isGradCorrect);
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, clipbynorm_bp_test3) {
const int bS = 2;
const int nOut = 3;
const int axis = 1;
const double clip = 1.;
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1]
auto gradO = NDArrayFactory::create<double>('c', {bS, nOut});
const OpArgsHolder argsHolderFF({&x}, {clip}, {axis});
const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {axis});
sd::ops::clipbynorm opFF;
sd::ops::clipbynorm_bp opBP;
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
ASSERT_TRUE(isGradCorrect);
}
////////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests9, cumprod_1) {