Shyrma merge max ind (#443)
* - provide correct possible output types in mergeMaxIndex op Signed-off-by: Yurii <iuriish@yahoo.com> * - cleaning up the unneeded backprop arg in reverse_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - improve clipByNorm both ff and bp Signed-off-by: Yurii <iuriish@yahoo.com> * - implementation and testing clipByAvgNorm_bp op Signed-off-by: Yurii <iuriish@yahoo.com> * - pass biases in any way in dnnl lstm op, they are zeros when user doesn't provide them to us Signed-off-by: Yurii <iuriish@yahoo.com> * - start working on mkldnn concat op Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on mkldnn concat Signed-off-by: Yurii <iuriish@yahoo.com> * missing declaration fix Signed-off-by: raver119@gmail.com <raver119@gmail.com> * - polishing mkl ops Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in mkl concat op Signed-off-by: Yurii <iuriish@yahoo.com> * - fix linkage error for windows cuda build Signed-off-by: Yurii <iuriish@yahoo.com> * - further conflicts resolving with master Signed-off-by: Yurii <iuriish@yahoo.com> * - fix format tags in mkldnn matmul op Signed-off-by: Yurii <iuriish@yahoo.com> * - provide additional type cast in clip.cu Signed-off-by: Yurii <iuriish@yahoo.com> * - finally bug in mkldnn tanh_bp was caught Co-authored-by: raver119@gmail.com <raver119@gmail.com>master
parent
872a511042
commit
76f3553679
|
@ -981,12 +981,12 @@ namespace sd {
|
|||
* these methods suited for FlatBuffers use
|
||||
*/
|
||||
template <typename T>
|
||||
std::vector<T> getBufferAsVector();
|
||||
std::vector<T> getBufferAsVector() const;
|
||||
std::vector<Nd4jLong> getShapeAsVector() const;
|
||||
std::vector<int> getShapeAsVectorInt() const;
|
||||
std::vector<Nd4jLong> getShapeInfoAsVector();
|
||||
std::vector<int64_t> getShapeInfoAsFlatVector();
|
||||
std::vector<int64_t> getShapeAsFlatVector();
|
||||
std::vector<Nd4jLong> getShapeInfoAsVector() const;
|
||||
std::vector<int64_t> getShapeInfoAsFlatVector() const;
|
||||
std::vector<int64_t> getShapeAsFlatVector() const;
|
||||
|
||||
/**
|
||||
* set new order and shape in case of suitable array length (in-place operation)
|
||||
|
|
|
@ -982,16 +982,16 @@ std::string NDArray::asString(Nd4jLong limit) {
|
|||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
std::vector<T> NDArray::getBufferAsVector() {
|
||||
std::vector<T> NDArray::getBufferAsVector() const {
|
||||
std::vector<T> vector(lengthOf());
|
||||
for (Nd4jLong e = 0; e < lengthOf(); e++)
|
||||
vector[e] = this->e<T>(e);
|
||||
return vector;
|
||||
}
|
||||
BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT std::vector, NDArray::getBufferAsVector(), LIBND4J_TYPES);
|
||||
BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT std::vector, NDArray::getBufferAsVector() const, LIBND4J_TYPES);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
std::vector<int64_t> NDArray::getShapeAsFlatVector() {
|
||||
std::vector<int64_t> NDArray::getShapeAsFlatVector() const {
|
||||
std::vector<int64_t> vector(this->rankOf());
|
||||
for (int e = 0; e < this->rankOf(); e++)
|
||||
vector[e] = static_cast<int64_t>(this->sizeAt(e));
|
||||
|
@ -1019,7 +1019,7 @@ std::vector<int> NDArray::getShapeAsVectorInt() const {
|
|||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
std::vector<int64_t> NDArray::getShapeInfoAsFlatVector() {
|
||||
std::vector<int64_t> NDArray::getShapeInfoAsFlatVector() const {
|
||||
int magicNumber = shape::shapeInfoLength(this->rankOf());
|
||||
std::vector<int64_t> vector(magicNumber);
|
||||
|
||||
|
@ -1030,7 +1030,7 @@ std::vector<int64_t> NDArray::getShapeInfoAsFlatVector() {
|
|||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////
|
||||
std::vector<Nd4jLong> NDArray::getShapeInfoAsVector() {
|
||||
std::vector<Nd4jLong> NDArray::getShapeInfoAsVector() const {
|
||||
int magicNumber = shape::shapeInfoLength(this->rankOf());
|
||||
std::vector<Nd4jLong> vector(magicNumber);
|
||||
for (int e = 0; e < magicNumber; e++)
|
||||
|
|
|
@ -15,7 +15,8 @@
|
|||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
// @author raver119@gmail.com
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <system/op_boilerplate.h>
|
||||
|
@ -27,24 +28,58 @@
|
|||
namespace sd {
|
||||
namespace ops {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
CONFIGURABLE_OP_IMPL(clipbyavgnorm, 1, 1, true, 1, 0) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
const bool isInplace = block.isInplace();
|
||||
auto ts = NDArrayFactory::create(T_ARG(0), block.launchContext());
|
||||
auto clipNorm = NDArrayFactory::create(T_ARG(0), block.launchContext());
|
||||
|
||||
helpers::clipByAveraged(block.launchContext(), *input, *output, *block.getIArguments(), ts, isInplace);
|
||||
helpers::clipByNorm(block.launchContext(), *input, *output, *block.getIArguments(), clipNorm, isInplace, true);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_TYPES(clipbyavgnorm) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(sd::DataType::ANY)
|
||||
->setAllowedOutputTypes({ALL_FLOATS});
|
||||
}
|
||||
DECLARE_TYPES(clipbyavgnorm) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(sd::DataType::ANY)
|
||||
->setAllowedOutputTypes({ALL_FLOATS});
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
CUSTOM_OP_IMPL(clipbyavgnorm_bp, 2, 1, false, 1, 0) {
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto gradO = INPUT_VARIABLE(1);
|
||||
|
||||
auto gradI = OUTPUT_VARIABLE(0);
|
||||
|
||||
const auto clipNorm = NDArrayFactory::create(gradI->dataType(), T_ARG(0), block.launchContext());
|
||||
|
||||
helpers::clipByNormBp(block.launchContext(), *input, *gradO, *gradI, *block.getIArguments(), clipNorm, true);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
DECLARE_SHAPE_FN(clipbyavgnorm_bp) {
|
||||
|
||||
Nd4jLong *newShape = nullptr;
|
||||
COPY_SHAPE(inputShape->at(1), newShape);
|
||||
|
||||
return SHAPELIST(CONSTANT(newShape));
|
||||
}
|
||||
|
||||
|
||||
DECLARE_TYPES(clipbyavgnorm_bp) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, DataType::ANY)
|
||||
->setAllowedInputTypes(1, {ALL_FLOATS})
|
||||
->setAllowedOutputTypes(0, {ALL_FLOATS});
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,10 +31,10 @@ namespace ops {
|
|||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
const auto clipNorm = NDArrayFactory::create(input->dataType(), T_ARG(0), block.launchContext());
|
||||
const auto clipNorm = NDArrayFactory::create(output->dataType(), T_ARG(0), block.launchContext());
|
||||
const bool isInplace = block.isInplace();
|
||||
|
||||
helpers::clipByNorm(block.launchContext(), *input, *output, *block.getIArguments(), clipNorm, isInplace);
|
||||
helpers::clipByNorm(block.launchContext(), *input, *output, *block.getIArguments(), clipNorm, isInplace, false);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -45,15 +45,15 @@ namespace ops {
|
|||
auto gradO = INPUT_VARIABLE(1);
|
||||
|
||||
auto gradI = OUTPUT_VARIABLE(0);
|
||||
const auto clipNorm = NDArrayFactory::create(T_ARG(0));
|
||||
const auto clipNorm = NDArrayFactory::create(gradI->dataType(), T_ARG(0), block.launchContext());
|
||||
|
||||
helpers::clipByNormBP(block.launchContext(), *input, *gradO, *gradI, *block.getIArguments(), clipNorm);
|
||||
helpers::clipByNormBp(block.launchContext(), *input, *gradO, *gradI, *block.getIArguments(), clipNorm, false);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(clipbynorm_bp) {
|
||||
auto inShapeInfo = inputShape->at(0);
|
||||
auto inShapeInfo = inputShape->at(1);
|
||||
|
||||
Nd4jLong *newShape = nullptr;
|
||||
COPY_SHAPE(inShapeInfo, newShape);
|
||||
|
|
|
@ -23,8 +23,8 @@
|
|||
#include<ops/declarable/helpers/transforms.h>
|
||||
#include<array>
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
@ -85,6 +85,7 @@ CUSTOM_OP_IMPL(concat, -1, 1, false, 0, 0) {
|
|||
|
||||
// ******** input validation ******** //
|
||||
REQUIRE_TRUE(allOfSameType, 0, "CONCAT op: all of input arrays must have same type !");
|
||||
REQUIRE_TRUE(nonEmptyArrs[0]->dataType() == OUTPUT_VARIABLE(0)->dataType(), 0, "CONCAT op: output array should have the same type as inputs arrays !");
|
||||
REQUIRE_TRUE(0 <= axis && (axis < rank || (axis == 0 && rank == 0)), 0, "CONCAT op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis);
|
||||
|
||||
for(int i = 1; i < numOfNonEmptyArrs; ++i)
|
||||
|
|
|
@ -33,7 +33,7 @@ CUSTOM_OP_IMPL(mergemaxindex, -1, 1, false, 0, 0) {
|
|||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
std::vector<const NDArray*> inArrs(block.width());
|
||||
|
||||
|
||||
for(int i = 0; i < block.width(); ++i)
|
||||
inArrs[i] = INPUT_VARIABLE(i);
|
||||
|
||||
|
@ -46,7 +46,8 @@ DECLARE_SYN(MergeMaxIndex, mergemaxindex);
|
|||
|
||||
DECLARE_TYPES(mergemaxindex) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes({ALL_INTS, ALL_FLOATS});
|
||||
->setAllowedInputTypes({ALL_INTS, ALL_FLOATS})
|
||||
->setAllowedOutputTypes({ALL_INDICES});
|
||||
}
|
||||
}
|
||||
DECLARE_SHAPE_FN(mergemaxindex) {
|
||||
|
|
|
@ -52,7 +52,7 @@ namespace ops {
|
|||
else {
|
||||
// check the consistency of input dimensions to reverse along
|
||||
shape::checkDimensions(input->rankOf(), axis);
|
||||
helpers::reverse(block.launchContext(), input, output, &axis, false);
|
||||
helpers::reverse(block.launchContext(), input, output, &axis);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
@ -85,7 +85,7 @@ namespace ops {
|
|||
// check the consistency of input dimensions to reverse along
|
||||
shape::checkDimensions(input->rankOf(), axis);
|
||||
// we just reverse back original array
|
||||
helpers::reverse(block.launchContext(), eps, output, &axis, false);
|
||||
helpers::reverse(block.launchContext(), eps, output, &axis);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
@ -36,6 +36,7 @@ namespace sd {
|
|||
|
||||
#if NOT_EXCLUDED(OP_clipbyavgnorm)
|
||||
DECLARE_CONFIGURABLE_OP(clipbyavgnorm, 1, 1, true, 1, 0);
|
||||
DECLARE_CUSTOM_OP(clipbyavgnorm_bp, 2, 1, false, 1, 0);
|
||||
#endif
|
||||
|
||||
#if NOT_EXCLUDED(OP_cumsum)
|
||||
|
|
|
@ -15,83 +15,134 @@
|
|||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
// @author sgazeos@gmail.com
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
|
||||
#include <ops/declarable/helpers/transforms.h>
|
||||
#include <helpers/Loops.h>
|
||||
#include <execution/Threads.h>
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
static void clipByNorm_(NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace) {
|
||||
void clipByNorm(sd::LaunchContext* context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace, const bool useAverage) {
|
||||
|
||||
const int rank = input.rankOf();
|
||||
const auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions);
|
||||
NDArray* z = nullptr;
|
||||
|
||||
const T normActual = norm2.e<T>(0);
|
||||
const T normClip = clipNorm.e<T>(0);
|
||||
if(isInplace) {
|
||||
z = &input;
|
||||
}
|
||||
else {
|
||||
output.assign(input);
|
||||
z = &output;
|
||||
}
|
||||
|
||||
if (isInplace) {
|
||||
if(dimensions.empty()) {
|
||||
|
||||
if(norm2.lengthOf() == 1) {
|
||||
const NDArray actualNorm = useAverage ? z->reduceAlongDimension(reduce::Norm2, {}) / z->lengthOf() : z->reduceAlongDimension(reduce::Norm2, {});
|
||||
|
||||
if(normActual > normClip)
|
||||
input *= (normClip / normActual);
|
||||
}
|
||||
else {
|
||||
|
||||
auto listOfInSubArrs = input.allTensorsAlongDimension(dimensions);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
const T iNormActual = norm2.e<T>(i);
|
||||
if (iNormActual > normClip)
|
||||
*listOfInSubArrs.at(i) *= normClip / iNormActual;
|
||||
}
|
||||
};
|
||||
samediff::Threads::parallel_tad(func, 0, listOfInSubArrs.size());
|
||||
}
|
||||
if(actualNorm.e<float>(0) > clipNorm.e<float>(0))
|
||||
*z *= clipNorm / actualNorm;
|
||||
}
|
||||
else {
|
||||
|
||||
if(norm2.lengthOf() == 1) {
|
||||
auto listOfSubArrs = z->allTensorsAlongDimension(dimensions);
|
||||
|
||||
if(normActual > normClip)
|
||||
output.assign(input * (normClip / normActual));
|
||||
else
|
||||
output.assign(input);
|
||||
}
|
||||
else {
|
||||
|
||||
auto listOfInSubArrs = input.allTensorsAlongDimension(dimensions);
|
||||
auto listOfOutSubArrs = output.allTensorsAlongDimension(dimensions);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
auto inputSubArr = listOfInSubArrs.at(i);
|
||||
auto outputSubArr = listOfOutSubArrs.at(i);
|
||||
outputSubArr->assign(inputSubArr);
|
||||
|
||||
const T iNormActual = norm2.e<T>(i);
|
||||
|
||||
if (iNormActual > clipNorm.e<T>(0))
|
||||
*outputSubArr *= clipNorm / iNormActual;
|
||||
}
|
||||
};
|
||||
samediff::Threads::parallel_tad(func, 0, listOfInSubArrs.size());
|
||||
}
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
const NDArray actualNorm = useAverage ? listOfSubArrs.at(i)->reduceAlongDimension(reduce::Norm2, {}) / listOfSubArrs.at(i)->lengthOf() : listOfSubArrs.at(i)->reduceAlongDimension(reduce::Norm2, {});
|
||||
if(actualNorm.e<float>(0) > clipNorm.e<float>(0))
|
||||
*listOfSubArrs.at(i) *= clipNorm / actualNorm;
|
||||
}
|
||||
};
|
||||
samediff::Threads::parallel_tad(func, 0, listOfSubArrs.size());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void clipByNorm(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace) {
|
||||
BUILD_SINGLE_SELECTOR(output.dataType(), clipByNorm_, (input, output, dimensions, clipNorm, isInplace), FLOAT_TYPES);
|
||||
template<typename T>
|
||||
static void clipByNormBp_(const NDArray& input, const NDArray& gradO, NDArray& gradI, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool useAverage) {
|
||||
|
||||
const int rank = input.rankOf();
|
||||
|
||||
auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions);
|
||||
auto sums = input.reduceAlongDimension(reduce::Sum, dimensions);
|
||||
|
||||
if(norm2.lengthOf() == 1) {
|
||||
|
||||
const T norm = useAverage ? norm2.e<T>(0) / input.lengthOf() : norm2.e<T>(0);
|
||||
|
||||
auto clipVal = clipNorm.e<T>(0);
|
||||
|
||||
if(norm > clipVal) {
|
||||
|
||||
const T sum = sums.e<T>(0); // reduce to scalar
|
||||
const T factor1 = clipVal / norm;
|
||||
const T factor2 = static_cast<T>(1.f) / (norm * norm); // 1 / (norm*norm*norm)
|
||||
|
||||
auto lambda = LAMBDA_TT(x, y, sum, factor1, factor2) {
|
||||
return factor1 * y * (static_cast<T>(1.f) - factor2 * x * sum);
|
||||
};
|
||||
|
||||
const_cast<NDArray&>(input).applyPairwiseLambda<T>(const_cast<NDArray&>(gradO), lambda, gradI);
|
||||
}
|
||||
else
|
||||
gradI.assign(gradO);
|
||||
}
|
||||
else {
|
||||
|
||||
auto gradISubArrs = gradI.allTensorsAlongDimension({dimensions});
|
||||
auto gradOSubArrs = gradO.allTensorsAlongDimension({dimensions});
|
||||
auto inputSubArrs = input.allTensorsAlongDimension({dimensions});
|
||||
|
||||
auto clipVal = clipNorm.e<T>(0);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
|
||||
for (auto i = start; i < stop; i++) {
|
||||
|
||||
auto gradOSubArr = gradOSubArrs.at(i);
|
||||
auto gradISubArr = gradISubArrs.at(i);
|
||||
|
||||
const T norm = useAverage ? norm2.e<T>(i) / gradISubArr->lengthOf() : norm2.e<T>(i);
|
||||
|
||||
if (norm > clipVal) {
|
||||
|
||||
auto inputSubArr = inputSubArrs.at(i);
|
||||
|
||||
const T sum = sums.e<T>(i); // reduce to scalar
|
||||
const T factor1 = clipVal / norm;
|
||||
const T factor2 = static_cast<T>(1.f) / (norm * norm); // 1 / (norm*norm*norm)
|
||||
|
||||
auto lambda = LAMBDA_TT(x, y, sum, factor1, factor2) {
|
||||
return factor1 * y * (static_cast<T>(1.f) - factor2 * x * sum);
|
||||
};
|
||||
|
||||
inputSubArr->applyPairwiseLambda<T>(*gradOSubArr, lambda, *gradISubArr);
|
||||
}
|
||||
else
|
||||
gradISubArr->assign(gradOSubArr);
|
||||
}
|
||||
};
|
||||
samediff::Threads::parallel_tad(func, 0, gradISubArrs.size());
|
||||
}
|
||||
}
|
||||
BUILD_SINGLE_TEMPLATE(template void clipByNormBp_, (const NDArray& input, const NDArray& gradO, NDArray& gradI, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool useAverage), FLOAT_TYPES);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void clipByNormBp(sd::LaunchContext* context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool useAverage) {
|
||||
|
||||
const NDArray& castedInput = gradI.dataType() == input.dataType() ? input : input.cast(gradI.dataType());
|
||||
|
||||
BUILD_SINGLE_SELECTOR(gradI.dataType(), clipByNormBp_, (castedInput, gradO, gradI, dimensions, clipNorm, useAverage), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
|
@ -132,125 +183,6 @@ void clipByNorm(sd::LaunchContext * context, NDArray& input, NDArray& output, co
|
|||
|
||||
BUILD_SINGLE_TEMPLATE(template void clipByGlobalNorm_, (std::vector<NDArray*> const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace), FLOAT_TYPES);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
static void clipByNormBP_(const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector<int>& dimensions, const NDArray& clipNorm) {
|
||||
|
||||
const int rank = input.rankOf();
|
||||
|
||||
auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions);
|
||||
|
||||
if(norm2.lengthOf() == 1) {
|
||||
|
||||
const T N = norm2.e<T>(0);
|
||||
|
||||
auto cn = clipNorm.e<T>(0);
|
||||
|
||||
if(N > cn) {
|
||||
|
||||
const T sumOfProd = (input * gradO).reduceNumber(reduce::Sum).e<T>(0); // reduce to scalar
|
||||
const T factor1 = static_cast<T>(1.f) / N;
|
||||
const T factor3 = factor1 / (N * N); // 1 / (N*N*N)
|
||||
|
||||
auto lambda = LAMBDA_TT(elem1, elem2, cn, sumOfProd, factor1, factor3) {
|
||||
return cn * (factor1 * elem2 - factor3 * elem1 * sumOfProd);
|
||||
};
|
||||
|
||||
(const_cast<NDArray&>(input)).applyPairwiseLambda<T>(const_cast<NDArray&>(gradO), lambda, gradI);
|
||||
}
|
||||
else
|
||||
gradI.assign(gradO);
|
||||
}
|
||||
else {
|
||||
|
||||
auto gradISubArrs = gradI.allTensorsAlongDimension({dimensions});
|
||||
auto gradOSubArrs = gradO.allTensorsAlongDimension({dimensions});
|
||||
auto inputSubArrs = input.allTensorsAlongDimension({dimensions});
|
||||
|
||||
auto cn = clipNorm.e<T>(0);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto i = start; i < stop; i++) {
|
||||
T N = norm2.e<T>(i);
|
||||
|
||||
auto gradOSubArr = gradOSubArrs.at(i);
|
||||
auto gradISubArr = gradISubArrs.at(i);
|
||||
|
||||
if (N > cn) {
|
||||
auto inputSubArr = inputSubArrs.at(i);
|
||||
const T sumOfProd = (*inputSubArr * *gradOSubArr).reduceNumber(reduce::Sum).e<T>(0); // reduce to scalar
|
||||
const T factor1 = static_cast<T>(1.f) / N;
|
||||
const T factor3 = factor1 / (N * N); // 1 / (N*N*N)
|
||||
|
||||
auto lambda = LAMBDA_TT(elem1, elem2, cn, sumOfProd, factor1, factor3) {
|
||||
return cn * (factor1 * elem2 - factor3 * elem1 * sumOfProd);
|
||||
};
|
||||
|
||||
inputSubArr->applyPairwiseLambda<T>(*gradOSubArr, lambda, *gradISubArr);
|
||||
} else
|
||||
gradISubArr->assign(gradOSubArr);
|
||||
}
|
||||
};
|
||||
samediff::Threads::parallel_tad(func, 0, gradISubArrs.size());
|
||||
}
|
||||
}
|
||||
|
||||
void clipByNormBP(sd::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector<int>& dimensions, const NDArray& clipNorm) {
|
||||
BUILD_SINGLE_SELECTOR(gradI.dataType(), clipByNormBP_, (input, gradO, gradI, dimensions, clipNorm), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void clipByNormBP_, (const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector<int>& dimensions, const NDArray& clipNorm), FLOAT_TYPES);
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
static void clipByAveraged_(NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace) {
|
||||
|
||||
auto cn = clipNorm.e<T>(0);
|
||||
if (dimensions.size() == 0) {
|
||||
// all-reduce
|
||||
T n2 = input.reduceNumber(reduce::Norm2).e<T>(0) / input.lengthOf();
|
||||
if (n2 <= cn) {
|
||||
if (!isInplace)
|
||||
output.assign(input);
|
||||
}
|
||||
else {
|
||||
const T factor = cn / n2;
|
||||
auto lambda = LAMBDA_T(_x, factor) { return _x * factor; };
|
||||
input.applyLambda<T>(lambda, output);
|
||||
}
|
||||
}
|
||||
else {
|
||||
// along dimension
|
||||
auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions, false);
|
||||
if (!isInplace)
|
||||
output.assign(input);
|
||||
auto tads = output.allTensorsAlongDimension(dimensions);
|
||||
// TODO: make this CUDA-compliant somehow
|
||||
for (int e = 0; e < tads.size(); e++) {
|
||||
T n2 = norm2.e<T>(e) / tads.at(e)->lengthOf();
|
||||
const T factor = cn / n2;
|
||||
if (n2 > cn) {
|
||||
auto lambda = LAMBDA_T(_x, factor) {return _x * factor;};
|
||||
tads.at(e)->applyLambda<T>(lambda, output);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void clipByAveraged(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), clipByAveraged_, (input, output, dimensions, clipNorm, isInplace), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void clipByAveraged_, (NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace), FLOAT_TYPES);
|
||||
|
||||
/*
|
||||
if (d1 > params[1])
|
||||
return params[1];
|
||||
else if (d1 < params[0])
|
||||
return params[0];
|
||||
else return d1;
|
||||
*/
|
||||
|
||||
template <typename T>
|
||||
static void clipByValue_(NDArray& input, double leftBound, double rightBound, NDArray& output) {
|
||||
|
|
|
@ -29,7 +29,7 @@ namespace helpers {
|
|||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
template<typename X, typename Z>
|
||||
static void mergeMaxIndex_(const std::vector<const NDArray*>& inArrs, NDArray& output) {
|
||||
|
||||
const Nd4jLong numArgs = inArrs.size();
|
||||
|
@ -37,17 +37,18 @@ static void mergeMaxIndex_(const std::vector<const NDArray*>& inArrs, NDArray& o
|
|||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
for (auto e = start; e < stop; e++) {
|
||||
T max = -DataTypeUtils::max<T>();
|
||||
Nd4jLong idx = 0;
|
||||
X max = -DataTypeUtils::max<X>();
|
||||
Z idx = static_cast<Z>(0);
|
||||
|
||||
for (Nd4jLong i = 0; i < numArgs; i++) {
|
||||
T v = inArrs[i]->e<T>(e);
|
||||
X v = inArrs[i]->t<X>(e);
|
||||
if (v > max) {
|
||||
max = v;
|
||||
idx = i;
|
||||
idx = static_cast<Z>(i);
|
||||
}
|
||||
}
|
||||
output.p(e, idx);
|
||||
// FIXME, use .r<Z>(e)
|
||||
output.t<Z>(e) = static_cast<Z>(idx);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -55,14 +56,14 @@ static void mergeMaxIndex_(const std::vector<const NDArray*>& inArrs, NDArray& o
|
|||
}
|
||||
|
||||
void mergeMaxIndex(sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray& output) {
|
||||
BUILD_SINGLE_SELECTOR(inArrs[0]->dataType(), mergeMaxIndex_, (inArrs, output), LIBND4J_TYPES);
|
||||
BUILD_DOUBLE_SELECTOR(inArrs[0]->dataType(), output.dataType(), mergeMaxIndex_, (inArrs, output), LIBND4J_TYPES, INDEXING_TYPES);
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
static void mergeMax_(const std::vector<const NDArray*>& inArrs, NDArray& output) {
|
||||
|
||||
|
||||
const Nd4jLong numArgs = inArrs.size();
|
||||
auto x = inArrs[0];
|
||||
|
||||
|
@ -89,15 +90,15 @@ void mergeMax(sd::LaunchContext * context, const std::vector<const NDArray*>& in
|
|||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
static void mergeMaxBp_(const std::vector<const NDArray*>& inArrs, std::vector<NDArray*>& outArrs) {
|
||||
|
||||
|
||||
// outArrs.size() == inArrs.size() - 1
|
||||
const Nd4jLong numArgs = outArrs.size();
|
||||
// last array is gradient
|
||||
const auto gradient = inArrs[numArgs]->bufferAsT<T>();
|
||||
auto length = inArrs[numArgs]->lengthOf();
|
||||
|
||||
|
||||
bool bSameOrderAndEws1 = (1 == inArrs[numArgs]->ews());
|
||||
|
||||
|
||||
if (bSameOrderAndEws1) {
|
||||
auto gradOrdering = inArrs[numArgs]->ordering();
|
||||
|
||||
|
@ -108,8 +109,8 @@ static void mergeMaxBp_(const std::vector<const NDArray*>& inArrs, std::vector<N
|
|||
bSameOrderAndEws1 &= (1 == outArrs[i]->ews());
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
if(bSameOrderAndEws1){
|
||||
auto func = PRAGMA_THREADS_FOR{
|
||||
for (auto e = start; e < stop; e++) {
|
||||
|
@ -130,7 +131,7 @@ static void mergeMaxBp_(const std::vector<const NDArray*>& inArrs, std::vector<N
|
|||
samediff::Threads::parallel_for(func, 0, length);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
auto gradShape = inArrs[numArgs]->shapeInfo();
|
||||
std::vector<bool> vbSameShaepeAndStrides(numArgs);
|
||||
for (int i = 0; i < numArgs; ++i) {
|
||||
|
@ -145,12 +146,12 @@ static void mergeMaxBp_(const std::vector<const NDArray*>& inArrs, std::vector<N
|
|||
shape::index2coordsCPU(start, e, gradShape, coords);
|
||||
|
||||
const auto gradOffset = shape::getOffset(gradShape, coords);
|
||||
|
||||
|
||||
T max = -DataTypeUtils::max<T>();
|
||||
Nd4jLong nMaxIndex = 0;
|
||||
|
||||
for (Nd4jLong i = 0; i < numArgs; i++) {
|
||||
|
||||
|
||||
const auto xOffset = vbSameShaepeAndStrides[i] ? gradOffset : shape::getOffset(inArrs[i]->shapeInfo(), coords);
|
||||
const T* v = inArrs[i]->bufferAsT<T>();
|
||||
if (v[xOffset] > max) {
|
||||
|
@ -160,7 +161,7 @@ static void mergeMaxBp_(const std::vector<const NDArray*>& inArrs, std::vector<N
|
|||
}
|
||||
|
||||
const auto zOffset = vbSameShaepeAndStrides[nMaxIndex] ? gradOffset : shape::getOffset(outArrs[nMaxIndex]->shapeInfo(), coords);
|
||||
|
||||
|
||||
T* z = outArrs[nMaxIndex]->bufferAsT<T>();
|
||||
z[zOffset] = gradient[gradOffset];
|
||||
}
|
||||
|
|
|
@ -193,13 +193,10 @@ static void reverseSequence_(sd::LaunchContext * context, const NDArray* input,
|
|||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void reverse(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector<int>* intArgs, bool isBackProp) {
|
||||
void reverse(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector<int>* intArgs) {
|
||||
|
||||
// we need to reverse axis only if that's new op
|
||||
std::vector<int> dimensions = isBackProp ? ShapeUtils::evalDimsToExclude(input->rankOf(), *intArgs) : *intArgs;
|
||||
|
||||
auto listOut = output->allTensorsAlongDimension(dimensions);
|
||||
auto listIn = input->allTensorsAlongDimension(dimensions);
|
||||
auto listOut = output->allTensorsAlongDimension(*intArgs);
|
||||
auto listIn = input->allTensorsAlongDimension(*intArgs);
|
||||
|
||||
NDArray *subArrIn, *subArrOut;
|
||||
|
||||
|
|
|
@ -0,0 +1,334 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
// @author sgazeos@gmail.com
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
|
||||
#include <ops/declarable/helpers/transforms.h>
|
||||
#include <helpers/ShapeUtils.h>
|
||||
#include <helpers/PointersManager.h>
|
||||
#include <helpers/ConstantTadHelper.h>
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
__global__ static void clipByNormCuda(const void* vClipNorm, const void* vNorm, const Nd4jLong* normShapeInfo, void* vz, const Nd4jLong* zShapeInfo, const int* dimensions, const int dimsLen, const bool useAverage) {
|
||||
|
||||
const T clipNorm = *reinterpret_cast<const T*>(vClipNorm);
|
||||
const T* norm = reinterpret_cast<const T*>(vNorm);
|
||||
T* z = reinterpret_cast<T*>(vz);
|
||||
|
||||
__shared__ Nd4jLong zLen, tadLen, totalThreads;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
|
||||
zLen = shape::length(zShapeInfo);
|
||||
tadLen = zLen / shape::length(normShapeInfo);
|
||||
totalThreads = gridDim.x * blockDim.x;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int zCoords[MAX_RANK], normCoords[MAX_RANK];
|
||||
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
|
||||
|
||||
shape::index2coords(i, zShapeInfo, zCoords);
|
||||
|
||||
// deduce norm coords
|
||||
for (int j = 0; j < dimsLen; ++j)
|
||||
normCoords[j] = zCoords[dimensions[j]];
|
||||
|
||||
const T actualNorm = useAverage ? norm[shape::getOffset(normShapeInfo, normCoords)] / tadLen : norm[shape::getOffset(normShapeInfo, normCoords)];
|
||||
|
||||
if(actualNorm > clipNorm)
|
||||
z[shape::getOffset(zShapeInfo, zCoords)] *= clipNorm / actualNorm;
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
__host__ static void clipByNormCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream,
|
||||
const void* vClipNorm, const void* vNorm, const Nd4jLong* normShapeInfo, void* vz, const Nd4jLong* zShapeInfo,
|
||||
const int* dimensions, const int dimsLen, const bool useAverage) {
|
||||
|
||||
clipByNormCuda<T><<<blocksPerGrid, threadsPerBlock, 512, *stream>>>(vClipNorm, vNorm, normShapeInfo, vz, zShapeInfo, dimensions, dimsLen, useAverage);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void clipByNorm(sd::LaunchContext* context, NDArray& input, NDArray& output, const std::vector<int>& dims, const NDArray& clipNorm, const bool isInplace, const bool useAverage) {
|
||||
|
||||
NDArray* z = nullptr;
|
||||
|
||||
if(isInplace) {
|
||||
z = &input;
|
||||
}
|
||||
else {
|
||||
output.assign(input);
|
||||
z = &output;
|
||||
}
|
||||
|
||||
if(dims.empty()) {
|
||||
|
||||
const NDArray actualNorm = useAverage ? z->reduceAlongDimension(reduce::Norm2, {}) / z->lengthOf() : z->reduceAlongDimension(reduce::Norm2, {});
|
||||
|
||||
if(actualNorm.e<float>(0) > clipNorm.e<float>(0))
|
||||
*z *= clipNorm / actualNorm;
|
||||
}
|
||||
else {
|
||||
|
||||
const NDArray actualNorms = z->reduceAlongDimension(reduce::Norm2, dims);
|
||||
|
||||
std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(z->rankOf(), dims);
|
||||
|
||||
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||
const int blocksPerGrid = (z->lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||
|
||||
PointersManager manager(context, "clipByNorm");
|
||||
|
||||
const int* dimensions = reinterpret_cast<const int*>(manager.replicatePointer(dimsToExclude.data(), dimsToExclude.size() * sizeof(int)));
|
||||
|
||||
NDArray::prepareSpecialUse({z}, {z, &actualNorms, &clipNorm});
|
||||
BUILD_SINGLE_SELECTOR(z->dataType(), clipByNormCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), clipNorm.specialBuffer(), actualNorms.specialBuffer(), actualNorms.specialShapeInfo(), z->specialBuffer(), z->specialShapeInfo(), dimensions, (int)dimsToExclude.size(), useAverage), FLOAT_TYPES);
|
||||
NDArray::registerSpecialUse({z}, {z, &actualNorms, &clipNorm});
|
||||
|
||||
manager.synchronize();
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
__global__ static void clipByNormBpCuda(const void* vClipNorm,
|
||||
const void* vx, const Nd4jLong* xShapeInfo, // input
|
||||
const void* vy, const Nd4jLong* yShapeInfo, // gradO
|
||||
const void* vNorm, const Nd4jLong* normShapeInfo,
|
||||
const void* vSum, const Nd4jLong* sumShapeInfo,
|
||||
void* vz, const Nd4jLong* zShapeInfo, // gradI
|
||||
const int* dimensions, const int dimsLen, const bool useAverage) {
|
||||
|
||||
const T clipNorm = *reinterpret_cast<const T*>(vClipNorm);
|
||||
const T* norm = reinterpret_cast<const T*>(vNorm);
|
||||
const T* sum = reinterpret_cast<const T*>(vSum);
|
||||
const T* x = reinterpret_cast<const T*>(vx);
|
||||
const T* y = reinterpret_cast<const T*>(vy);
|
||||
T* z = reinterpret_cast<T*>(vz);
|
||||
|
||||
__shared__ Nd4jLong zLen, tadLen, totalThreads;
|
||||
__shared__ bool sameOffsets;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
|
||||
zLen = shape::length(zShapeInfo);
|
||||
tadLen = zLen / shape::length(normShapeInfo);
|
||||
totalThreads = gridDim.x * blockDim.x;
|
||||
|
||||
sameOffsets = shape::haveSameShapeAndStrides(xShapeInfo, yShapeInfo, zShapeInfo);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
int zCoords[MAX_RANK], normCoords[MAX_RANK];
|
||||
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
|
||||
|
||||
shape::index2coords(i, zShapeInfo, zCoords);
|
||||
|
||||
const auto zOffset = shape::getOffset(zShapeInfo, zCoords);
|
||||
const auto yOffset = sameOffsets ? zOffset : shape::getOffset(yShapeInfo, zCoords);
|
||||
|
||||
// deduce norm coords
|
||||
for (int j = 0; j < dimsLen; ++j)
|
||||
normCoords[j] = zCoords[dimensions[j]];
|
||||
|
||||
const T actualNorm = useAverage ? norm[shape::getOffset(normShapeInfo, normCoords)] / tadLen : norm[shape::getOffset(normShapeInfo, normCoords)];
|
||||
|
||||
if(actualNorm > clipNorm) {
|
||||
|
||||
const T sumVal = sum[shape::getOffset(sumShapeInfo, normCoords)];
|
||||
const auto xOffset = sameOffsets ? zOffset : shape::getOffset(xShapeInfo, zCoords);
|
||||
|
||||
z[zOffset] = (clipNorm / actualNorm) * y[yOffset] * (static_cast<T>(1.f) - (x[xOffset] * sumVal) / (actualNorm * actualNorm));
|
||||
}
|
||||
else
|
||||
z[zOffset] = y[yOffset];
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
void clipByNormBp_(sd::LaunchContext* context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const std::vector<int>& dims, const NDArray& clipNorm, const bool useAverage) {
|
||||
|
||||
const int rank = input.rankOf();
|
||||
|
||||
auto actualNorms = input.reduceAlongDimension(reduce::Norm2, dims);
|
||||
|
||||
if(actualNorms.lengthOf() == 1) {
|
||||
|
||||
const T norm = useAverage ? actualNorms.e<T>(0) / static_cast<T>(input.lengthOf()) : actualNorms.e<T>(0);
|
||||
|
||||
auto clipVal = clipNorm.e<T>(0);
|
||||
|
||||
if(norm > clipVal) {
|
||||
|
||||
const T sum = input.reduceNumber(reduce::Sum).e<T>(0); // reduce to scalar
|
||||
const T factor1 = clipVal / norm;
|
||||
const T factor2 = static_cast<T>(1.f) / (norm * norm); // 1 / (norm*norm*norm)
|
||||
|
||||
auto lambda = LAMBDA_TT(x, y, sum, factor1, factor2) {
|
||||
return factor1 * y * (static_cast<T>(1.f) - factor2 * x * sum);
|
||||
};
|
||||
|
||||
const_cast<NDArray&>(input).applyPairwiseLambda(const_cast<NDArray&>(gradO), lambda, gradI);
|
||||
}
|
||||
else
|
||||
gradI.assign(gradO);
|
||||
}
|
||||
else {
|
||||
|
||||
const NDArray actualNorms = input.reduceAlongDimension(reduce::Norm2, dims);
|
||||
const NDArray sums = input.reduceAlongDimension(reduce::Sum, dims);
|
||||
|
||||
std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(gradI.rankOf(), dims);
|
||||
|
||||
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||
const int blocksPerGrid = (gradI.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||
|
||||
PointersManager manager(context, "clipByNormBp");
|
||||
|
||||
const int* dimensions = reinterpret_cast<const int*>(manager.replicatePointer(dimsToExclude.data(), dimsToExclude.size() * sizeof(int)));
|
||||
|
||||
NDArray::prepareSpecialUse({&gradI}, {&actualNorms, &sums, &clipNorm, &input, &gradO});
|
||||
clipByNormBpCuda<T><<<blocksPerGrid, threadsPerBlock, 512, *context->getCudaStream()>>>(clipNorm.specialBuffer(), input.specialBuffer(), input.specialShapeInfo(), gradO.specialBuffer(), gradO.specialShapeInfo(), actualNorms.specialBuffer(), actualNorms.specialShapeInfo(), sums.specialBuffer(), sums.specialShapeInfo(), gradI.specialBuffer(), gradI.specialShapeInfo(), dimensions, (int)dimsToExclude.size(), useAverage);
|
||||
NDArray::registerSpecialUse({&gradI}, {&actualNorms, &sums, &clipNorm, &input, &gradO});
|
||||
|
||||
manager.synchronize();
|
||||
}
|
||||
}
|
||||
BUILD_SINGLE_TEMPLATE(template void clipByNormBp_, (sd::LaunchContext* context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool useAverage), FLOAT_TYPES);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void clipByNormBp(sd::LaunchContext* context, const NDArray& input, const NDArray& gradO, NDArray& gradI, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool useAverage) {
|
||||
|
||||
const NDArray& castedInput = gradI.dataType() == input.dataType() ? input : input.cast(gradI.dataType());
|
||||
BUILD_SINGLE_SELECTOR(gradI.dataType(), clipByNormBp_, (context, castedInput, gradO, gradI, dimensions, clipNorm, useAverage), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
void clipByGlobalNorm_(sd::LaunchContext * context, std::vector<NDArray*> const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace) {
|
||||
NDArray globalNorm = NDArrayFactory::create<T>(0, inputs[0]->getContext()); //sqrt(sum([l2norm(t)**2 for t in t_list]))
|
||||
|
||||
for (auto i = 0; i < inputs.size(); i++) {
|
||||
auto input = inputs[i];
|
||||
auto l2norm = input->reduceNumber(reduce::Norm2);
|
||||
globalNorm += l2norm * l2norm;
|
||||
}
|
||||
|
||||
globalNorm.applyTransform(transform::Sqrt, globalNorm); // = sd::math::nd4j_sqrt(globalNorm);
|
||||
outputs[inputs.size()]->p(0, globalNorm);
|
||||
globalNorm.syncToHost();
|
||||
const T factor = static_cast<T>(clipNorm) / globalNorm.e<T>(0);
|
||||
|
||||
for (size_t e = 0; e < inputs.size(); e++) {
|
||||
// all-reduce
|
||||
auto input = inputs[e];
|
||||
auto output = outputs[e];
|
||||
|
||||
if (globalNorm.e<double>(0) <= clipNorm) {
|
||||
output->assign(input);
|
||||
}
|
||||
else {
|
||||
|
||||
auto lambda = LAMBDA_T(_x, factor) { return _x * factor; };
|
||||
input->applyLambda(lambda, *output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void clipByGlobalNorm(sd::LaunchContext * context, std::vector<NDArray*> const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace) {
|
||||
BUILD_SINGLE_SELECTOR(outputs[0]->dataType(), clipByGlobalNorm_, (context, inputs, clipNorm, workspace, outputs, isInplace), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void clipByGlobalNorm_, (sd::LaunchContext * context, std::vector<NDArray*> const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace), FLOAT_TYPES);
|
||||
|
||||
|
||||
template <typename T>
|
||||
static void __global__ clipByValueKernel(void* input, const Nd4jLong* inputShape, void* output, const Nd4jLong* outputShape, double leftBound, double rightBound) {
|
||||
__shared__ T* outputBuf;
|
||||
__shared__ T* inputBuf;
|
||||
__shared__ Nd4jLong length;
|
||||
__shared__ bool linearBuffers;
|
||||
if (threadIdx.x == 0) {
|
||||
outputBuf = reinterpret_cast<T *>(output);
|
||||
inputBuf = reinterpret_cast<T *>(input);
|
||||
length = shape::length(inputShape);
|
||||
linearBuffers = shape::elementWiseStride(inputShape) == shape::elementWiseStride(outputShape) && shape::elementWiseStride(inputShape) == 1;
|
||||
}
|
||||
__syncthreads();
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const auto step = gridDim.x * blockDim.x;
|
||||
|
||||
for (Nd4jLong e = tid; e < length; e += step) {
|
||||
if (linearBuffers) {
|
||||
if (inputBuf[e] > rightBound) outputBuf[e] = (T) rightBound;
|
||||
else if (inputBuf[e] < leftBound) outputBuf[e] = (T) leftBound;
|
||||
else outputBuf[e] = inputBuf[e];
|
||||
}
|
||||
else {
|
||||
auto inputOffset = shape::getIndexOffset(e, inputShape);
|
||||
auto outputOffset = shape::getIndexOffset(e, outputShape);
|
||||
if (inputBuf[inputOffset] > rightBound) outputBuf[outputOffset] = (T) rightBound;
|
||||
else if (inputBuf[inputOffset] < leftBound) outputBuf[outputOffset] = (T) leftBound;
|
||||
else outputBuf[outputOffset] = inputBuf[outputOffset];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void clipByValue_(sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output) {
|
||||
auto stream = context->getCudaStream();
|
||||
if (!input.isActualOnDeviceSide())
|
||||
input.syncToDevice();
|
||||
NDArray::prepareSpecialUse({&output}, {&input});
|
||||
clipByValueKernel<T><<<256, 512, 8192, *stream>>>(input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), leftBound, rightBound);
|
||||
NDArray::registerSpecialUse({&output}, {&input});
|
||||
}
|
||||
|
||||
void clipByValue(sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), clipByValue_, (context, input, leftBound, rightBound, output), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void clipByValue_, (sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output);, FLOAT_TYPES);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -210,14 +210,10 @@ namespace helpers {
|
|||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void reverse(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector<int>* intArgs, bool isBackProp) {
|
||||
// we need to reverse axis only if that's new op
|
||||
std::vector<int> dimensions = isBackProp ? ShapeUtils::evalDimsToExclude(input->rankOf(), *intArgs) : *intArgs;
|
||||
std::vector<int> axis = ShapeUtils::evalDimsToExclude(input->rankOf(), dimensions);
|
||||
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), dimensions);
|
||||
auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), dimensions);
|
||||
|
||||
void reverse(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector<int>* intArgs) {
|
||||
|
||||
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input->shapeInfo(), *intArgs);
|
||||
auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output->shapeInfo(), *intArgs);
|
||||
|
||||
NDArray::prepareSpecialUse({output}, {input});
|
||||
|
||||
|
|
|
@ -300,269 +300,6 @@ void tileBP(sd::LaunchContext * context, const NDArray& gradO /*input*/, NDArray
|
|||
manager.synchronize();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// x - input, y - gradO, z - gradI
|
||||
template<typename X, typename Z>
|
||||
__global__ static void clipByNormBPWholeArrCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, void* vreducBuff, const Z clipNormVal) {
|
||||
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if(tid >= shape::length(zShapeInfo))
|
||||
return;
|
||||
|
||||
const auto x = reinterpret_cast<const X*>(vx);
|
||||
const auto y = reinterpret_cast<const Z*>(vy);
|
||||
auto z = reinterpret_cast<Z*>(vz);
|
||||
|
||||
auto reducBuff = reinterpret_cast<Z*>(vreducBuff);
|
||||
uint* count = reinterpret_cast<uint*>(vreducBuff) + 16384;
|
||||
|
||||
__shared__ Z* shMem;
|
||||
__shared__ Nd4jLong len;
|
||||
__shared__ bool amIinLastBlock;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
extern __shared__ unsigned char shmem[];
|
||||
shMem = reinterpret_cast<Z*>(shmem);
|
||||
|
||||
len = shape::length(zShapeInfo); // xLen = yLen = zLen
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// fill shared memory with array elements
|
||||
const auto xVal = x[shape::getIndexOffset(tid, xShapeInfo)];
|
||||
const auto yVal = y[shape::getIndexOffset(tid, yShapeInfo)];
|
||||
|
||||
shMem[2*threadIdx.x] = static_cast<Z>(xVal * xVal); // for norm
|
||||
shMem[2*threadIdx.x + 1] = static_cast<Z>(xVal * yVal); // for input * gradO
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// accumulate sum per block
|
||||
for (int activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) {
|
||||
|
||||
if (threadIdx.x < activeThreads && tid + activeThreads < len) {
|
||||
|
||||
shMem[2*threadIdx.x] += shMem[2*(threadIdx.x + activeThreads)];
|
||||
shMem[2*threadIdx.x + 1] += shMem[2*(threadIdx.x + activeThreads) + 1];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// store accumulated sums in reduction buffer (reducBuff)
|
||||
if (threadIdx.x == 0) {
|
||||
|
||||
reducBuff[2*blockIdx.x] = shMem[0];
|
||||
reducBuff[2*blockIdx.x + 1] = shMem[1];
|
||||
|
||||
__threadfence();
|
||||
|
||||
amIinLastBlock = gridDim.x == 1 || (atomicInc(count, gridDim.x) == gridDim.x - 1);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// shared memory of last block is used for final summation of values stored in reduction buffer
|
||||
if (amIinLastBlock) {
|
||||
|
||||
for (int i = threadIdx.x; i < gridDim.x; i += blockDim.x) {
|
||||
|
||||
shMem[2*threadIdx.x] = (i == threadIdx.x ) ? reducBuff[2*i] : reducBuff[2*i] + shMem[2*threadIdx.x];
|
||||
shMem[2*threadIdx.x + 1] = (i == threadIdx.x ) ? reducBuff[2*i + 1] : reducBuff[2*i + 1] + shMem[2*threadIdx.x + 1];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// accumulate sum
|
||||
for (int activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) {
|
||||
|
||||
if (threadIdx.x < activeThreads && threadIdx.x + activeThreads < gridDim.x) {
|
||||
shMem[2*threadIdx.x] += shMem[2*(threadIdx.x + activeThreads)];
|
||||
shMem[2*threadIdx.x + 1] += shMem[2*(threadIdx.x + activeThreads) + 1];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
|
||||
reducBuff[0] = math::nd4j_sqrt<Z,Z>(shMem[0]);
|
||||
reducBuff[1] = shMem[1];
|
||||
count = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// x - input, y - gradO, z - gradI
|
||||
template<typename X, typename Z>
|
||||
__global__ static void clipByNormBPCalcGradCuda(const void* vx, const Nd4jLong* xShapeInfo, const void* vy, const Nd4jLong* yShapeInfo, void* vz, const Nd4jLong* zShapeInfo, void* vreducBuff, const Z clipNormVal) {
|
||||
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
const Nd4jLong len = shape::length(zShapeInfo); // xLen = yLen = zLen
|
||||
|
||||
if(tid >= len)
|
||||
return;
|
||||
|
||||
const auto x = reinterpret_cast<const X*>(vx);
|
||||
const auto y = reinterpret_cast<const Z*>(vy);
|
||||
auto z = reinterpret_cast<Z*>(vz);
|
||||
|
||||
__shared__ Z norm, sumOfProd;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
|
||||
norm = reinterpret_cast<Z*>(vreducBuff)[0];
|
||||
sumOfProd = reinterpret_cast<Z*>(vreducBuff)[1];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const auto yOffset = shape::getIndexOffset(tid, yShapeInfo);
|
||||
const auto zOffset = shape::getIndexOffset(tid, zShapeInfo);
|
||||
|
||||
if(norm > clipNormVal) {
|
||||
|
||||
const auto xOffset = shape::getIndexOffset(tid, xShapeInfo);
|
||||
|
||||
const Z factor1 = static_cast<Z>(1) / norm; // 1 / norm
|
||||
const Z factor2 = factor1 / (norm * norm); // 1 / (norm * norm * norm)
|
||||
|
||||
z[zOffset] = clipNormVal * (factor1 * y[yOffset] - factor2 * sumOfProd * x[xOffset]);
|
||||
}
|
||||
else {
|
||||
z[zOffset] = y[yOffset];
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// x - input, y - gradO, z - gradI
|
||||
template<typename X, typename Z>
|
||||
__global__ static void clipByNormBPTadsCuda(const void* vx, const Nd4jLong* xTadShapeInfo, const Nd4jLong* xTadOffsets, const void* vy, const Nd4jLong* yTadShapeInfo, const Nd4jLong* yTadOffsets, void* vz, const Nd4jLong* zTadShapeInfo, const Nd4jLong* zTadOffsets, const Z clipNormVal) {
|
||||
|
||||
const auto x = reinterpret_cast<const X*>(vx);
|
||||
const auto y = reinterpret_cast<const Z*>(vy);
|
||||
auto z = reinterpret_cast<Z*>(vz);
|
||||
|
||||
__shared__ Z* shMem;
|
||||
__shared__ Nd4jLong tadLen;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
|
||||
extern __shared__ unsigned char shmem[];
|
||||
shMem = reinterpret_cast<Z*>(shmem);
|
||||
tadLen = shape::length(zTadShapeInfo); // xTadLen = yTadLen = zTadLen
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
const auto* xTad = x + xTadOffsets[blockIdx.x];
|
||||
const auto* yTad = y + yTadOffsets[blockIdx.x];
|
||||
auto* zTad = z + zTadOffsets[blockIdx.x];
|
||||
|
||||
// *** FIRST STAGE - ACCUMULATE REQUIRED SUMS *** //
|
||||
|
||||
Z norm = 0;
|
||||
Z sumOfProd = 0;
|
||||
|
||||
for (uint i = threadIdx.x; i < tadLen; i += blockDim.x) {
|
||||
|
||||
const auto xOffset = shape::getIndexOffset(i, xTadShapeInfo);
|
||||
const auto yOffset = shape::getIndexOffset(i, yTadShapeInfo);
|
||||
|
||||
shMem[2*threadIdx.x] = static_cast<Z>(xTad[xOffset] * xTad[xOffset]); // for norm
|
||||
shMem[2*threadIdx.x + 1] = static_cast<Z>(xTad[xOffset] * yTad[yOffset]); // for input * gradO
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// accumulate sum per block
|
||||
for (uint activeThreads = blockDim.x / 2; activeThreads > 0; activeThreads /= 2) {
|
||||
|
||||
if (threadIdx.x < activeThreads && i + activeThreads < tadLen) {
|
||||
|
||||
shMem[2*threadIdx.x] += shMem[2*(threadIdx.x + activeThreads)];
|
||||
shMem[2*threadIdx.x + 1] += shMem[2*(threadIdx.x + activeThreads) + 1];
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
norm += shMem[0];
|
||||
sumOfProd += shMem[1];
|
||||
}
|
||||
|
||||
// *** SECOND STAGE - GRADIENT CALCULATION *** //
|
||||
|
||||
norm = math::nd4j_sqrt<Z,Z>(norm);
|
||||
|
||||
for (uint i = threadIdx.x; i < tadLen; i += blockDim.x) {
|
||||
|
||||
const auto yOffset = shape::getIndexOffset(i, yTadShapeInfo);
|
||||
const auto zOffset = shape::getIndexOffset(i, zTadShapeInfo);
|
||||
|
||||
if(norm > clipNormVal) {
|
||||
|
||||
const auto xOffset = shape::getIndexOffset(i, xTadShapeInfo);
|
||||
|
||||
const Z factor1 = static_cast<Z>(1) / norm; // 1 / norm
|
||||
const Z factor2 = factor1 / (norm * norm); // 1 / (norm * norm * norm)
|
||||
|
||||
zTad[zOffset] = clipNormVal * (factor1 * yTad[yOffset] - factor2 * sumOfProd * xTad[xOffset]);
|
||||
}
|
||||
else {
|
||||
zTad[zOffset] = yTad[yOffset];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename X, typename Z>
|
||||
static void clipByNormBPCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream,
|
||||
const void* vx, const Nd4jLong* xShapeInfo, const Nd4jLong* xTadOffsets,
|
||||
const void* vy, const Nd4jLong* yShapeInfo, const Nd4jLong* yTadOffsets,
|
||||
void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong* zTadOffsets,
|
||||
void* vreducBuff, const double clipNormVal) {
|
||||
|
||||
if(xTadOffsets == nullptr) { // means whole array
|
||||
clipByNormBPWholeArrCuda<X,Z><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vreducBuff, static_cast<Z>(clipNormVal));
|
||||
clipByNormBPCalcGradCuda<X,Z><<<blocksPerGrid, threadsPerBlock, 256, *stream>>>(vx, xShapeInfo, vy, yShapeInfo, vz, zShapeInfo, vreducBuff, static_cast<Z>(clipNormVal));
|
||||
}
|
||||
else // means tads using
|
||||
clipByNormBPTadsCuda<X,Z><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, xTadOffsets, vy, yShapeInfo, yTadOffsets, vz, zShapeInfo, zTadOffsets, static_cast<Z>(clipNormVal));
|
||||
}
|
||||
BUILD_DOUBLE_TEMPLATE(template void clipByNormBPCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, const Nd4jLong* xTadOffsets, const void *vy, const Nd4jLong *yShapeInfo, const Nd4jLong* yTadOffsets, void *vz, const Nd4jLong *zShapeInfo, const Nd4jLong* zTadOffsets, void* vreducBuff, const double clipNormVal), FLOAT_TYPES, FLOAT_TYPES);
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void clipByNormBP(sd::LaunchContext* context, const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector<int>& dimensions, const NDArray& clipNorm) {
|
||||
|
||||
PointersManager manager(context, "clipByNormBP");
|
||||
|
||||
const double clipNormVal = clipNorm.e<double>(0);
|
||||
|
||||
const auto xType = input.dataType();
|
||||
const auto zType = gradI.dataType();
|
||||
|
||||
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||
const int sharedMem = threadsPerBlock * 2 * input.sizeOfT() + 128;
|
||||
|
||||
NDArray::prepareSpecialUse({&gradI}, {&input, &gradO});
|
||||
|
||||
|
||||
if(dimensions.empty() || dimensions.size() == input.rankOf()) { // means whole array
|
||||
|
||||
const int blocksPerGrid = (input.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||
BUILD_DOUBLE_SELECTOR(xType, zType, clipByNormBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.specialBuffer(), input.specialShapeInfo(), nullptr, gradO.specialBuffer(), gradO.specialShapeInfo(), nullptr, gradI.specialBuffer(), gradI.specialShapeInfo(), nullptr, context->getReductionPointer(), clipNormVal), FLOAT_TYPES, FLOAT_TYPES);
|
||||
}
|
||||
else { // means tads using
|
||||
|
||||
auto packX = ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dimensions);
|
||||
auto packY = ConstantTadHelper::getInstance()->tadForDimensions(gradO.shapeInfo(), dimensions);
|
||||
auto packZ = ConstantTadHelper::getInstance()->tadForDimensions(gradI.shapeInfo(), dimensions);
|
||||
|
||||
const int blocksPerGrid = packX.numberOfTads();
|
||||
BUILD_DOUBLE_SELECTOR(xType, zType, clipByNormBPCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, context->getCudaStream(), input.specialBuffer(), packX.platformShapeInfo(), packX.platformOffsets(), gradO.specialBuffer(), packY.platformShapeInfo(), packY.platformOffsets(), gradI.specialBuffer(), packZ.platformShapeInfo(), packZ.platformOffsets(), nullptr, clipNormVal), FLOAT_TYPES, FLOAT_TYPES);
|
||||
}
|
||||
|
||||
NDArray::registerSpecialUse({&gradI}, {&input, &gradO});
|
||||
|
||||
manager.synchronize();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __global__ void swapShuffleKernel(T* input, Nd4jLong const* shape, Nd4jLong firstDim, sd::graph::RandomGenerator* rng) {
|
||||
auto tid = blockIdx.x * blockDim.x;
|
||||
|
@ -692,252 +429,6 @@ void clipByNormBP(sd::LaunchContext* context, const NDArray& input, const NDArra
|
|||
output.setIdentity();
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
static __global__ void clipByNormInplaceKernel(Nd4jLong numOfSubArrs, T* inputBuffer, Nd4jLong const* shape, Nd4jLong const* inputOffsets, T* norm2Buf, Nd4jLong const* norm2shape, T clipNorm) {
|
||||
for (int arr = blockIdx.x; arr < numOfSubArrs; arr += gridDim.x) {
|
||||
__shared__ T* z;
|
||||
__shared__ Nd4jLong len;
|
||||
if (threadIdx.x == 0) {
|
||||
len = shape::length(shape);
|
||||
z = inputBuffer + inputOffsets[arr];
|
||||
}
|
||||
__syncthreads();
|
||||
for (int j = threadIdx.x; j < len; j+= blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(j, shape);
|
||||
|
||||
if(norm2Buf[arr] > clipNorm)
|
||||
z[xIndex] *= clipNorm / norm2Buf[arr]; // case with ews = 1 and ordering is 'c'
|
||||
}
|
||||
}
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
static __global__ void clipByNormKernel(Nd4jLong numOfSubArrs, T* inputBuffer, Nd4jLong const* shape, Nd4jLong const* inputOffsets, T* outputBuffer, Nd4jLong const* outputShape, Nd4jLong const* outputOffsets, T* norm2Buf, Nd4jLong const* norm2shape, T clipNorm) {
|
||||
|
||||
for (Nd4jLong arr = blockIdx.x; arr < numOfSubArrs; arr += gridDim.x) {
|
||||
__shared__ T* x, *z;
|
||||
__shared__ Nd4jLong lenZ;
|
||||
__shared__ T norm2;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
x = inputBuffer + inputOffsets[arr];
|
||||
z = outputBuffer + outputOffsets[arr];
|
||||
lenZ = shape::length(outputShape);
|
||||
norm2 = norm2Buf[shape::getIndexOffset(arr, norm2shape)];
|
||||
}
|
||||
__syncthreads();
|
||||
for (Nd4jLong j = threadIdx.x; j < lenZ; j+= blockDim.x) {
|
||||
auto xIndex = shape::getIndexOffset(j, shape);
|
||||
auto zIndex = shape::getIndexOffset(j, outputShape);
|
||||
if(norm2 > clipNorm) {
|
||||
z[zIndex] = x[xIndex] * clipNorm / norm2; // case with ews = 1 and ordering is 'c'
|
||||
} else {
|
||||
z[zIndex] = x[xIndex];
|
||||
}
|
||||
//printf("%lld: %lf %lf\n", j, z[zIndex], x[xIndex]);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
static void clipByNorm_(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, NDArray const& clipNormA, const bool isInplace) {
|
||||
const int rank = input.rankOf();
|
||||
auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions);
|
||||
clipNormA.syncToHost();
|
||||
//norm2.printBuffer("Norm2");
|
||||
T const clipNorm = clipNormA.e<T>(0);
|
||||
//clipNormA.printBuffer("ClipNorm");
|
||||
auto stream = context->getCudaStream();
|
||||
if (isInplace) {
|
||||
if(norm2.lengthOf() == 1) {
|
||||
norm2.syncToHost();
|
||||
T norm2Val = norm2.e<T>(0);
|
||||
if(norm2Val > clipNorm)
|
||||
input *= clipNorm / norm2Val;
|
||||
}
|
||||
else {
|
||||
|
||||
std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(rank, dimensions);
|
||||
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(input.shapeInfo(), dimsToExclude);
|
||||
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dimensions);
|
||||
//auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output.shapeInfo(), dimsToExclude);
|
||||
T* inputBuffer = reinterpret_cast<T*>(input.specialBuffer());
|
||||
T* norm2buf = reinterpret_cast<T*>(norm2.specialBuffer());
|
||||
|
||||
clipByNormInplaceKernel<T><<<256, 512, 1024, *stream>>>(numOfSubArrs, inputBuffer, packX.specialShapeInfo(), packX.specialOffsets(), norm2buf, norm2.specialShapeInfo(), clipNorm);
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
if(norm2.lengthOf() == 1) {
|
||||
norm2.syncToHost();
|
||||
T norm2Val = norm2.e<T>(0);
|
||||
|
||||
if(norm2Val > clipNorm)
|
||||
output.assign( input * (clipNorm / norm2Val));
|
||||
else
|
||||
output.assign( input );
|
||||
}
|
||||
else {
|
||||
|
||||
std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(rank, dimensions);
|
||||
const Nd4jLong numOfSubArrs = ShapeUtils::getNumOfSubArrs(input.shapeInfo(), dimsToExclude);
|
||||
auto packX = sd::ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dimensions);
|
||||
auto packZ = sd::ConstantTadHelper::getInstance()->tadForDimensions(output.shapeInfo(), dimensions);
|
||||
T* inputBuffer = reinterpret_cast<T*>(input.specialBuffer());
|
||||
T* norm2buf = reinterpret_cast<T*>(norm2.specialBuffer());
|
||||
T* outputBuffer = reinterpret_cast<T*>(output.specialBuffer());
|
||||
|
||||
clipByNormKernel<T><<<256, 512, 1024, *stream>>>(numOfSubArrs, inputBuffer, packX.specialShapeInfo(), packX.specialOffsets(), outputBuffer, packZ.specialShapeInfo(), packZ.specialOffsets(), norm2buf, norm2.specialShapeInfo(), clipNorm);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void clipByNorm(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace) {
|
||||
BUILD_SINGLE_SELECTOR(output.dataType(), clipByNorm_, (context, input, output, dimensions, clipNorm, isInplace), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void clipByNorm_, (sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace), FLOAT_TYPES);
|
||||
|
||||
template <typename T>
|
||||
void clipByGlobalNorm_(sd::LaunchContext * context, std::vector<NDArray*> const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace) {
|
||||
NDArray globalNorm = NDArrayFactory::create<T>(0, inputs[0]->getContext()); //sqrt(sum([l2norm(t)**2 for t in t_list]))
|
||||
|
||||
for (auto i = 0; i < inputs.size(); i++) {
|
||||
auto input = inputs[i];
|
||||
auto l2norm = input->reduceNumber(reduce::Norm2);
|
||||
globalNorm += l2norm * l2norm;
|
||||
}
|
||||
|
||||
globalNorm.applyTransform(transform::Sqrt, globalNorm); // = sd::math::nd4j_sqrt(globalNorm);
|
||||
outputs[inputs.size()]->p(0, globalNorm);
|
||||
globalNorm.syncToHost();
|
||||
const T factor = static_cast<T>(clipNorm) / globalNorm.e<T>(0);
|
||||
|
||||
for (size_t e = 0; e < inputs.size(); e++) {
|
||||
// all-reduce
|
||||
auto input = inputs[e];
|
||||
auto output = outputs[e];
|
||||
|
||||
if (globalNorm.e<double>(0) <= clipNorm) {
|
||||
output->assign(input);
|
||||
}
|
||||
else {
|
||||
|
||||
auto lambda = LAMBDA_T(_x, factor) { return _x * factor; };
|
||||
input->applyLambda(lambda, *output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void clipByGlobalNorm(sd::LaunchContext * context, std::vector<NDArray*> const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace) {
|
||||
BUILD_SINGLE_SELECTOR(outputs[0]->dataType(), clipByGlobalNorm_, (context, inputs, clipNorm, workspace, outputs, isInplace), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void clipByGlobalNorm_, (sd::LaunchContext * context, std::vector<NDArray*> const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace), FLOAT_TYPES);
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template<typename T>
|
||||
static void clipByAveraged_(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace) {
|
||||
auto cn = clipNorm.e<T>(0);
|
||||
if (dimensions.size() == 0) {
|
||||
// all-reduce
|
||||
T n2 = input.reduceNumber(reduce::Norm2).e<T>(0) / static_cast<T>(input.lengthOf());
|
||||
if (n2 <= cn) {
|
||||
if (!isInplace)
|
||||
output.assign(input);
|
||||
}
|
||||
else {
|
||||
const T factor = cn / n2;
|
||||
//auto lambda = LAMBDA_T(_x, factor) { return _x * factor; };
|
||||
//input.applyLambda<T>(lambda, output);
|
||||
output.assign(input * factor);
|
||||
}
|
||||
}
|
||||
else {
|
||||
// along dimension
|
||||
auto norm2 = input.reduceAlongDimension(reduce::Norm2, dimensions, false);
|
||||
if (!isInplace)
|
||||
output.assign(input);
|
||||
auto tads = output.allTensorsAlongDimension(dimensions);
|
||||
auto outTads = output.allTensorsAlongDimension(dimensions);
|
||||
// TODO: make this CUDA-compliant somehow
|
||||
for (int e = 0; e < tads.size(); e++) {
|
||||
T n2 = norm2.e<T>(e) / static_cast<T>(tads.at(e)->lengthOf());
|
||||
const T factor = cn / n2;
|
||||
if (n2 > cn) {
|
||||
//auto lambda = LAMBDA_T(_x, factor) {return _x * factor;};
|
||||
tads.at(e)->applyScalar(scalar::Multiply, factor, *outTads.at(e));//applyLambda<T>(lambda, &output);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void clipByAveraged(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), clipByAveraged_, (context, input, output, dimensions, clipNorm, isInplace), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void clipByAveraged_, (sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace), FLOAT_TYPES);
|
||||
|
||||
/*
|
||||
if (d1 > params[1])
|
||||
return params[1];
|
||||
else if (d1 < params[0])
|
||||
return params[0];
|
||||
else return d1;
|
||||
*/
|
||||
template <typename T>
|
||||
static void __global__ clipByValueKernel(void* input, Nd4jLong const* inputShape, void* output, Nd4jLong const* outputShape, double leftBound, double rightBound) {
|
||||
__shared__ T* outputBuf;
|
||||
__shared__ T* inputBuf;
|
||||
__shared__ Nd4jLong length;
|
||||
__shared__ bool linearBuffers;
|
||||
if (threadIdx.x == 0) {
|
||||
outputBuf = reinterpret_cast<T *>(output);
|
||||
inputBuf = reinterpret_cast<T *>(input);
|
||||
length = shape::length(inputShape);
|
||||
linearBuffers = shape::elementWiseStride(inputShape) == shape::elementWiseStride(outputShape) && shape::elementWiseStride(inputShape) == 1;
|
||||
}
|
||||
__syncthreads();
|
||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const auto step = gridDim.x * blockDim.x;
|
||||
|
||||
for (Nd4jLong e = tid; e < length; e += step) {
|
||||
if (linearBuffers) {
|
||||
if (inputBuf[e] > rightBound) outputBuf[e] = (T) rightBound;
|
||||
else if (inputBuf[e] < leftBound) outputBuf[e] = (T) leftBound;
|
||||
else outputBuf[e] = inputBuf[e];
|
||||
}
|
||||
else {
|
||||
auto inputOffset = shape::getIndexOffset(e, inputShape);
|
||||
auto outputOffset = shape::getIndexOffset(e, outputShape);
|
||||
if (inputBuf[inputOffset] > rightBound) outputBuf[outputOffset] = (T) rightBound;
|
||||
else if (inputBuf[inputOffset] < leftBound) outputBuf[outputOffset] = (T) leftBound;
|
||||
else outputBuf[outputOffset] = inputBuf[outputOffset];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void clipByValue_(sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output) {
|
||||
auto stream = context->getCudaStream();
|
||||
if (!input.isActualOnDeviceSide())
|
||||
input.syncToDevice();
|
||||
NDArray::prepareSpecialUse({&output}, {&input});
|
||||
clipByValueKernel<T><<<256, 512, 8192, *stream>>>(input.specialBuffer(), input.specialShapeInfo(), output.specialBuffer(), output.specialShapeInfo(), leftBound, rightBound);
|
||||
NDArray::registerSpecialUse({&output}, {&input});
|
||||
}
|
||||
|
||||
void clipByValue(sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), clipByValue_, (context, input, leftBound, rightBound, output), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void clipByValue_, (sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output);, FLOAT_TYPES);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,9 +29,9 @@ namespace helpers {
|
|||
|
||||
void reverseSequence(sd::LaunchContext * context, const NDArray* input, const NDArray* seqLengths, NDArray* output, int seqDim, const int batchDim);
|
||||
|
||||
void reverse(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector<int>* intArgs, bool isBackProp);
|
||||
void reverse(sd::LaunchContext * context, const NDArray* input, NDArray* output, const std::vector<int>* intArgs);
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -63,13 +63,13 @@ namespace helpers {
|
|||
void mergeAdd(sd::LaunchContext * context, const std::vector<const NDArray*>& inArrs, NDArray& output);
|
||||
void mergeAddBp(sd::LaunchContext* context, const NDArray& gradient, std::vector<NDArray*>& outArrs);
|
||||
|
||||
void clipByNorm(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace);
|
||||
void clipByNorm(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace, const bool useAverage);
|
||||
|
||||
void clipByGlobalNorm(sd::LaunchContext * context, std::vector<NDArray*> const& inputs, double clipNorm, sd::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace);
|
||||
|
||||
void clipByNormBP(sd::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector<int>& dimensions, const NDArray& clipNorm);
|
||||
void clipByNormBp(sd::LaunchContext * context, const NDArray& input, const NDArray& gradO, NDArray& gradI /*output*/, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool useAverage);
|
||||
|
||||
void clipByAveraged(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace);
|
||||
void clipByValue(sd::LaunchContext * context, NDArray& input, double leftBound, double rightBound, NDArray& output);
|
||||
void clipByAveragedNorm(sd::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace);
|
||||
|
||||
void mirrorPad(sd::LaunchContext * context, const NDArray& input, const NDArray& paddings, NDArray& output, const int mode);
|
||||
|
||||
|
|
|
@ -1093,7 +1093,7 @@ namespace sd {
|
|||
return ND4J_STATUS_OK;
|
||||
|
||||
NDArray *a0 = block.array(0);
|
||||
for (int e = 0; e < block.width(); e++) {
|
||||
for (int e = 1; e < block.width(); e++) {
|
||||
auto aV = block.array(e);
|
||||
if (!shape::equalsSoft(a0->shapeInfo(), aV->shapeInfo()))
|
||||
return ND4J_STATUS_BAD_DIMENSIONS;
|
||||
|
|
|
@ -90,13 +90,12 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
|
|||
// x
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format);
|
||||
mkldnnUtils::setBlockStrides(*x, x_user_md);
|
||||
|
||||
mkldnnUtils::setBlockStrides(x, x_user_md);
|
||||
// z, output
|
||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, format);
|
||||
|
||||
mkldnnUtils::setBlockStrides(z, z_user_md);
|
||||
mkldnnUtils::setBlockStrides(*z, z_user_md);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
|
@ -112,15 +111,10 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
|
|||
// provide memory and check whether reorder is required
|
||||
|
||||
// x
|
||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// z
|
||||
auto z_user_mem = dnnl::memory(z_user_md, engine, z->buffer());
|
||||
const bool zReorder = op_ff_prim_desc.dst_desc() != z_user_mem.get_desc();
|
||||
auto z_mkl_mem = zReorder ? dnnl::memory(op_ff_prim_desc.dst_desc(), engine) : z_user_mem;
|
||||
if (zReorder)
|
||||
dnnl::reorder(z_user_mem, z_mkl_mem).execute(stream, z_user_mem, z_mkl_mem);
|
||||
args[DNNL_ARG_DST] = z_mkl_mem;
|
||||
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*z, engine, stream, z_user_md, op_ff_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
|
||||
|
||||
// mean
|
||||
auto mean_mkl_mem = dnnl::memory(op_ff_prim_desc.mean_desc(), engine, const_cast<void*>(mean->buffer()));
|
||||
|
@ -141,8 +135,8 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
|
|||
dnnl::batch_normalization_forward(op_ff_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder outputs if necessary
|
||||
if (zReorder)
|
||||
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
||||
if (op_ff_prim_desc.dst_desc() != z_user_mem.get_desc())
|
||||
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
|
||||
|
||||
stream.wait();
|
||||
|
||||
|
@ -151,7 +145,7 @@ static void batchnormMKLDNN(const NDArray* x, const NDArray* mean, const NDArray
|
|||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray &dLdO, const NDArray* weights,
|
||||
static void batchnormBpMKLDNN(const NDArray* x, const NDArray* mean, const NDArray* variance, const NDArray &dLdO, const NDArray* weights,
|
||||
NDArray* dLdI, NDArray* dLdW, const float epsilon, const bool isNCHW) {
|
||||
|
||||
// unfortunately mkl dnn doesn't support any format (dnnl::memory::format_tag::any) for x
|
||||
|
@ -206,20 +200,17 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
|||
// x
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(dims, type, format);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(dims, type, format);
|
||||
|
||||
mkldnnUtils::setBlockStrides(x, x_user_md);
|
||||
mkldnnUtils::setBlockStrides(*x, x_user_md);
|
||||
|
||||
// dLdO
|
||||
dnnl::memory::desc dLdO_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc dLdO_user_md = dnnl::memory::desc(dims, type, format);
|
||||
|
||||
mkldnnUtils::setBlockStrides(&dLdO, dLdO_user_md);
|
||||
mkldnnUtils::setBlockStrides(dLdO, dLdO_user_md);
|
||||
|
||||
// dLdI
|
||||
dnnl::memory::desc dLdI_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc dLdI_user_md = dnnl::memory::desc(dims, type, format);
|
||||
|
||||
mkldnnUtils::setBlockStrides(dLdI, dLdI_user_md);
|
||||
mkldnnUtils::setBlockStrides(*dLdI, dLdI_user_md);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
|
@ -239,10 +230,10 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
|||
// provide memory and check whether reorder is required
|
||||
|
||||
// x
|
||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// dLdO
|
||||
mkldnnUtils::loadDataToMklStream(&dLdO, engine, stream, dLdO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
|
||||
mkldnnUtils::loadDataToMklStream(dLdO, engine, stream, dLdO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
|
||||
|
||||
// mean
|
||||
auto mean_mkl_mem = dnnl::memory(op_bp_prim_desc.mean_desc(), engine, const_cast<void*>(mean->buffer()));
|
||||
|
@ -253,10 +244,7 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
|||
args[DNNL_ARG_VARIANCE] = var_mkl_mem;
|
||||
|
||||
// dLdI
|
||||
auto dLdI_user_mem = dnnl::memory(dLdI_user_md, engine, dLdI->buffer());
|
||||
const bool dLdIReorder = op_bp_prim_desc.diff_src_desc() != dLdI_user_mem.get_desc();
|
||||
auto dLdI_mkl_mem = dLdIReorder ? dnnl::memory(op_bp_prim_desc.diff_src_desc(), engine) : dLdI_user_mem;
|
||||
args[DNNL_ARG_DIFF_SRC] = dLdI_mkl_mem;
|
||||
auto dLdI_user_mem = mkldnnUtils::loadDataToMklStream(*dLdI, engine, stream, dLdI_user_md, op_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]);
|
||||
|
||||
// gamma and beta (and their gradients) if they are present
|
||||
if(weights != nullptr) {
|
||||
|
@ -272,8 +260,8 @@ static void batchnormBackPropMKLDNN(const NDArray* x, const NDArray* mean, const
|
|||
dnnl::batch_normalization_backward(op_bp_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder outputs if necessary
|
||||
if (dLdIReorder)
|
||||
dnnl::reorder(dLdI_mkl_mem, dLdI_user_mem).execute(stream, dLdI_mkl_mem, dLdI_user_mem);
|
||||
if (op_bp_prim_desc.diff_src_desc() != dLdI_user_mem.get_desc())
|
||||
dnnl::reorder(args[DNNL_ARG_DIFF_SRC], dLdI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], dLdI_user_mem);
|
||||
|
||||
stream.wait();
|
||||
|
||||
|
@ -662,9 +650,9 @@ PLATFORM_IMPL(batchnorm_bp, ENGINE_CPU) {
|
|||
const bool isNCHW = !(axes[0] == inRank - 1 && inRank > 2);
|
||||
|
||||
if (shape::strideDescendingCAscendingF(dLdO->shapeInfo()))
|
||||
batchnormBackPropMKLDNN(input, mean, variance, *dLdO, weights, dLdI, dLdW, epsilon, isNCHW);
|
||||
batchnormBpMKLDNN(input, mean, variance, *dLdO, weights, dLdI, dLdW, epsilon, isNCHW);
|
||||
else
|
||||
batchnormBackPropMKLDNN(input, mean, variance, dLdO->dup(), weights, dLdI, dLdW, epsilon, isNCHW);
|
||||
batchnormBpMKLDNN(input, mean, variance, dLdO->dup(), weights, dLdI, dLdW, epsilon, isNCHW);
|
||||
|
||||
*dLdM = 0;
|
||||
*dLdV = 0;
|
||||
|
|
|
@ -0,0 +1,186 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
//
|
||||
|
||||
#include <ops/declarable/PlatformHelper.h>
|
||||
#include <ops/declarable/OpRegistrator.h>
|
||||
#include <system/platform_boilerplate.h>
|
||||
|
||||
#include <helpers/MKLDNNStream.h>
|
||||
#include "mkldnnUtils.h"
|
||||
#include <numeric>
|
||||
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace platforms {
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
static void concatMKLDNN(const std::vector<const NDArray*>& inArrs, NDArray& output, const int axis) {
|
||||
|
||||
// data type
|
||||
dnnl::memory::data_type type;
|
||||
if(output.dataType() == DataType::FLOAT32)
|
||||
type = dnnl::memory::data_type::f32;
|
||||
else if(output.dataType() == DataType::HALF)
|
||||
type = dnnl::memory::data_type::f16;
|
||||
else if(output.dataType() == DataType::BFLOAT16)
|
||||
type = dnnl::memory::data_type::bf16;
|
||||
else if(output.dataType() == DataType::UINT8)
|
||||
type = dnnl::memory::data_type::u8;
|
||||
else
|
||||
type = dnnl::memory::data_type::s8;
|
||||
|
||||
std::vector<dnnl::memory::desc> x_user_md(inArrs.size()), x_mkl_md(inArrs.size());
|
||||
|
||||
// inputs
|
||||
for (int i = 0; i < inArrs.size(); ++i) {
|
||||
|
||||
dnnl::memory::dims dims = inArrs[i]->getShapeAsFlatVector();
|
||||
x_user_md[i] = x_mkl_md[i] = dnnl::memory::desc(dims, type, mkldnnUtils::getFormat(*inArrs[i]));
|
||||
mkldnnUtils::setBlockStrides(*inArrs[i], x_user_md[i]);
|
||||
}
|
||||
|
||||
// output
|
||||
dnnl::memory::dims dims = output.getShapeAsFlatVector();
|
||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(dims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(dims, type, mkldnnUtils::getFormat(output));
|
||||
mkldnnUtils::setBlockStrides(output, z_user_md);
|
||||
|
||||
std::unordered_map<int, dnnl::memory> args;
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
dnnl::concat::primitive_desc op_prim_desc(axis, x_mkl_md, engine);
|
||||
|
||||
dnnl::stream stream(engine);
|
||||
|
||||
// inputs
|
||||
for (int i = 0; i < inArrs.size(); ++i)
|
||||
mkldnnUtils::loadDataToMklStream(*inArrs[i], engine, stream, x_user_md[i], op_prim_desc.src_desc(i), args[DNNL_ARG_MULTIPLE_SRC + i]);
|
||||
|
||||
// outputs
|
||||
auto z_user_mem = mkldnnUtils::loadDataToMklStream(output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
|
||||
|
||||
// primitive execution
|
||||
dnnl::concat(op_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder output if necessary
|
||||
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
|
||||
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
|
||||
|
||||
stream.wait();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_IMPL(concat, ENGINE_CPU) {
|
||||
|
||||
REQUIRE_TRUE(block.width() > 0, 0, "CONCAT MKLDNN op: No input arrays were provided");
|
||||
|
||||
const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0);
|
||||
|
||||
const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width();
|
||||
|
||||
// first of all take into account possible presence of empty arrays
|
||||
// also if scalar is present -> copy its value to vector with length=1
|
||||
std::vector<const NDArray*> nonEmptyArrs;
|
||||
std::vector<int> arrsToDelete;
|
||||
int index = 0;
|
||||
bool allOfSameType = true;
|
||||
auto rankOfFirstArr = block.width() > 0 ? INPUT_VARIABLE(0)->rankOf() : 0;
|
||||
auto typeOfFirstArr = block.width() > 0 ? INPUT_VARIABLE(0)->dataType() : block.dataType();
|
||||
|
||||
for(int i = 0; i < numOfInArrs; ++i) {
|
||||
auto input = INPUT_VARIABLE(i);
|
||||
auto currentRank = input->rankOf();
|
||||
|
||||
if(!input->isEmpty()) {
|
||||
|
||||
allOfSameType &= (typeOfFirstArr == input->dataType());
|
||||
|
||||
if(input->rankOf() == 0) {
|
||||
auto vec = new NDArray('c', {1}, input->dataType(), block.launchContext());
|
||||
vec->assign(input);
|
||||
nonEmptyArrs.push_back(vec);
|
||||
arrsToDelete.push_back(index);
|
||||
}
|
||||
else{
|
||||
nonEmptyArrs.push_back(input);
|
||||
}
|
||||
++index;
|
||||
}
|
||||
}
|
||||
|
||||
const int numOfNonEmptyArrs = nonEmptyArrs.size();
|
||||
|
||||
if(numOfNonEmptyArrs == 0){
|
||||
//All inputs are empty arrays -> return empty, mainly for TF import compatibility (no op)
|
||||
REQUIRE_TRUE(OUTPUT_VARIABLE(0)->isEmpty(), 0, "CONCAT MKLDNN op: If all input variables are empty, output must be empty");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const int rank = nonEmptyArrs[0]->rankOf(); // look up to first non-empty array
|
||||
int axis = isAxisInLastArr ? INPUT_VARIABLE(block.width() - 1)->e<int>(0) : INT_ARG(0);
|
||||
if(axis < 0){
|
||||
axis += rank;
|
||||
}
|
||||
|
||||
// ******** input validation ******** //
|
||||
REQUIRE_TRUE(allOfSameType, 0, "CONCAT MKLDNN op: all of input arrays must have same type !");
|
||||
REQUIRE_TRUE(nonEmptyArrs[0]->dataType() == OUTPUT_VARIABLE(0)->dataType(), 0, "CONCAT MKLDNN op: output array should have the same type as inputs arrays !");
|
||||
REQUIRE_TRUE(0 <= axis && (axis < rank || (axis == 0 && rank == 0)), 0, "CONCAT MKLDNN op: input axis must be in range [0, %i], but got %i instead!", rank-1, axis);
|
||||
|
||||
for(int i = 1; i < numOfNonEmptyArrs; ++i)
|
||||
REQUIRE_TRUE(nonEmptyArrs[i]->rankOf() == rank, 0, "CONCAT MKLDNN op: all input arrays must have the same rank !");
|
||||
|
||||
for(int i = 1; i < numOfNonEmptyArrs; ++i) {
|
||||
for(int dim = 0; dim < rank; ++dim)
|
||||
if(dim != axis)
|
||||
REQUIRE_TRUE(nonEmptyArrs[i]->sizeAt(dim) == nonEmptyArrs[0]->sizeAt(dim), 0, "CONCAT MKLDNN op: all input arrays must have the same dimensions (except those on input axis) !");
|
||||
}
|
||||
// ******** end of input validation ******** //
|
||||
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
if(numOfNonEmptyArrs == 1)
|
||||
output->assign(nonEmptyArrs[0]);
|
||||
else
|
||||
concatMKLDNN(nonEmptyArrs, *output, axis);
|
||||
|
||||
// delete dynamically allocated vectors with length=1
|
||||
for(int index : arrsToDelete)
|
||||
delete nonEmptyArrs[index];
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
PLATFORM_CHECK(concat, ENGINE_CPU) {
|
||||
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
||||
const auto zType = z->dataType();
|
||||
|
||||
return z->rankOf() < 7 && (zType==DataType::FLOAT32 || zType==DataType::HALF || zType==DataType::BFLOAT16 || zType==DataType::UINT8 || zType==DataType::INT8);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -62,33 +62,23 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
|
|||
|
||||
auto type = dnnl::memory::data_type::f32;
|
||||
|
||||
std::vector<int> permut;
|
||||
if(0 == wFormat)
|
||||
permut = {3,2,0,1}; // [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||
else if(2 == wFormat)
|
||||
permut = {0,3,1,2}; // [oC, kH, kW, iC] -> [oC, iC, kH, kW]
|
||||
|
||||
// memory descriptors for arrays
|
||||
|
||||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(input, x_user_md);
|
||||
mkldnnUtils::setBlockStrides(*input, x_user_md);
|
||||
|
||||
// weights
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
|
||||
if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) {
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
uint i0, i1, i2, i3;
|
||||
if(0 == wFormat) {
|
||||
i0 = 3; i1 = 2; i2 = 0; i3 = 1; // [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||
}
|
||||
else if(1 == wFormat) {
|
||||
i0 = 0; i1 = 1; i2 = 2; i3 = 3;
|
||||
}
|
||||
else {
|
||||
i0 = 0; i1 = 3; i2 = 1; i3 = 2; // [oC, kH, kW, iC] -> [oC, iC, kH, kW]
|
||||
}
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
|
||||
}
|
||||
mkldnnUtils::setBlockStrides(*weights, w_user_md, permut);
|
||||
|
||||
// bias
|
||||
dnnl::memory::desc b_mkl_md;
|
||||
|
@ -98,7 +88,7 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
|
|||
// output
|
||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(output, z_user_md);
|
||||
mkldnnUtils::setBlockStrides(*output, z_user_md);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
|
@ -114,10 +104,10 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
|
|||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// weights
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
|
||||
// bias
|
||||
if(bias != nullptr) {
|
||||
|
@ -126,17 +116,14 @@ static void conv2dMKLDNN(const NDArray *input, const NDArray *weights,
|
|||
}
|
||||
|
||||
// output
|
||||
auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer());
|
||||
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
||||
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
||||
args[DNNL_ARG_DST] = z_mkl_mem;
|
||||
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
|
||||
|
||||
// run calculations
|
||||
dnnl::convolution_forward(op_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder outputs if necessary
|
||||
if (zReorder)
|
||||
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
||||
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
|
||||
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
|
||||
|
||||
stream.wait();
|
||||
// shape::printArray(z_mkl_mem.map_data<float>(),8);
|
||||
|
@ -170,64 +157,38 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
|
|||
|
||||
auto type = dnnl::memory::data_type::f32;
|
||||
|
||||
std::vector<int> permut;
|
||||
if(0 == wFormat)
|
||||
permut = {3,2,0,1}; // [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||
else if(2 == wFormat)
|
||||
permut = {0,3,1,2}; // [oC, kH, kW, iC] -> [oC, iC, kH, kW]
|
||||
|
||||
// memory descriptors for arrays
|
||||
|
||||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(input, x_user_md);
|
||||
mkldnnUtils::setBlockStrides(*input, x_user_md);
|
||||
|
||||
// weights
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
|
||||
if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) {
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
uint i0, i1, i2, i3;
|
||||
if(0 == wFormat) {
|
||||
i0 = 3; i1 = 2; i2 = 0; i3 = 1; // [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||
}
|
||||
else if(1 == wFormat) {
|
||||
i0 = 0; i1 = 1; i2 = 2; i3 = 3;
|
||||
}
|
||||
else {
|
||||
i0 = 0; i1 = 3; i2 = 1; i3 = 2; // [oC, kH, kW, iC] -> [oC, iC, kH, kW]
|
||||
}
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
|
||||
}
|
||||
mkldnnUtils::setBlockStrides(*weights, w_user_md, permut);
|
||||
|
||||
// gradO
|
||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(gradO, gradO_user_md);
|
||||
mkldnnUtils::setBlockStrides(*gradO, gradO_user_md);
|
||||
|
||||
// gradI
|
||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(gradI, gradI_user_md);
|
||||
mkldnnUtils::setBlockStrides(*gradI, gradI_user_md);
|
||||
|
||||
// gradW
|
||||
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
|
||||
if(gradW->ews() != 1 || gradW->ordering() != 'c' || 1 != wFormat) {
|
||||
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
uint i0, i1, i2, i3;
|
||||
if(0 == wFormat) {
|
||||
i0 = 3; i1 = 2; i2 = 0; i3 = 1; // [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||
}
|
||||
else if(1 == wFormat) {
|
||||
i0 = 0; i1 = 1; i2 = 2; i3 = 3;
|
||||
}
|
||||
else {
|
||||
i0 = 0; i1 = 3; i2 = 1; i3 = 2; // [oC, kH, kW, iC] -> [oC, iC, kH, kW]
|
||||
}
|
||||
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0);
|
||||
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1);
|
||||
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2);
|
||||
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3);
|
||||
}
|
||||
mkldnnUtils::setBlockStrides(*gradW, gradW_user_md, permut);
|
||||
|
||||
// gradB
|
||||
dnnl::memory::desc gradB_mkl_md;
|
||||
|
@ -256,10 +217,10 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
|
|||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// weights
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
|
||||
// gradO
|
||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast<void*>(gradO->buffer()));
|
||||
|
@ -274,16 +235,10 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
|
|||
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
|
||||
|
||||
// gradI
|
||||
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer());
|
||||
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
|
||||
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
|
||||
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
|
||||
auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_data_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]);
|
||||
|
||||
// gradW
|
||||
auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->buffer());
|
||||
const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc();
|
||||
auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
|
||||
args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
|
||||
auto gradW_user_mem = mkldnnUtils::loadDataToMklStream(*gradW, engine, stream, gradW_user_md, op_weights_bp_prim_desc.diff_weights_desc(), args[DNNL_ARG_DIFF_WEIGHTS]);
|
||||
|
||||
// gradB
|
||||
if(gradB != nullptr) {
|
||||
|
@ -301,10 +256,10 @@ static void conv2dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
|
|||
dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder gradI if necessary
|
||||
if (gradIReorder)
|
||||
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
|
||||
if (gradWReorder)
|
||||
dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
|
||||
if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc())
|
||||
dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem);
|
||||
if (op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc())
|
||||
dnnl::reorder(args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem).execute(stream, args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem);
|
||||
|
||||
stream.wait();
|
||||
|
||||
|
|
|
@ -63,6 +63,12 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
|
|||
dnnl::memory::dims wDims = {oC, iC, kD, kH, kW};
|
||||
dnnl::memory::dims zDims = {bS, oC, oD, oH, oW};
|
||||
|
||||
std::vector<int> permut;
|
||||
if(0 == wFormat)
|
||||
permut = {4,3,0,1,2}; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
|
||||
else if(2 == wFormat)
|
||||
permut = {0,4,1,2,3}; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW]
|
||||
|
||||
auto type = dnnl::memory::data_type::f32;
|
||||
|
||||
// memory descriptors for arrays
|
||||
|
@ -70,29 +76,12 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
|
|||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(input, x_user_md);
|
||||
mkldnnUtils::setBlockStrides(*input, x_user_md);
|
||||
|
||||
// weights
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
|
||||
if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) {
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
uint i0, i1, i2, i3, i4;
|
||||
if(0 == wFormat) {
|
||||
i0 = 4; i1 = 3; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
|
||||
}
|
||||
else if(1 == wFormat) {
|
||||
i0 = 0; i1 = 1; i2 = 2; i3 = 3; i4 = 4;
|
||||
}
|
||||
else {
|
||||
i0 = 0; i1 = 4; i2 = 1; i3 = 2; i4 = 3; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW]
|
||||
}
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
|
||||
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4);
|
||||
}
|
||||
mkldnnUtils::setBlockStrides(*weights, w_user_md, permut);
|
||||
|
||||
// bias
|
||||
dnnl::memory::desc b_mkl_md;
|
||||
|
@ -102,7 +91,7 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
|
|||
// output
|
||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(output, z_user_md);
|
||||
mkldnnUtils::setBlockStrides(*output, z_user_md);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
|
@ -118,10 +107,10 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
|
|||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// weights
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
|
||||
// bias
|
||||
if(bias != nullptr) {
|
||||
|
@ -130,17 +119,14 @@ static void conv3dMKLDNN(const NDArray *input, const NDArray *weights,
|
|||
}
|
||||
|
||||
// output
|
||||
auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer());
|
||||
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
||||
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
||||
args[DNNL_ARG_DST] = z_mkl_mem;
|
||||
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
|
||||
|
||||
// run calculations
|
||||
dnnl::convolution_forward(op_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder outputs if necessary
|
||||
if (zReorder)
|
||||
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
||||
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
|
||||
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
|
||||
|
||||
stream.wait();
|
||||
}
|
||||
|
@ -177,68 +163,40 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
|
|||
|
||||
auto type = dnnl::memory::data_type::f32;
|
||||
|
||||
std::vector<int> permut;
|
||||
if(0 == wFormat)
|
||||
permut = {4,3,0,1,2}; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
|
||||
else if(2 == wFormat)
|
||||
permut = {0,4,1,2,3}; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW]
|
||||
|
||||
// memory descriptors for arrays
|
||||
|
||||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(input, x_user_md);
|
||||
mkldnnUtils::setBlockStrides(*input, x_user_md);
|
||||
|
||||
// weights
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
|
||||
if(weights->ews() != 1 || weights->ordering() != 'c' || 1 != wFormat) {
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
uint i0, i1, i2, i3, i4;
|
||||
if(0 == wFormat) {
|
||||
i0 = 4; i1 = 3; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
|
||||
}
|
||||
else if(1 == wFormat) {
|
||||
i0 = 0; i1 = 1; i2 = 2; i3 = 3; i4 = 4;
|
||||
}
|
||||
else {
|
||||
i0 = 0; i1 = 4; i2 = 1; i3 = 2; i4 = 3; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW]
|
||||
}
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
|
||||
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4);
|
||||
}
|
||||
mkldnnUtils::setBlockStrides(*weights, w_user_md, permut);
|
||||
|
||||
// gradO
|
||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFormatMkl);
|
||||
|
||||
mkldnnUtils::setBlockStrides(gradO, gradO_user_md);
|
||||
mkldnnUtils::setBlockStrides(*gradO, gradO_user_md);
|
||||
|
||||
// gradI
|
||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFormatMkl);
|
||||
|
||||
mkldnnUtils::setBlockStrides(gradI, gradI_user_md);
|
||||
mkldnnUtils::setBlockStrides(*gradI, gradI_user_md);
|
||||
|
||||
// gradW
|
||||
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, type, wFormatMkl);
|
||||
if(gradW->ews() != 1 || gradW->ordering() != 'c' || 1 != wFormat) {
|
||||
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
uint i0, i1, i2, i3, i4;
|
||||
if(0 == wFormat) {
|
||||
i0 = 4; i1 = 3; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, iC, oC] -> [oC, iC, kD, kH, kW]
|
||||
}
|
||||
else if(1 == wFormat) {
|
||||
i0 = 0; i1 = 1; i2 = 2; i3 = 3; i4 = 4;
|
||||
}
|
||||
else {
|
||||
i0 = 0; i1 = 4; i2 = 1; i3 = 2; i4 = 3; // [oC, kD, kH, kW, iC] -> [oC, iC, kD, kH, kW]
|
||||
}
|
||||
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0);
|
||||
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1);
|
||||
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2);
|
||||
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3);
|
||||
gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(i4);
|
||||
}
|
||||
mkldnnUtils::setBlockStrides(*gradW, gradW_user_md, permut);
|
||||
|
||||
// gradB
|
||||
dnnl::memory::desc gradB_mkl_md;
|
||||
|
@ -267,10 +225,10 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
|
|||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// weights
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
|
||||
// gradO
|
||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast<void*>(gradO->buffer()));
|
||||
|
@ -285,16 +243,10 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
|
|||
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
|
||||
|
||||
// gradI
|
||||
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer());
|
||||
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
|
||||
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
|
||||
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
|
||||
auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_data_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]);
|
||||
|
||||
// gradW
|
||||
auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->buffer());
|
||||
const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc();
|
||||
auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
|
||||
args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
|
||||
auto gradW_user_mem = mkldnnUtils::loadDataToMklStream(*gradW, engine, stream, gradW_user_md, op_weights_bp_prim_desc.diff_weights_desc(), args[DNNL_ARG_DIFF_WEIGHTS]);
|
||||
|
||||
// gradB
|
||||
if(gradB != nullptr) {
|
||||
|
@ -312,10 +264,10 @@ static void conv3dBpMKLDNN(const NDArray *input, const NDArray *weights, const N
|
|||
dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder gradI if necessary
|
||||
if (gradIReorder)
|
||||
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
|
||||
if (gradWReorder)
|
||||
dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
|
||||
if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc())
|
||||
dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem);
|
||||
if (op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc())
|
||||
dnnl::reorder(args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem).execute(stream, args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem);
|
||||
|
||||
stream.wait();
|
||||
|
||||
|
|
|
@ -47,16 +47,13 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
|
||||
dnnl::memory::dims dilation = { dH-1, dW-1 };
|
||||
|
||||
uint i0, i1, i2, i3;
|
||||
if(0 == wFormat) {
|
||||
i0 = 2; i1 = 3; i2 = 0; i3 = 1; // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
|
||||
}
|
||||
else if(1 == wFormat) {
|
||||
i0 = 1; i1 = 0; i2 = 2; i3 = 3; // [iC, oC, kH, kW] -> [oC, iC, kH, kW]
|
||||
}
|
||||
else {
|
||||
i0 = 3; i1 = 0; i2 = 1; i3 = 2; // [iC, kH, kW, oC] -> [oC, iC, kH, kW]
|
||||
}
|
||||
std::vector<int> permut;
|
||||
if(0 == wFormat)
|
||||
permut = {2,3,0,1}; // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
|
||||
else if(1 == wFormat)
|
||||
permut = {1,0,2,3}; // [iC, oC, kH, kW] -> [oC, iC, kH, kW]
|
||||
else
|
||||
permut = {3,0,1,2}; // [iC, kH, kW, oC] -> [oC, iC, kH, kW]
|
||||
|
||||
// input type
|
||||
dnnl::memory::data_type xType;
|
||||
|
@ -99,16 +96,12 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(input, x_user_md);
|
||||
mkldnnUtils::setBlockStrides(*input, x_user_md);
|
||||
|
||||
// weights
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
|
||||
mkldnnUtils::setBlockStrides(*weights, w_user_md, permut);
|
||||
|
||||
// bias
|
||||
dnnl::memory::desc b_mkl_md;
|
||||
|
@ -118,7 +111,7 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
// output
|
||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(output, z_user_md);
|
||||
mkldnnUtils::setBlockStrides(*output, z_user_md);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
|
@ -135,10 +128,10 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// weights
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
|
||||
// bias
|
||||
if(bias != nullptr) {
|
||||
|
@ -147,17 +140,14 @@ static void deconv2dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
}
|
||||
|
||||
// output
|
||||
auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer());
|
||||
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
||||
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
||||
args[DNNL_ARG_DST] = z_mkl_mem;
|
||||
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
|
||||
|
||||
// run calculations
|
||||
dnnl::deconvolution_forward(op_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder outputs if necessary
|
||||
if (zReorder)
|
||||
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
||||
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
|
||||
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
|
||||
|
||||
stream.wait();
|
||||
|
||||
|
@ -180,16 +170,13 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
|
|||
dnnl::memory::dims padding_r = { (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
|
||||
dnnl::memory::dims dilation = { dH-1, dW-1 };
|
||||
|
||||
uint i0, i1, i2, i3;
|
||||
if(0 == wFormat) {
|
||||
i0 = 2; i1 = 3; i2 = 0; i3 = 1; // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
|
||||
}
|
||||
else if(1 == wFormat) {
|
||||
i0 = 1; i1 = 0; i2 = 2; i3 = 3; // [iC, oC, kH, kW] -> [oC, iC, kH, kW]
|
||||
}
|
||||
else {
|
||||
i0 = 3; i1 = 0; i2 = 1; i3 = 2; // [iC, kH, kW, oC] -> [oC, iC, kH, kW]
|
||||
}
|
||||
std::vector<int> permut;
|
||||
if(0 == wFormat)
|
||||
permut = {2,3,0,1}; // [kH, kW, oC, iC] -> [oC, iC, kH, kW]
|
||||
else if(1 == wFormat)
|
||||
permut = {1,0,2,3}; // [iC, oC, kH, kW] -> [oC, iC, kH, kW]
|
||||
else
|
||||
permut = {3,0,1,2}; // [iC, kH, kW, oC] -> [oC, iC, kH, kW]
|
||||
|
||||
// input type
|
||||
dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
||||
|
@ -216,35 +203,27 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
|
|||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(input, x_user_md);
|
||||
mkldnnUtils::setBlockStrides(*input, x_user_md);
|
||||
|
||||
// weights
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
|
||||
mkldnnUtils::setBlockStrides(*weights, w_user_md, permut);
|
||||
|
||||
// gradO
|
||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(gradO, gradO_user_md);
|
||||
mkldnnUtils::setBlockStrides(*gradO, gradO_user_md);
|
||||
|
||||
// gradI
|
||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(gradI, gradI_user_md);
|
||||
mkldnnUtils::setBlockStrides(*gradI, gradI_user_md);
|
||||
|
||||
// gradW
|
||||
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormatMkl);
|
||||
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0);
|
||||
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1);
|
||||
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2);
|
||||
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3);
|
||||
mkldnnUtils::setBlockStrides(*gradW, gradW_user_md, permut);
|
||||
|
||||
// gradB
|
||||
dnnl::memory::desc gradB_mkl_md;
|
||||
|
@ -273,10 +252,10 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
|
|||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// weights
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
|
||||
// gradO
|
||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast<void*>(gradO->buffer()));
|
||||
|
@ -291,16 +270,10 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
|
|||
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
|
||||
|
||||
// gradI
|
||||
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer());
|
||||
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
|
||||
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
|
||||
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
|
||||
auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_data_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]);
|
||||
|
||||
// gradW
|
||||
auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->buffer());
|
||||
const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc();
|
||||
auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
|
||||
args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
|
||||
auto gradW_user_mem = mkldnnUtils::loadDataToMklStream(*gradW, engine, stream, gradW_user_md, op_weights_bp_prim_desc.diff_weights_desc(), args[DNNL_ARG_DIFF_WEIGHTS]);
|
||||
|
||||
// gradB
|
||||
if(gradB != nullptr) {
|
||||
|
@ -318,10 +291,10 @@ static void deconv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const
|
|||
dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder gradI if necessary
|
||||
if (gradIReorder)
|
||||
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
|
||||
if (gradWReorder)
|
||||
dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
|
||||
if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc())
|
||||
dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem);
|
||||
if (op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc())
|
||||
dnnl::reorder(args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem).execute(stream, args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem);
|
||||
|
||||
stream.wait();
|
||||
|
||||
|
|
|
@ -31,7 +31,7 @@ namespace ops {
|
|||
namespace platforms {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* gradO, NDArray* gradI,
|
||||
static void deconv2TFdBpMKLDNN(const NDArray* weights, const NDArray* gradO, NDArray* gradI,
|
||||
const int bS, const int iC, const int iH, const int iW, const int oC, const int oH, const int oW,
|
||||
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
||||
const bool isNCHW, const int wFormat) {
|
||||
|
@ -67,21 +67,17 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
|
|||
// weights
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(3); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(2);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(0);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(1);
|
||||
mkldnnUtils::setBlockStrides(*weights, w_user_md, {3,2,0,1}); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||
|
||||
// gradO
|
||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(gradO, gradO_user_md);
|
||||
mkldnnUtils::setBlockStrides(*gradO, gradO_user_md);
|
||||
|
||||
// gradI
|
||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(gradI, gradI_user_md);
|
||||
mkldnnUtils::setBlockStrides(*gradI, gradI_user_md);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
|
@ -101,23 +97,20 @@ static void deconv2TFdBackPropMKLDNN(const NDArray* weights, const NDArray* grad
|
|||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// weights
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
|
||||
// gradO
|
||||
mkldnnUtils::loadDataToMklStream(gradO, engine, stream, gradO_user_md, op_data_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
|
||||
mkldnnUtils::loadDataToMklStream(*gradO, engine, stream, gradO_user_md, op_data_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
|
||||
|
||||
// gradI
|
||||
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer());
|
||||
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
|
||||
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
|
||||
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
|
||||
auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_data_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]);
|
||||
|
||||
// run backward data calculations
|
||||
dnnl::convolution_backward_data(op_data_bp_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder gradI if necessary
|
||||
if (gradIReorder)
|
||||
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
|
||||
if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc())
|
||||
dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem);
|
||||
|
||||
stream.wait();
|
||||
|
||||
|
@ -189,7 +182,7 @@ PLATFORM_IMPL(deconv2d_tf, ENGINE_CPU) {
|
|||
// gradO = new NDArray(gradO->permute({0,3,1,2})); // [bS, oH, oW, oC] -> [bS, oC, oH, oW]
|
||||
// }
|
||||
|
||||
deconv2TFdBackPropMKLDNN(weights, gradO, gradI, bS, iC, iH, iW, oC, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, wFormat);
|
||||
deconv2TFdBpMKLDNN(weights, gradO, gradI, bS, iC, iH, iW, oC, oH, oW, kH, kW, sH, sW, pH, pW, dH, dW, isNCHW, wFormat);
|
||||
|
||||
// delete weights;
|
||||
|
||||
|
|
|
@ -48,16 +48,13 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
|
||||
dnnl::memory::dims dilation = { dD-1, dH-1, dW-1 };
|
||||
|
||||
uint i0, i1, i2, i3, i4;
|
||||
if(0 == wFormat) {
|
||||
i0 = 3; i1 = 4; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
|
||||
}
|
||||
else if(1 == wFormat) {
|
||||
i0 = 1; i1 = 0; i2 = 2; i3 = 3; i4 = 4; // [iC, oC, kD, kH, kW] -> [oC, iC, kD, kH, kW]
|
||||
}
|
||||
else {
|
||||
i0 = 4; i1 = 0; i2 = 1; i3 = 2; i4 = 3; // [iC, kD, kH, kW, oC] -> [oC, iC, kD, kH, kW]
|
||||
}
|
||||
std::vector<int> permut;
|
||||
if(0 == wFormat)
|
||||
permut = {3,4,0,1,2}; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
|
||||
else if(1 == wFormat)
|
||||
permut = {1,0,2,3,4}; // [iC, oC, kD, kH, kW] -> [oC, iC, kD, kH, kW]
|
||||
else
|
||||
permut = {4,0,1,2,3}; // [iC, kD, kH, kW, oC] -> [oC, iC, kD, kH, kW]
|
||||
|
||||
// input type
|
||||
dnnl::memory::data_type xType;
|
||||
|
@ -100,17 +97,12 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(input, x_user_md);
|
||||
mkldnnUtils::setBlockStrides(*input, x_user_md);
|
||||
|
||||
// weights
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
|
||||
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4);
|
||||
mkldnnUtils::setBlockStrides(*weights, w_user_md, permut);
|
||||
|
||||
// bias
|
||||
dnnl::memory::desc b_mkl_md;
|
||||
|
@ -120,7 +112,7 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
// output
|
||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(output, z_user_md);
|
||||
mkldnnUtils::setBlockStrides(*output, z_user_md);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
|
@ -137,10 +129,10 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// weights
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
|
||||
// bias
|
||||
if(bias != nullptr) {
|
||||
|
@ -149,17 +141,14 @@ static void deconv3dMKLDNN(const NDArray* input, const NDArray* weights, const N
|
|||
}
|
||||
|
||||
// output
|
||||
auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer());
|
||||
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
||||
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
||||
args[DNNL_ARG_DST] = z_mkl_mem;
|
||||
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
|
||||
|
||||
// run calculations
|
||||
dnnl::deconvolution_forward(op_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder outputs if necessary
|
||||
if (zReorder)
|
||||
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
||||
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
|
||||
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
|
||||
|
||||
stream.wait();
|
||||
|
||||
|
@ -185,16 +174,13 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
dnnl::memory::dims padding_r = { (iD - 1) * sD - oD + kD - pD, (iH - 1) * sH - oH + kH - pH, (iW - 1) * sW - oW + kW - pW };
|
||||
dnnl::memory::dims dilation = { dD-1, dH-1, dW-1 };
|
||||
|
||||
uint i0, i1, i2, i3, i4;
|
||||
if(0 == wFormat) {
|
||||
i0 = 3; i1 = 4; i2 = 0; i3 = 1; i4 = 2; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
|
||||
}
|
||||
else if(1 == wFormat) {
|
||||
i0 = 1; i1 = 0; i2 = 2; i3 = 3; i4 = 4; // [iC, oC, kD, kH, kW] -> [oC, iC, kD, kH, kW]
|
||||
}
|
||||
else {
|
||||
i0 = 4; i1 = 0; i2 = 1; i3 = 2; i4 = 3; // [iC, kD, kH, kW, oC] -> [oC, iC, kD, kH, kW]
|
||||
}
|
||||
std::vector<int> permut;
|
||||
if(0 == wFormat)
|
||||
permut = {3,4,0,1,2}; // [kD, kH, kW, oC, iC] -> [oC, iC, kD, kH, kW]
|
||||
else if(1 == wFormat)
|
||||
permut = {1,0,2,3,4}; // [iC, oC, kD, kH, kW] -> [oC, iC, kD, kH, kW]
|
||||
else
|
||||
permut = {4,0,1,2,3}; // [iC, kD, kH, kW, oC] -> [oC, iC, kD, kH, kW]
|
||||
|
||||
// input type
|
||||
dnnl::memory::data_type xType = input->dataType() == DataType::FLOAT32 ? dnnl::memory::data_type::f32 : dnnl::memory::data_type::bf16;
|
||||
|
@ -221,37 +207,27 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(input, x_user_md);
|
||||
mkldnnUtils::setBlockStrides(*input, x_user_md);
|
||||
|
||||
// weights
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc w_user_md = dnnl::memory::desc(wDims, wType, wFormatMkl);
|
||||
w_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
w_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(i0);
|
||||
w_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(i1);
|
||||
w_user_md.data.format_desc.blocking.strides[2] = weights->strideAt(i2);
|
||||
w_user_md.data.format_desc.blocking.strides[3] = weights->strideAt(i3);
|
||||
w_user_md.data.format_desc.blocking.strides[4] = weights->strideAt(i4);
|
||||
mkldnnUtils::setBlockStrides(*weights, w_user_md, permut);
|
||||
|
||||
// gradO
|
||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(gradO, gradO_user_md);
|
||||
mkldnnUtils::setBlockStrides(*gradO, gradO_user_md);
|
||||
|
||||
// gradI
|
||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(gradI, gradI_user_md);
|
||||
mkldnnUtils::setBlockStrides(*gradI, gradI_user_md);
|
||||
|
||||
// gradW
|
||||
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradW_user_md = dnnl::memory::desc(wDims, gradWType, wFormatMkl);
|
||||
gradW_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradW_user_md.data.format_desc.blocking.strides[0] = gradW->strideAt(i0);
|
||||
gradW_user_md.data.format_desc.blocking.strides[1] = gradW->strideAt(i1);
|
||||
gradW_user_md.data.format_desc.blocking.strides[2] = gradW->strideAt(i2);
|
||||
gradW_user_md.data.format_desc.blocking.strides[3] = gradW->strideAt(i3);
|
||||
gradW_user_md.data.format_desc.blocking.strides[4] = gradW->strideAt(i4);
|
||||
mkldnnUtils::setBlockStrides(*gradW, gradW_user_md, permut);
|
||||
|
||||
// gradB
|
||||
dnnl::memory::desc gradB_mkl_md;
|
||||
|
@ -281,10 +257,10 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// weights
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
|
||||
// gradO
|
||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast<void*>(gradO->buffer()));
|
||||
|
@ -299,16 +275,10 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
|
||||
|
||||
// gradI
|
||||
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer());
|
||||
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
|
||||
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
|
||||
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
|
||||
auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_data_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]);
|
||||
|
||||
// gradW
|
||||
auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->buffer());
|
||||
const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc();
|
||||
auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
|
||||
args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
|
||||
auto gradW_user_mem = mkldnnUtils::loadDataToMklStream(*gradW, engine, stream, gradW_user_md, op_weights_bp_prim_desc.diff_weights_desc(), args[DNNL_ARG_DIFF_WEIGHTS]);
|
||||
|
||||
// gradB
|
||||
if(gradB != nullptr) {
|
||||
|
@ -326,10 +296,10 @@ static void deconv3dBackPropMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
dnnl::deconvolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder gradI if necessary
|
||||
if (gradIReorder)
|
||||
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
|
||||
if (gradWReorder)
|
||||
dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
|
||||
if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc())
|
||||
dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem);
|
||||
if (op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc())
|
||||
dnnl::reorder(args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem).execute(stream, args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem);
|
||||
|
||||
stream.wait();
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@
|
|||
|
||||
using namespace dnnl;
|
||||
|
||||
namespace sd {
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace platforms {
|
||||
|
||||
|
@ -109,7 +109,7 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(input, x_user_md);
|
||||
mkldnnUtils::setBlockStrides(*input, x_user_md);
|
||||
|
||||
// weights
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||
|
@ -129,7 +129,7 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
// output
|
||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, zType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, zType, xzFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(output, z_user_md);
|
||||
mkldnnUtils::setBlockStrides(*output, z_user_md);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
|
@ -146,10 +146,10 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// weights
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
|
||||
// bias
|
||||
if(bias != nullptr) {
|
||||
|
@ -158,24 +158,21 @@ static void depthwiseConv2dMKLDNN(const NDArray* input, const NDArray* weights,
|
|||
}
|
||||
|
||||
// output
|
||||
auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer());
|
||||
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
||||
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
||||
args[DNNL_ARG_DST] = z_mkl_mem;
|
||||
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
|
||||
|
||||
// run calculations
|
||||
dnnl::convolution_forward(op_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder outputs if necessary
|
||||
if (zReorder)
|
||||
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
||||
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
|
||||
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
|
||||
|
||||
stream.wait();
|
||||
// shape::printArray(z_mkl_mem.map_data<float>(),8);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB,
|
||||
static void depthwiseConv2dBpMKLDNN(const NDArray* input, const NDArray* weights, const NDArray* gradO, NDArray* gradI, NDArray* gradW, NDArray* gradB,
|
||||
const int kH, const int kW, const int sH, const int sW, const int pH, const int pW, const int dH, const int dW,
|
||||
const int paddingMode, const bool isNCHW, const int wFormat) {
|
||||
|
||||
|
@ -235,7 +232,7 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
|
|||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, xType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, xType, xzFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(input, x_user_md);
|
||||
mkldnnUtils::setBlockStrides(*input, x_user_md);
|
||||
|
||||
// weights
|
||||
dnnl::memory::desc w_mkl_md = dnnl::memory::desc(wDims, wType, dnnl::memory::format_tag::any);
|
||||
|
@ -250,12 +247,12 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
|
|||
// gradO
|
||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, gradOType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, gradOType, xzFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(gradO, gradO_user_md);
|
||||
mkldnnUtils::setBlockStrides(*gradO, gradO_user_md);
|
||||
|
||||
// gradI
|
||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, gradIType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, gradIType, xzFormatMkl);
|
||||
mkldnnUtils::setBlockStrides(gradI, gradI_user_md);
|
||||
mkldnnUtils::setBlockStrides(*gradI, gradI_user_md);
|
||||
|
||||
// gradW
|
||||
dnnl::memory::desc gradW_mkl_md = dnnl::memory::desc(wDims, gradWType, dnnl::memory::format_tag::any);
|
||||
|
@ -294,10 +291,10 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
|
|||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_weights_bp_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// weights
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, w_user_md, op_data_bp_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
|
||||
// gradO
|
||||
auto gradO_user_mem = dnnl::memory(gradO_user_md, engine, const_cast<void*>(gradO->buffer()));
|
||||
|
@ -312,16 +309,10 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
|
|||
args[DNNL_ARG_DIFF_DST] = gradO_mkl_memD;
|
||||
|
||||
// gradI
|
||||
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer());
|
||||
const bool gradIReorder = op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
|
||||
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_data_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
|
||||
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
|
||||
auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_data_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]);
|
||||
|
||||
// gradW
|
||||
auto gradW_user_mem = dnnl::memory(gradW_user_md, engine, gradW->buffer());
|
||||
const bool gradWReorder = op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc();
|
||||
auto gradW_mkl_mem = gradWReorder ? dnnl::memory(op_weights_bp_prim_desc.diff_weights_desc(), engine) : gradW_user_mem;
|
||||
args[DNNL_ARG_DIFF_WEIGHTS] = gradW_mkl_mem;
|
||||
auto gradW_user_mem = mkldnnUtils::loadDataToMklStream(*gradW, engine, stream, gradW_user_md, op_weights_bp_prim_desc.diff_weights_desc(), args[DNNL_ARG_DIFF_WEIGHTS]);
|
||||
|
||||
// gradB
|
||||
if(gradB != nullptr) {
|
||||
|
@ -339,10 +330,10 @@ static void depthwiseConv2dNackPropMKLDNN(const NDArray* input, const NDArray* w
|
|||
dnnl::convolution_backward_weights(op_weights_bp_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder gradI if necessary
|
||||
if (gradIReorder)
|
||||
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
|
||||
if (gradWReorder)
|
||||
dnnl::reorder(gradW_mkl_mem, gradW_user_mem).execute(stream, gradW_mkl_mem, gradW_user_mem);
|
||||
if (op_data_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc())
|
||||
dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem);
|
||||
if (op_weights_bp_prim_desc.diff_weights_desc() != gradW_user_mem.get_desc())
|
||||
dnnl::reorder(args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem).execute(stream, args[DNNL_ARG_DIFF_WEIGHTS], gradW_user_mem);
|
||||
|
||||
stream.wait();
|
||||
|
||||
|
@ -458,7 +449,7 @@ PLATFORM_IMPL(depthwise_conv2d_bp, ENGINE_CPU) {
|
|||
if(bias)
|
||||
REQUIRE_TRUE(bias->rankOf() <= 2 && oC == bias->lengthOf(), 0, "CUSTOM DEPTHWISECONV2D_BP MKL OP: wrong shape of array with biases, expected rank, length: <=2, %i, but got %i, %i instead !", oC, bias->rankOf(), bias->lengthOf());
|
||||
|
||||
depthwiseConv2dNackPropMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat);
|
||||
depthwiseConv2dBpMKLDNN(input, weights, gradO, gradI, gradW, gradB, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW, wFormat);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
@ -169,71 +169,43 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
|
|||
x_lstm_md = dnnl::memory::desc({sL, bS, nIn}, xType, dnnl::memory::format_tag::any);
|
||||
// x_user_md = dataFormat == 0 ? dnnl::memory::desc({sL, bS, nIn}, type, dnnl::memory::format_tag::tnc) : dnnl::memory::desc({bS, sL, nIn}, type, dnnl::memory::format_tag::ntc);
|
||||
x_user_md = dnnl::memory::desc({sL, bS, nIn}, xType, dnnl::memory::format_tag::tnc);
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = x->stridesOf()[0];
|
||||
x_user_md.data.format_desc.blocking.strides[1] = x->stridesOf()[1];
|
||||
x_user_md.data.format_desc.blocking.strides[2] = x->stridesOf()[2];
|
||||
mkldnnUtils::setBlockStrides(*x, x_user_md);
|
||||
|
||||
// wx
|
||||
wx_lstm_md = dnnl::memory::desc({1,dirDim,nIn,4,nOut}, wType, dnnl::memory::format_tag::any);
|
||||
wx_user_md = dnnl::memory::desc({1,dirDim,nIn,4,nOut}, wType, dnnl::memory::format_tag::ldigo);
|
||||
wx_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
wx_user_md.data.format_desc.blocking.strides[0] = Wx->stridesOf()[0];
|
||||
wx_user_md.data.format_desc.blocking.strides[1] = Wx->stridesOf()[1];
|
||||
wx_user_md.data.format_desc.blocking.strides[2] = Wx->stridesOf()[2];
|
||||
wx_user_md.data.format_desc.blocking.strides[3] = Wx->stridesOf()[3];
|
||||
wx_user_md.data.format_desc.blocking.strides[4] = Wx->stridesOf()[4];
|
||||
mkldnnUtils::setBlockStrides(*Wx, wx_user_md);
|
||||
|
||||
// wr
|
||||
wr_lstm_md = dnnl::memory::desc({1,dirDim,nOut,4,nOut}, wType, dnnl::memory::format_tag::any);
|
||||
wr_user_md = dnnl::memory::desc({1,dirDim,nOut,4,nOut}, wType, dnnl::memory::format_tag::ldigo);
|
||||
wr_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
wr_user_md.data.format_desc.blocking.strides[0] = Wr->stridesOf()[0];
|
||||
wr_user_md.data.format_desc.blocking.strides[1] = Wr->stridesOf()[1];
|
||||
wr_user_md.data.format_desc.blocking.strides[2] = Wr->stridesOf()[2];
|
||||
wr_user_md.data.format_desc.blocking.strides[3] = Wr->stridesOf()[3];
|
||||
wr_user_md.data.format_desc.blocking.strides[4] = Wr->stridesOf()[4];
|
||||
mkldnnUtils::setBlockStrides(*Wr, wr_user_md);
|
||||
|
||||
// h
|
||||
h_lstm_md = dnnl::memory::desc({sL, bS, hDirDim*nOut}, hType, dnnl::memory::format_tag::any);
|
||||
// h_user_md = dataFormat == 0 ? dnnl::memory::desc({sL, bS, hDirDim*nOut}, type, dnnl::memory::format_tag::tnc) : dnnl::memory::desc({bS, sL, hDirDim*nOut}, type, dnnl::memory::format_tag::ntc);
|
||||
h_user_md = dnnl::memory::desc({sL, bS, hDirDim*nOut}, hType, dnnl::memory::format_tag::tnc);
|
||||
h_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
h_user_md.data.format_desc.blocking.strides[0] = h->stridesOf()[0];
|
||||
h_user_md.data.format_desc.blocking.strides[1] = h->stridesOf()[1];
|
||||
h_user_md.data.format_desc.blocking.strides[2] = h->stridesOf()[2];
|
||||
mkldnnUtils::setBlockStrides(*h, h_user_md);
|
||||
|
||||
// b
|
||||
if(b) {
|
||||
b_lstm_md = dnnl::memory::desc({1,dirDim,4,nOut}, bType, dnnl::memory::format_tag::any);
|
||||
b_user_md = dnnl::memory::desc({1,dirDim,4,nOut}, bType, dnnl::memory::format_tag::ldgo);
|
||||
b_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
b_user_md.data.format_desc.blocking.strides[0] = b->stridesOf()[0];
|
||||
b_user_md.data.format_desc.blocking.strides[1] = b->stridesOf()[1];
|
||||
b_user_md.data.format_desc.blocking.strides[2] = b->stridesOf()[2];
|
||||
b_user_md.data.format_desc.blocking.strides[3] = b->stridesOf()[3];
|
||||
mkldnnUtils::setBlockStrides(*b, b_user_md);
|
||||
}
|
||||
|
||||
// hI
|
||||
if(hI) {
|
||||
hI_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::any);
|
||||
hI_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::ldnc);
|
||||
hI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
hI_user_md.data.format_desc.blocking.strides[0] = hI->stridesOf()[0];
|
||||
hI_user_md.data.format_desc.blocking.strides[1] = hI->stridesOf()[1];
|
||||
hI_user_md.data.format_desc.blocking.strides[2] = hI->stridesOf()[2];
|
||||
hI_user_md.data.format_desc.blocking.strides[3] = hI->stridesOf()[3];
|
||||
mkldnnUtils::setBlockStrides(*hI, hI_user_md);
|
||||
}
|
||||
|
||||
// cI
|
||||
if(cI) {
|
||||
cI_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::any);
|
||||
cI_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, xType, dnnl::memory::format_tag::ldnc);
|
||||
cI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
cI_user_md.data.format_desc.blocking.strides[0] = cI->stridesOf()[0];
|
||||
cI_user_md.data.format_desc.blocking.strides[1] = cI->stridesOf()[1];
|
||||
cI_user_md.data.format_desc.blocking.strides[2] = cI->stridesOf()[2];
|
||||
cI_user_md.data.format_desc.blocking.strides[2] = cI->stridesOf()[3];
|
||||
mkldnnUtils::setBlockStrides(*cI, cI_user_md);
|
||||
}
|
||||
|
||||
// hL
|
||||
|
@ -241,20 +213,13 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
|
|||
hL_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::any);
|
||||
hL_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc);
|
||||
hL_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
hL_user_md.data.format_desc.blocking.strides[0] = hL->stridesOf()[0];
|
||||
hL_user_md.data.format_desc.blocking.strides[1] = hL->stridesOf()[1];
|
||||
hL_user_md.data.format_desc.blocking.strides[2] = hL->stridesOf()[2];
|
||||
hL_user_md.data.format_desc.blocking.strides[3] = hL->stridesOf()[3];
|
||||
mkldnnUtils::setBlockStrides(*hL, hL_user_md);
|
||||
}
|
||||
|
||||
if(cL) {
|
||||
cL_lstm_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc);
|
||||
cL_user_md = dnnl::memory::desc({1,dirDim,bS,nOut}, hType, dnnl::memory::format_tag::ldnc);
|
||||
cL_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
cL_user_md.data.format_desc.blocking.strides[0] = cL->stridesOf()[0];
|
||||
cL_user_md.data.format_desc.blocking.strides[1] = cL->stridesOf()[1];
|
||||
cL_user_md.data.format_desc.blocking.strides[2] = cL->stridesOf()[2];
|
||||
cL_user_md.data.format_desc.blocking.strides[3] = cL->stridesOf()[3];
|
||||
mkldnnUtils::setBlockStrides(*cL, cL_user_md);
|
||||
}
|
||||
|
||||
// lstm memory description
|
||||
|
@ -272,64 +237,49 @@ static void lstmLayerMKLDNN(const NDArray* x, const NDArray* Wx, const NDArray*
|
|||
|
||||
// provide memory and check whether reorder is required
|
||||
// x
|
||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, lstm_prim_desc.src_layer_desc(), args[DNNL_ARG_SRC_LAYER]);
|
||||
mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, lstm_prim_desc.src_layer_desc(), args[DNNL_ARG_SRC_LAYER]);
|
||||
|
||||
// wx
|
||||
mkldnnUtils::loadDataToMklStream(Wx, engine, stream, wx_user_md, lstm_prim_desc.weights_layer_desc(), args[DNNL_ARG_WEIGHTS_LAYER]);
|
||||
mkldnnUtils::loadDataToMklStream(*Wx, engine, stream, wx_user_md, lstm_prim_desc.weights_layer_desc(), args[DNNL_ARG_WEIGHTS_LAYER]);
|
||||
|
||||
// wr
|
||||
mkldnnUtils::loadDataToMklStream(Wr, engine, stream, wr_user_md, lstm_prim_desc.weights_iter_desc(), args[DNNL_ARG_WEIGHTS_ITER]);
|
||||
mkldnnUtils::loadDataToMklStream(*Wr, engine, stream, wr_user_md, lstm_prim_desc.weights_iter_desc(), args[DNNL_ARG_WEIGHTS_ITER]);
|
||||
|
||||
// h
|
||||
auto h_user_mem = dnnl::memory(h_user_md, engine, h->buffer());
|
||||
const bool hReorder = lstm_prim_desc.dst_layer_desc() != h_user_mem.get_desc();
|
||||
auto h_lstm_mem = hReorder ? dnnl::memory(lstm_prim_desc.dst_layer_desc(), engine) : h_user_mem;
|
||||
args[DNNL_ARG_DST_LAYER] = h_lstm_mem;
|
||||
auto h_user_mem = mkldnnUtils::loadDataToMklStream(*h, engine, stream, h_user_md, lstm_prim_desc.dst_layer_desc(), args[DNNL_ARG_DST_LAYER]);
|
||||
|
||||
// b
|
||||
if(b) {
|
||||
mkldnnUtils::loadDataToMklStream(b, engine, stream, b_user_md, lstm_prim_desc.bias_desc(), args[DNNL_ARG_BIAS]);
|
||||
}
|
||||
if(b)
|
||||
mkldnnUtils::loadDataToMklStream(*b, engine, stream, b_user_md, lstm_prim_desc.bias_desc(), args[DNNL_ARG_BIAS]);
|
||||
|
||||
// hI
|
||||
if(hI) {
|
||||
mkldnnUtils::loadDataToMklStream(hI, engine, stream, hI_user_md, lstm_prim_desc.src_iter_desc(), args[DNNL_ARG_SRC_ITER]);
|
||||
}
|
||||
if(hI)
|
||||
mkldnnUtils::loadDataToMklStream(*hI, engine, stream, hI_user_md, lstm_prim_desc.src_iter_desc(), args[DNNL_ARG_SRC_ITER]);
|
||||
|
||||
// cI
|
||||
if(cI) {
|
||||
mkldnnUtils::loadDataToMklStream(cI, engine, stream, cI_user_md, lstm_prim_desc.src_iter_c_desc(), args[DNNL_ARG_SRC_ITER_C]);
|
||||
}
|
||||
if(cI)
|
||||
mkldnnUtils::loadDataToMklStream(*cI, engine, stream, cI_user_md, lstm_prim_desc.src_iter_c_desc(), args[DNNL_ARG_SRC_ITER_C]);
|
||||
|
||||
bool hLReorder(false), cLReorder(false);
|
||||
dnnl::memory hL_user_mem, cL_user_mem, hL_lstm_mem, cL_lstm_mem;
|
||||
|
||||
// hL
|
||||
if(hL) {
|
||||
hL_user_mem = dnnl::memory(hL_user_md, engine, hL->buffer());
|
||||
hLReorder = lstm_prim_desc.dst_iter_desc() != hL_user_mem.get_desc();
|
||||
hL_lstm_mem = hLReorder ? dnnl::memory(lstm_prim_desc.dst_iter_desc(), engine) : hL_user_mem;
|
||||
args[DNNL_ARG_DST_ITER] = hL_lstm_mem;
|
||||
}
|
||||
if(hL)
|
||||
hL_user_mem = mkldnnUtils::loadDataToMklStream(*hL, engine, stream, hL_user_md, lstm_prim_desc.dst_iter_desc(), args[DNNL_ARG_DST_ITER]);
|
||||
|
||||
// cL
|
||||
if(cL) {
|
||||
cL_user_mem = dnnl::memory(cL_user_md, engine, cL->buffer());
|
||||
cLReorder = lstm_prim_desc.dst_iter_c_desc() != cL_user_mem.get_desc();
|
||||
cL_lstm_mem = cLReorder ? dnnl::memory(lstm_prim_desc.dst_iter_c_desc(), engine) : cL_user_mem;
|
||||
args[DNNL_ARG_DST_ITER_C] = cL_lstm_mem;
|
||||
}
|
||||
if(cL)
|
||||
cL_user_mem = mkldnnUtils::loadDataToMklStream(*cL, engine, stream, cL_user_md, lstm_prim_desc.dst_iter_c_desc(), args[DNNL_ARG_DST_ITER_C]);
|
||||
|
||||
// run calculations
|
||||
lstm_forward(lstm_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder outputs if necessary
|
||||
if (hReorder)
|
||||
reorder(h_lstm_mem, h_user_mem).execute(stream, h_lstm_mem, h_user_mem);
|
||||
if(hLReorder)
|
||||
reorder(hL_lstm_mem, hL_user_mem).execute(stream, hL_lstm_mem, hL_user_mem);
|
||||
if(cLReorder)
|
||||
reorder(cL_lstm_mem, cL_user_mem).execute(stream, cL_lstm_mem, cL_user_mem);
|
||||
if (lstm_prim_desc.dst_layer_desc() != h_user_mem.get_desc())
|
||||
reorder(args[DNNL_ARG_DST_LAYER], h_user_mem).execute(stream, args[DNNL_ARG_DST_LAYER], h_user_mem);
|
||||
if(lstm_prim_desc.dst_iter_desc() != hL_user_mem.get_desc())
|
||||
reorder(args[DNNL_ARG_DST_ITER], hL_user_mem).execute(stream, args[DNNL_ARG_DST_ITER], hL_user_mem);
|
||||
if(lstm_prim_desc.dst_iter_c_desc() != cL_user_mem.get_desc())
|
||||
reorder(args[DNNL_ARG_DST_ITER_C], cL_user_mem).execute(stream, args[DNNL_ARG_DST_ITER_C], cL_user_mem);
|
||||
|
||||
stream.wait();
|
||||
}
|
||||
|
@ -377,9 +327,9 @@ PLATFORM_IMPL(lstmLayer, ENGINE_CPU) {
|
|||
auto cL = retLastC ? OUTPUT_VARIABLE(count++) : nullptr; // cell state at last step
|
||||
|
||||
// evaluate dimensions
|
||||
const Nd4jLong sL = dataFormat == 3 ? x->sizeAt(0) : x->sizeAt(dataFormat);
|
||||
const Nd4jLong bS = dataFormat == 1 || dataFormat == 2 ? x->sizeAt(0) : x->sizeAt(-2);
|
||||
const Nd4jLong nIn = dataFormat == 2 ? x->sizeAt(1) : x->sizeAt(-1);
|
||||
const Nd4jLong sL = x->sizeAt(dataFormat);
|
||||
const Nd4jLong bS = dataFormat == 0 ? x->sizeAt(1) : x->sizeAt(0);
|
||||
const Nd4jLong nIn = x->sizeAt(2);
|
||||
const Nd4jLong nOut = Wx->sizeAt(-1) / 4;
|
||||
|
||||
// inputs validations
|
||||
|
@ -435,14 +385,21 @@ PLATFORM_IMPL(lstmLayer, ENGINE_CPU) {
|
|||
|
||||
WxR = new NDArray(Wx->reshape(Wx->ordering(), {1,dirDim,nIn,4,nOut}));
|
||||
WrR = new NDArray(Wr->reshape(Wr->ordering(), {1,dirDim,nOut,4,nOut}));
|
||||
|
||||
if(b)
|
||||
bR = new NDArray(b->reshape(b->ordering(), {1,dirDim,4,nOut}));
|
||||
bR = new NDArray(b->reshape(b->ordering(), {1,dirDim,4,nOut}));
|
||||
else
|
||||
bR = new NDArray(x->ordering(), {1,dirDim,4,nOut}, x->dataType(), x->getContext()); // already nullified
|
||||
|
||||
if(hI)
|
||||
hIR = new NDArray(hI->reshape(hI->ordering(), {1,dirDim,bS,nOut}));
|
||||
|
||||
if(cI)
|
||||
cIR = new NDArray(cI->reshape(cI->ordering(), {1,dirDim,bS,nOut}));
|
||||
|
||||
if(hL)
|
||||
hLR = new NDArray(hL->reshape(hL->ordering(), {1,dirDim,bS,nOut}, false));
|
||||
|
||||
if(cL)
|
||||
cLR = new NDArray(cL->reshape(cL->ordering(), {1,dirDim,bS,nOut}, false));
|
||||
|
||||
|
|
|
@ -31,20 +31,6 @@ namespace sd {
|
|||
namespace ops {
|
||||
namespace platforms {
|
||||
|
||||
dnnl::memory::format_tag get_format_tag(const sd::NDArray &array) {
|
||||
switch (array.rankOf()) {
|
||||
case 1:
|
||||
return dnnl::memory::format_tag::ab;
|
||||
case 2:
|
||||
return array.ordering() == 'c' ? dnnl::memory::format_tag::ab : dnnl::memory::format_tag::ba;
|
||||
case 3:
|
||||
return array.ordering() == 'c' ? dnnl::memory::format_tag::abc : dnnl::memory::format_tag::cba;
|
||||
default:
|
||||
throw std::runtime_error("MKLDNN matmul only supports 2D/3D arrays");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const bool transX, const bool transY, float alpha = 1.f, float beta = 0.f) {
|
||||
|
||||
|
@ -123,11 +109,16 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b
|
|||
else if(z->dataType() == DataType::INT8)
|
||||
zType = dnnl::memory::data_type::s8;
|
||||
|
||||
|
||||
const auto xFormat = xRank == 1 ? dnnl::memory::format_tag::ab : mkldnnUtils::getFormat(*xTR);
|
||||
const auto yFormat = yRank == 1 ? dnnl::memory::format_tag::ab : mkldnnUtils::getFormat(*yTR);
|
||||
const auto zFormat = zRank == 1 ? dnnl::memory::format_tag::ab : mkldnnUtils::getFormat(*zR);
|
||||
|
||||
// memory descriptors for arrays
|
||||
dnnl::memory::desc x_mkl_md, x_user_md, y_mkl_md, y_user_md, z_mkl_md, z_user_md;
|
||||
|
||||
// x
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, get_format_tag(*xTR));
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, get_format_tag(*xTR));
|
||||
x_user_md = x_mkl_md = dnnl::memory::desc(xShape, xType, xFormat);
|
||||
if(xTR->ews() != 1) {
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = xRank == 1 ? 1 : xTR->strideAt(0);
|
||||
|
@ -137,8 +128,7 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b
|
|||
}
|
||||
|
||||
// y
|
||||
dnnl::memory::desc y_mkl_md = dnnl::memory::desc(yShape, yType, get_format_tag(*yTR));
|
||||
dnnl::memory::desc y_user_md = dnnl::memory::desc(yShape, yType, get_format_tag(*yTR));
|
||||
y_user_md = y_mkl_md = dnnl::memory::desc(yShape, yType, yFormat);
|
||||
if(yTR->ews() != 1) {
|
||||
y_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
y_user_md.data.format_desc.blocking.strides[0] = yRank == 1 ? 1 : yTR->strideAt(0);
|
||||
|
@ -148,8 +138,7 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b
|
|||
}
|
||||
|
||||
// z
|
||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, zType, get_format_tag(*zR));
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, zType, get_format_tag(*zR));
|
||||
z_user_md = z_mkl_md = dnnl::memory::desc(zShape, zType, zFormat);
|
||||
if(zR->ews() != 1) {
|
||||
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
z_user_md.data.format_desc.blocking.strides[0] = zRank == 1 ? 1 : zR->strideAt(0);
|
||||
|
@ -181,37 +170,20 @@ static void matmulMKLDNN(const NDArray* x, const NDArray* y, NDArray* z, const b
|
|||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(xTR, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
/*
|
||||
auto x_user_mem = dnnl::memory(x_user_md, engine, xTR->buffer());
|
||||
const bool xReorder = op_prim_desc.src_desc() != x_user_mem.get_desc();
|
||||
auto x_mkl_mem = xReorder ? dnnl::memory(op_prim_desc.src_desc(), engine) : x_user_mem;
|
||||
if (xReorder)
|
||||
dnnl::reorder(x_user_mem, x_mkl_mem).execute(stream, x_user_mem, x_mkl_mem);
|
||||
args[DNNL_ARG_SRC] = x_mkl_mem;
|
||||
*/
|
||||
mkldnnUtils::loadDataToMklStream(*xTR, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// y
|
||||
mkldnnUtils::loadDataToMklStream(yTR, engine, stream, y_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
/*
|
||||
auto y_user_mem = dnnl::memory(y_user_md, engine, yTR->buffer());
|
||||
const bool yReorder = op_prim_desc.weights_desc() != y_user_mem.get_desc();
|
||||
auto y_mkl_mem = yReorder ? dnnl::memory(op_prim_desc.weights_desc(), engine) : y_user_mem;
|
||||
if (yReorder)
|
||||
dnnl::reorder(y_user_mem, y_mkl_mem).execute(stream, y_user_mem, y_mkl_mem);
|
||||
args[DNNL_ARG_WEIGHTS] = y_mkl_mem;
|
||||
*/
|
||||
mkldnnUtils::loadDataToMklStream(*yTR, engine, stream, y_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
|
||||
// z
|
||||
auto z_user_mem = dnnl::memory(z_user_md, engine, zR->buffer());
|
||||
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
||||
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
||||
args[DNNL_ARG_DST] = z_mkl_mem;
|
||||
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*zR, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
|
||||
|
||||
// run calculations
|
||||
dnnl::matmul(op_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder outputs if necessary
|
||||
if (zReorder)
|
||||
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
||||
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
|
||||
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
|
||||
|
||||
stream.wait();
|
||||
|
||||
|
|
|
@ -38,45 +38,65 @@ void getDims(const NDArray* array, const int rank, dnnl::memory::dims& mklDims){
|
|||
mklDims = dnnl::memory::dims(vDims);
|
||||
}
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
dnnl::memory::format_tag getFormat(const int rank){
|
||||
if (2 == rank) {
|
||||
return dnnl::memory::format_tag::ab;
|
||||
}
|
||||
else if (3 == rank) {
|
||||
return dnnl::memory::format_tag::abc;
|
||||
}
|
||||
else if (4 == rank) {
|
||||
return dnnl::memory::format_tag::abcd;
|
||||
}
|
||||
else if (5 == rank) {
|
||||
return dnnl::memory::format_tag::abcde;
|
||||
}
|
||||
else if (6 == rank) {
|
||||
return dnnl::memory::format_tag::abcdef;
|
||||
}
|
||||
return dnnl::memory::format_tag::a; // 1 == dataSetRank
|
||||
dnnl::memory::format_tag getFormat(const NDArray& arr) {
|
||||
|
||||
dnnl::memory::format_tag result;
|
||||
|
||||
switch (arr.rankOf()) {
|
||||
case 1:
|
||||
result = dnnl::memory::format_tag::a;
|
||||
break;
|
||||
case 2:
|
||||
result = arr.ordering() == 'c' ? dnnl::memory::format_tag::ab : dnnl::memory::format_tag::ba;
|
||||
break;
|
||||
case 3:
|
||||
result = arr.ordering() == 'c' ? dnnl::memory::format_tag::abc : dnnl::memory::format_tag::cba;
|
||||
break;
|
||||
case 4:
|
||||
result = dnnl::memory::format_tag::abcd;
|
||||
break;
|
||||
case 5:
|
||||
result = dnnl::memory::format_tag::abcde;
|
||||
break;
|
||||
case 6:
|
||||
result = dnnl::memory::format_tag::abcdef;
|
||||
break;
|
||||
default:
|
||||
throw std::invalid_argument("MKLDNN getFormat: do we really want to use arras with rank > 6 ?");
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
void setBlockStrides(const NDArray* array, dnnl::memory::desc& mklMd){
|
||||
void setBlockStrides(const NDArray& array, dnnl::memory::desc& mklMd, const std::vector<int>& permut) {
|
||||
|
||||
if (array->ews() != 1 || array->ordering() != 'c') {
|
||||
mklMd.data.format_kind = dnnl_blocked; // overrides format
|
||||
for (auto i = 0; i < array->rankOf(); ++i) {
|
||||
mklMd.data.format_desc.blocking.strides[i] = array->strideAt(i);
|
||||
if (array.ews() != 1 || (array.rankOf() > 3 && array.ordering() == 'f') || !permut.empty()) {
|
||||
|
||||
mklMd.data.format_kind = dnnl_blocked; // overrides format
|
||||
|
||||
if(permut.empty())
|
||||
for (auto i = 0; i < array.rankOf(); ++i)
|
||||
mklMd.data.format_desc.blocking.strides[i] = array.strideAt(i);
|
||||
else {
|
||||
if(array.rankOf() != permut.size())
|
||||
throw std::invalid_argument("mkldnnUtils::setBlockStrides: size of permut vector is not equal to array rank !");
|
||||
for (auto i = 0; i < array.rankOf(); ++i)
|
||||
mklMd.data.format_desc.blocking.strides[i] = array.strideAt(permut[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
void loadDataToMklStream(const NDArray* array, const dnnl::engine& engine, const dnnl::stream& stream, const dnnl::memory::desc& user_md, const dnnl::memory::desc& primitive_md,
|
||||
dnnl::memory& arg) {
|
||||
dnnl::memory loadDataToMklStream(const NDArray& array, const dnnl::engine& engine, const dnnl::stream& stream,
|
||||
const dnnl::memory::desc& user_md, const dnnl::memory::desc& primitive_md, dnnl::memory& arg) {
|
||||
|
||||
auto user_mem = dnnl::memory(user_md, engine,const_cast<void*>(array->buffer()));
|
||||
auto user_mem = dnnl::memory(user_md, engine, const_cast<NDArray&>(array).buffer());
|
||||
const bool bReorder = primitive_md != user_mem.get_desc();
|
||||
auto mkl_mem = bReorder ? dnnl::memory(primitive_md, engine) : user_mem;
|
||||
if (bReorder)
|
||||
dnnl::reorder(user_mem, mkl_mem).execute(stream, user_mem, mkl_mem);
|
||||
arg = mkl_mem;
|
||||
return user_mem;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
@ -122,33 +142,21 @@ void poolingMKLDNN(const NDArray *input, NDArray *output,
|
|||
xzFrmat = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||
}
|
||||
|
||||
std::vector<int> permut;
|
||||
if(!isNCHW)
|
||||
permut = rank == 4 ? std::vector<int>({0,3,1,2}) : std::vector<int>({0,4,1,2,3});
|
||||
|
||||
// memory descriptors for arrays
|
||||
|
||||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||
if(input->ews() != 1 || input->ordering() != 'c') {
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
||||
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(isNCHW ? 1 :-1);
|
||||
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(isNCHW ? 2 : 1);
|
||||
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(isNCHW ? 3 : 2);
|
||||
if(rank == 5)
|
||||
x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(isNCHW ? 4 : 3);
|
||||
}
|
||||
mkldnnUtils::setBlockStrides(*input, x_user_md, permut);
|
||||
|
||||
// output
|
||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
|
||||
if(output->ews() != 1 || output->ordering() != 'c') {
|
||||
z_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
z_user_md.data.format_desc.blocking.strides[0] = output->strideAt(0);
|
||||
z_user_md.data.format_desc.blocking.strides[1] = output->strideAt(isNCHW ? 1 :-1);
|
||||
z_user_md.data.format_desc.blocking.strides[2] = output->strideAt(isNCHW ? 2 : 1);
|
||||
z_user_md.data.format_desc.blocking.strides[3] = output->strideAt(isNCHW ? 3 : 2);
|
||||
if(rank == 5)
|
||||
z_user_md.data.format_desc.blocking.strides[4] = output->strideAt(isNCHW ? 4 : 3);
|
||||
}
|
||||
mkldnnUtils::setBlockStrides(*output, z_user_md, permut);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
|
@ -164,20 +172,17 @@ void poolingMKLDNN(const NDArray *input, NDArray *output,
|
|||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// output
|
||||
auto z_user_mem = dnnl::memory(z_user_md, engine, output->buffer());
|
||||
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
||||
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
||||
args[DNNL_ARG_DST] = z_mkl_mem;
|
||||
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*output, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
|
||||
|
||||
// run calculations
|
||||
dnnl::pooling_forward(op_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder outputs if necessary
|
||||
if (zReorder)
|
||||
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
||||
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
|
||||
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
|
||||
|
||||
stream.wait();
|
||||
}
|
||||
|
@ -226,46 +231,27 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI,
|
|||
xzFrmat = isNCHW ? dnnl::memory::format_tag::ncdhw : dnnl::memory::format_tag::ndhwc;
|
||||
}
|
||||
|
||||
std::vector<int> permut;
|
||||
if(!isNCHW)
|
||||
permut = rank == 4 ? std::vector<int>({0,3,1,2}) : std::vector<int>({0,4,1,2,3});
|
||||
|
||||
|
||||
// memory descriptors for arrays
|
||||
|
||||
// input
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||
if(input->ews() != 1 || input->ordering() != 'c') {
|
||||
x_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
x_user_md.data.format_desc.blocking.strides[0] = input->strideAt(0);
|
||||
x_user_md.data.format_desc.blocking.strides[1] = input->strideAt(isNCHW ? 1 :-1);
|
||||
x_user_md.data.format_desc.blocking.strides[2] = input->strideAt(isNCHW ? 2 : 1);
|
||||
x_user_md.data.format_desc.blocking.strides[3] = input->strideAt(isNCHW ? 3 : 2);
|
||||
if(rank == 5)
|
||||
x_user_md.data.format_desc.blocking.strides[4] = input->strideAt(isNCHW ? 4 : 3);
|
||||
}
|
||||
mkldnnUtils::setBlockStrides(*input, x_user_md, permut);
|
||||
|
||||
// gradO
|
||||
dnnl::memory::desc gradO_mkl_md = dnnl::memory::desc(zDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradO_user_md = dnnl::memory::desc(zDims, type, xzFrmat);
|
||||
if(gradO->ews() != 1 || gradO->ordering() != 'c') {
|
||||
gradO_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradO_user_md.data.format_desc.blocking.strides[0] = gradO->strideAt(0);
|
||||
gradO_user_md.data.format_desc.blocking.strides[1] = gradO->strideAt(isNCHW ? 1 :-1);
|
||||
gradO_user_md.data.format_desc.blocking.strides[2] = gradO->strideAt(isNCHW ? 2 : 1);
|
||||
gradO_user_md.data.format_desc.blocking.strides[3] = gradO->strideAt(isNCHW ? 3 : 2);
|
||||
if(rank == 5)
|
||||
gradO_user_md.data.format_desc.blocking.strides[4] = gradO->strideAt(isNCHW ? 4 : 3);
|
||||
}
|
||||
mkldnnUtils::setBlockStrides(*gradO, gradO_user_md, permut);
|
||||
|
||||
// gradI
|
||||
dnnl::memory::desc gradI_mkl_md = dnnl::memory::desc(xDims, type, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc gradI_user_md = dnnl::memory::desc(xDims, type, xzFrmat);
|
||||
if(gradI->ews() != 1 || gradI->ordering() != 'c') {
|
||||
gradI_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
gradI_user_md.data.format_desc.blocking.strides[0] = gradI->strideAt(0);
|
||||
gradI_user_md.data.format_desc.blocking.strides[1] = gradI->strideAt(isNCHW ? 1 :-1);
|
||||
gradI_user_md.data.format_desc.blocking.strides[2] = gradI->strideAt(isNCHW ? 2 : 1);
|
||||
gradI_user_md.data.format_desc.blocking.strides[3] = gradI->strideAt(isNCHW ? 3 : 2);
|
||||
if(rank == 5)
|
||||
gradI_user_md.data.format_desc.blocking.strides[4] = gradI->strideAt(isNCHW ? 4 : 3);
|
||||
}
|
||||
mkldnnUtils::setBlockStrides(*gradI, gradI_user_md, permut);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
dnnl::stream stream(engine);
|
||||
|
@ -282,18 +268,15 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI,
|
|||
std::unordered_map<int, dnnl::memory> args;
|
||||
|
||||
// gradO
|
||||
mkldnnUtils::loadDataToMklStream(gradO, engine, stream, gradO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
|
||||
mkldnnUtils::loadDataToMklStream(*gradO, engine, stream, gradO_user_md, op_bp_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
|
||||
|
||||
// gradI
|
||||
auto gradI_user_mem = dnnl::memory(gradI_user_md, engine, gradI->buffer());
|
||||
const bool gradIReorder = op_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc();
|
||||
auto gradI_mkl_mem = gradIReorder ? dnnl::memory(op_bp_prim_desc.diff_src_desc(), engine) : gradI_user_mem;
|
||||
args[DNNL_ARG_DIFF_SRC] = gradI_mkl_mem;
|
||||
auto gradI_user_mem = mkldnnUtils::loadDataToMklStream(*gradI, engine, stream, gradI_user_md, op_bp_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]);
|
||||
|
||||
if(mode == algorithm::pooling_max) {
|
||||
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(input, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
mkldnnUtils::loadDataToMklStream(*input, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// z
|
||||
auto z_mkl_mem = dnnl::memory(op_ff_prim_desc.dst_desc(), engine);
|
||||
|
@ -310,10 +293,9 @@ void poolingBpMKLDNN(const NDArray *input, const NDArray *gradO, NDArray *gradI,
|
|||
// run backward calculations
|
||||
dnnl::pooling_backward(op_bp_prim_desc).execute(stream, args);
|
||||
|
||||
|
||||
// reorder gradI if necessary
|
||||
if (gradIReorder)
|
||||
dnnl::reorder(gradI_mkl_mem, gradI_user_mem).execute(stream, gradI_mkl_mem, gradI_user_mem);
|
||||
if (op_bp_prim_desc.diff_src_desc() != gradI_user_mem.get_desc())
|
||||
dnnl::reorder(args[DNNL_ARG_DIFF_SRC], gradI_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], gradI_user_mem);
|
||||
|
||||
stream.wait();
|
||||
}
|
||||
|
|
|
@ -100,6 +100,8 @@ namespace sd {
|
|||
|
||||
DECLARE_PLATFORM(xw_plus_b_bp, ENGINE_CPU);
|
||||
|
||||
DECLARE_PLATFORM(concat, ENGINE_CPU);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -123,19 +125,13 @@ namespace sd {
|
|||
*/
|
||||
void getDims(const NDArray* array, const int rank, dnnl::memory::dims& mklDims);
|
||||
/**
|
||||
* This function generate memory format tag based on rank
|
||||
* @param const array rank
|
||||
* This function evaluate memory format tag based on array shapeInfo
|
||||
* @param const array
|
||||
* @return memory format
|
||||
*/
|
||||
dnnl::memory::format_tag getFormat(const int rank);
|
||||
/**
|
||||
* This function generate memory format tag based on rank
|
||||
* @param const pointer to dataset
|
||||
* @param const dataset rank
|
||||
* @param reference to memory descriptor
|
||||
* @return memory format
|
||||
*/
|
||||
void setBlockStrides(const NDArray* array, dnnl::memory::desc& mklMd);
|
||||
dnnl::memory::format_tag getFormat(const NDArray& arr);
|
||||
|
||||
void setBlockStrides(const NDArray& array, dnnl::memory::desc& mklMd, const std::vector<int>& permut = {});
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
/**
|
||||
* This function load and reorder user memory to mkl
|
||||
|
@ -147,7 +143,7 @@ namespace sd {
|
|||
* @param primitive memory descriptor
|
||||
* @param dnnl arg activation enumerator
|
||||
*/
|
||||
void loadDataToMklStream(const NDArray* array, const dnnl::engine& engine, const dnnl::stream& stream, const dnnl::memory::desc& user_md, const dnnl::memory::desc& primitive_md,
|
||||
dnnl::memory loadDataToMklStream(const NDArray& array, const dnnl::engine& engine, const dnnl::stream& stream, const dnnl::memory::desc& user_md, const dnnl::memory::desc& primitive_md,
|
||||
dnnl::memory& arg);
|
||||
|
||||
/**
|
||||
|
|
|
@ -35,32 +35,37 @@ namespace sd {
|
|||
//////////////////////////////////////////////////////////////////////
|
||||
static void softmaxMKLDNN(const NDArray* x, NDArray* z, const int axis) {
|
||||
|
||||
const auto xRank = x->rankOf();
|
||||
dnnl::memory::dims xShape, zShape;
|
||||
dnnl::memory::dims shape = x->getShapeAsFlatVector();
|
||||
|
||||
mkldnnUtils::getDims(x, xRank, xShape);
|
||||
mkldnnUtils::getDims(z, xRank, zShape);
|
||||
const int xRank = x->rankOf();
|
||||
|
||||
dnnl::memory::format_tag xFormat = mkldnnUtils::getFormat(*x);
|
||||
dnnl::memory::format_tag zFormat = mkldnnUtils::getFormat(*z);
|
||||
|
||||
dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank);
|
||||
// optimized cases
|
||||
if (2 == xRank && 0 == axis) {
|
||||
format = dnnl::memory::format_tag::ba;
|
||||
if(x->ews() == 1)
|
||||
xFormat = dnnl::memory::format_tag::ba;
|
||||
if(z->ews() == 1)
|
||||
zFormat = dnnl::memory::format_tag::ba;
|
||||
}
|
||||
else if (4 == xRank && 1 == axis && (x->sizeAt(2) * x->sizeAt(3)) > 1) {
|
||||
format = dnnl::memory::format_tag::acdb;
|
||||
if(x->ews() == 1)
|
||||
xFormat = dnnl::memory::format_tag::acdb;
|
||||
if(z->ews() == 1)
|
||||
zFormat = dnnl::memory::format_tag::acdb;
|
||||
}
|
||||
|
||||
dnnl::memory::data_type xType = dnnl::memory::data_type::f32;
|
||||
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, format);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format);
|
||||
mkldnnUtils::setBlockStrides(x, x_user_md);
|
||||
dnnl::memory::desc x_mkl_md, x_user_md, z_mkl_md, z_user_md;
|
||||
|
||||
x_user_md = x_mkl_md = dnnl::memory::desc(shape, xType, xFormat);
|
||||
mkldnnUtils::setBlockStrides(*x, x_user_md);
|
||||
|
||||
// z
|
||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, xType, format);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, xType, format);
|
||||
mkldnnUtils::setBlockStrides(z, z_user_md);
|
||||
z_user_md = z_mkl_md = dnnl::memory::desc(shape, xType, zFormat);
|
||||
mkldnnUtils::setBlockStrides(*z, z_user_md);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
|
@ -80,20 +85,17 @@ namespace sd {
|
|||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// z
|
||||
auto z_user_mem = dnnl::memory(z_user_md, engine, z->buffer());
|
||||
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
||||
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
||||
args[DNNL_ARG_DST] = z_mkl_mem;
|
||||
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*z, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
|
||||
|
||||
// run calculations
|
||||
dnnl::softmax_forward(op_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder outputs if necessary
|
||||
if (zReorder)
|
||||
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
||||
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
|
||||
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
|
||||
|
||||
stream.wait();
|
||||
}
|
||||
|
@ -142,33 +144,19 @@ namespace sd {
|
|||
//////////////////////////////////////////////////////////////////////
|
||||
static void softmaxBpMKLDNN(const NDArray* x, const NDArray* dLdz, NDArray* dLdx, const int axis) {
|
||||
|
||||
const auto xRank = x->rankOf();
|
||||
const auto dLdzRank = dLdz->rankOf();
|
||||
|
||||
dnnl::memory::dims xShape, dLdxShape, dLdzShape;
|
||||
|
||||
mkldnnUtils::getDims(x, xRank, xShape);
|
||||
mkldnnUtils::getDims(dLdx, xRank, dLdxShape);
|
||||
mkldnnUtils::getDims(dLdz, dLdzRank, dLdzShape);
|
||||
|
||||
dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank);
|
||||
dnnl::memory::desc x_user_md, x_mkl_md, dLdx_mkl_md, dLdx_user_md, dLdz_mkl_md, dLdz_user_md;
|
||||
|
||||
// x
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
||||
mkldnnUtils::setBlockStrides(x, x_user_md);
|
||||
x_mkl_md = x_user_md = dnnl::memory::desc(x->getShapeAsFlatVector(), dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*x));
|
||||
mkldnnUtils::setBlockStrides(*x, x_user_md);
|
||||
|
||||
// dLdx
|
||||
dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(dLdxShape, dnnl::memory::data_type::f32, format);
|
||||
dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(dLdxShape, dnnl::memory::data_type::f32, format);
|
||||
mkldnnUtils::setBlockStrides(dLdx, dLdx_user_md);
|
||||
// todo if mkl does not support broadcast we can remove this
|
||||
format = mkldnnUtils::getFormat(dLdzRank);
|
||||
dLdx_mkl_md = dLdx_user_md = dnnl::memory::desc(dLdx->getShapeAsFlatVector(), dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*dLdx));
|
||||
mkldnnUtils::setBlockStrides(*dLdx, dLdx_user_md);
|
||||
|
||||
// dLdz
|
||||
dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(dLdzShape, dnnl::memory::data_type::f32, format);
|
||||
dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(dLdzShape, dnnl::memory::data_type::f32, format);
|
||||
mkldnnUtils::setBlockStrides(dLdz, dLdz_user_md);
|
||||
dLdz_mkl_md = dLdz_user_md = dnnl::memory::desc(dLdz->getShapeAsFlatVector(), dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*dLdz));
|
||||
mkldnnUtils::setBlockStrides(*dLdz, dLdz_user_md);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
|
@ -188,19 +176,18 @@ namespace sd {
|
|||
|
||||
// provide memory buffers and check whether reorder is required for forward
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), argsff[DNNL_ARG_SRC]);
|
||||
mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_ff_prim_desc.src_desc(), argsff[DNNL_ARG_SRC]);
|
||||
|
||||
// dLdz
|
||||
mkldnnUtils::loadDataToMklStream(*dLdz, engine, stream, dLdz_user_md, op_bp_prim_desc.diff_dst_desc(), argsbp[DNNL_ARG_DIFF_DST]);
|
||||
|
||||
// dLdx
|
||||
auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->buffer());
|
||||
const bool dLdxReorder = op_ff_prim_desc.dst_desc() != dLdx_user_mem.get_desc();
|
||||
auto dLdx_mkl_mem = dLdxReorder ? dnnl::memory(op_ff_prim_desc.dst_desc(), engine) : dLdx_user_mem;
|
||||
argsff[DNNL_ARG_DST] = dLdx_mkl_mem;
|
||||
auto dLdx_user_mem = mkldnnUtils::loadDataToMklStream(*dLdx, engine, stream, dLdx_user_md, op_ff_prim_desc.src_desc(), argsff[DNNL_ARG_DST]);
|
||||
|
||||
// check and arg set for backprob
|
||||
argsbp[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem;
|
||||
argsbp[DNNL_ARG_DST] = dLdx_mkl_mem;
|
||||
// dLdz
|
||||
mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_bp_prim_desc.diff_dst_desc(), argsbp[DNNL_ARG_DIFF_DST]);
|
||||
argsbp[DNNL_ARG_DIFF_SRC] = argsff[DNNL_ARG_DST];
|
||||
argsbp[DNNL_ARG_DST] = argsff[DNNL_ARG_DST];
|
||||
|
||||
|
||||
// run calculations forward
|
||||
dnnl::softmax_forward(op_ff_prim_desc).execute(stream, argsff);
|
||||
|
@ -209,8 +196,8 @@ namespace sd {
|
|||
dnnl::softmax_backward(op_bp_prim_desc).execute(stream, argsbp);
|
||||
|
||||
// reorder outputs if necessary
|
||||
if (dLdxReorder)
|
||||
dnnl::reorder(dLdx_mkl_mem, dLdx_user_mem).execute(stream, dLdx_mkl_mem, dLdx_user_mem);
|
||||
if (op_ff_prim_desc.dst_desc() != dLdx_user_mem.get_desc())
|
||||
dnnl::reorder(argsff[DNNL_ARG_DST], dLdx_user_mem).execute(stream, argsff[DNNL_ARG_DST], dLdx_user_mem);
|
||||
|
||||
stream.wait();
|
||||
}
|
||||
|
|
|
@ -34,22 +34,16 @@ namespace sd {
|
|||
//////////////////////////////////////////////////////////////////////
|
||||
static void tanhMKLDNN(const NDArray* x, NDArray* z) {
|
||||
|
||||
const auto xRank = x->rankOf();
|
||||
dnnl::memory::dims xShape, zShape;
|
||||
dnnl::memory::dims shape = x->getShapeAsFlatVector();
|
||||
|
||||
mkldnnUtils::getDims(x, xRank, xShape);
|
||||
mkldnnUtils::getDims(z, xRank, zShape);
|
||||
dnnl::memory::desc x_mkl_md, x_user_md, z_mkl_md, z_user_md;
|
||||
|
||||
dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank);
|
||||
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
||||
mkldnnUtils::setBlockStrides(x, x_user_md);
|
||||
x_user_md = x_mkl_md = dnnl::memory::desc(shape, dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*x));
|
||||
mkldnnUtils::setBlockStrides(*x, x_user_md);
|
||||
|
||||
// z
|
||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, dnnl::memory::data_type::f32, format);
|
||||
mkldnnUtils::setBlockStrides(z, z_user_md);
|
||||
z_user_md = z_mkl_md = dnnl::memory::desc(shape, dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*z));
|
||||
mkldnnUtils::setBlockStrides(*z, z_user_md);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
|
@ -68,20 +62,17 @@ namespace sd {
|
|||
|
||||
// provide memory buffers and check whether reorder is required
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// z
|
||||
auto z_user_mem = dnnl::memory(z_user_md, engine, z->buffer());
|
||||
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
||||
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
||||
args[DNNL_ARG_DST] = z_mkl_mem;
|
||||
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*z, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
|
||||
|
||||
// run calculations
|
||||
dnnl::eltwise_forward(op_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder outputs if necessary
|
||||
if (zReorder)
|
||||
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
||||
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
|
||||
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
|
||||
|
||||
stream.wait();
|
||||
}
|
||||
|
@ -121,28 +112,21 @@ namespace sd {
|
|||
//////////////////////////////////////////////////////////////////////
|
||||
static void tanhBpMKLDNN(const NDArray* x, const NDArray* dLdz, NDArray* dLdx) {
|
||||
|
||||
const auto xRank = x->rankOf();
|
||||
dnnl::memory::dims xShape, dLdzShape, dLdxShape;
|
||||
dnnl::memory::dims shape = x->getShapeAsFlatVector();
|
||||
|
||||
mkldnnUtils::getDims(x, xRank, xShape);
|
||||
mkldnnUtils::getDims(dLdz, xRank, dLdzShape);
|
||||
mkldnnUtils::getDims(dLdx, xRank, dLdxShape);
|
||||
dnnl::memory::desc x_mkl_md, x_user_md, dLdx_mkl_md, dLdx_user_md, dLdz_mkl_md, dLdz_user_md;
|
||||
|
||||
dnnl::memory::format_tag format = mkldnnUtils::getFormat(xRank);
|
||||
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
||||
mkldnnUtils::setBlockStrides(x, x_user_md);
|
||||
// x
|
||||
x_user_md = x_mkl_md = dnnl::memory::desc(shape, dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*x));
|
||||
mkldnnUtils::setBlockStrides(*x, x_user_md);
|
||||
|
||||
// dLdz
|
||||
dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
||||
dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
||||
mkldnnUtils::setBlockStrides(dLdz, dLdz_user_md);
|
||||
dLdz_user_md = dLdz_mkl_md = dnnl::memory::desc(shape, dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*dLdz));
|
||||
mkldnnUtils::setBlockStrides(*dLdz, dLdz_user_md);
|
||||
|
||||
// dLdx
|
||||
dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
||||
dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(xShape, dnnl::memory::data_type::f32, format);
|
||||
mkldnnUtils::setBlockStrides(dLdx, dLdx_user_md);
|
||||
dLdx_user_md = dLdx_mkl_md = dnnl::memory::desc(shape, dnnl::memory::data_type::f32, mkldnnUtils::getFormat(*dLdx));
|
||||
mkldnnUtils::setBlockStrides(*dLdx, dLdx_user_md);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
|
@ -162,23 +146,20 @@ namespace sd {
|
|||
|
||||
// provide memory buffers and check whether reorder is required for forward
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// dLdz
|
||||
mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
|
||||
mkldnnUtils::loadDataToMklStream(*dLdz, engine, stream, dLdz_user_md, op_prim_desc.diff_dst_desc(), args[DNNL_ARG_DIFF_DST]);
|
||||
|
||||
// dLdx
|
||||
auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->buffer());
|
||||
const bool dLdxReorder = op_prim_desc.diff_src_desc() != dLdx_user_mem.get_desc();
|
||||
auto dLdx_mkl_mem = dLdxReorder ? dnnl::memory(op_prim_desc.diff_src_desc(), engine) : dLdx_user_mem;
|
||||
args[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem;
|
||||
auto dLdx_user_mem = mkldnnUtils::loadDataToMklStream(*dLdx, engine, stream, dLdx_user_md, op_prim_desc.diff_src_desc(), args[DNNL_ARG_DIFF_SRC]);
|
||||
|
||||
// run calculations backward
|
||||
dnnl::eltwise_backward(op_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder outputs if necessary
|
||||
if (dLdxReorder)
|
||||
dnnl::reorder(dLdx_mkl_mem, dLdx_user_mem).execute(stream, dLdx_mkl_mem, dLdx_user_mem);
|
||||
if (op_prim_desc.diff_src_desc() != dLdx_user_mem.get_desc())
|
||||
dnnl::reorder(args[DNNL_ARG_DIFF_SRC], dLdx_user_mem).execute(stream, args[DNNL_ARG_DIFF_SRC], dLdx_user_mem);
|
||||
|
||||
stream.wait();
|
||||
}
|
||||
|
|
|
@ -82,33 +82,23 @@ namespace sd {
|
|||
// memory descriptors for arrays
|
||||
// x
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, xType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, format);
|
||||
mkldnnUtils::setBlockStrides(x, x_user_md);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, xType, mkldnnUtils::getFormat(*x));
|
||||
mkldnnUtils::setBlockStrides(*x, x_user_md);
|
||||
|
||||
// weights
|
||||
dnnl::memory::desc weights_mkl_md = dnnl::memory::desc(wShape, wType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc weights_user_md = dnnl::memory::desc(wShape, wType, format);
|
||||
if (weights->ews() != 1 || weights->ordering() != 'c' || bShouldTransp) {
|
||||
dnnl::memory::desc weights_user_md = dnnl::memory::desc(wShape, wType, mkldnnUtils::getFormat(*weights));
|
||||
mkldnnUtils::setBlockStrides(*weights, weights_user_md, bShouldTransp ? std::vector<int>({1,0}) : std::vector<int>());
|
||||
|
||||
weights_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
if (bShouldTransp) {
|
||||
weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(1);
|
||||
weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(0);
|
||||
}
|
||||
else {
|
||||
weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(0);
|
||||
weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(1);
|
||||
}
|
||||
}
|
||||
// bias
|
||||
dnnl::memory::desc bias_mkl_md = dnnl::memory::desc(bShape, bType, dnnl::memory::format_tag::x);
|
||||
dnnl::memory::desc bias_user_md = dnnl::memory::desc(bShape, bType, dnnl::memory::format_tag::x);
|
||||
mkldnnUtils::setBlockStrides(bias, bias_user_md);
|
||||
dnnl::memory::desc bias_mkl_md = dnnl::memory::desc(bShape, bType, dnnl::memory::format_tag::a);
|
||||
dnnl::memory::desc bias_user_md = dnnl::memory::desc(bShape, bType, dnnl::memory::format_tag::a);
|
||||
mkldnnUtils::setBlockStrides(*bias, bias_user_md);
|
||||
|
||||
// z
|
||||
dnnl::memory::desc z_mkl_md = dnnl::memory::desc(zShape, zType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, zType, format);
|
||||
mkldnnUtils::setBlockStrides(z, z_user_md);
|
||||
dnnl::memory::desc z_user_md = dnnl::memory::desc(zShape, zType, mkldnnUtils::getFormat(*z));
|
||||
mkldnnUtils::setBlockStrides(*z, z_user_md);
|
||||
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
|
@ -125,27 +115,24 @@ namespace sd {
|
|||
// provide memory buffers and check whether reorder is required
|
||||
|
||||
// input
|
||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_prim_desc.src_desc(), args[DNNL_ARG_SRC]);
|
||||
|
||||
// weights
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, weights_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, weights_user_md, op_prim_desc.weights_desc(), args[DNNL_ARG_WEIGHTS]);
|
||||
|
||||
// bias
|
||||
auto bias_mkl_mem = dnnl::memory(bias_mkl_md, engine, const_cast<void*>(bias->buffer()));
|
||||
args[DNNL_ARG_BIAS] = bias_mkl_mem;
|
||||
|
||||
// z
|
||||
auto z_user_mem = dnnl::memory(z_user_md, engine, z->buffer());
|
||||
const bool zReorder = op_prim_desc.dst_desc() != z_user_mem.get_desc();
|
||||
auto z_mkl_mem = zReorder ? dnnl::memory(op_prim_desc.dst_desc(), engine) : z_user_mem;
|
||||
args[DNNL_ARG_DST] = z_mkl_mem;
|
||||
auto z_user_mem = mkldnnUtils::loadDataToMklStream(*z, engine, stream, z_user_md, op_prim_desc.dst_desc(), args[DNNL_ARG_DST]);
|
||||
|
||||
// run calculations
|
||||
dnnl::inner_product_forward(op_prim_desc).execute(stream, args);
|
||||
|
||||
// reorder outputs if necessary
|
||||
if (zReorder)
|
||||
dnnl::reorder(z_mkl_mem, z_user_mem).execute(stream, z_mkl_mem, z_user_mem);
|
||||
if (op_prim_desc.dst_desc() != z_user_mem.get_desc())
|
||||
dnnl::reorder(args[DNNL_ARG_DST], z_user_mem).execute(stream, args[DNNL_ARG_DST], z_user_mem);
|
||||
|
||||
stream.wait();
|
||||
}
|
||||
|
@ -160,7 +147,7 @@ namespace sd {
|
|||
|
||||
// [M,K] x [K,N] = [M,N]
|
||||
const int M = x->sizeAt(0);
|
||||
const int K = x->sizeAt(1); // K == wK
|
||||
const int K = x->sizeAt(1); // K == wK
|
||||
const int N = dLdz->sizeAt(1);
|
||||
// input dims
|
||||
dnnl::memory::dims xShape = dnnl::memory::dims({ M, K });
|
||||
|
@ -168,71 +155,53 @@ namespace sd {
|
|||
dnnl::memory::dims dLdzShape = dnnl::memory::dims({ M, N });
|
||||
|
||||
dnnl::memory::dims bShape = dnnl::memory::dims({ N });
|
||||
|
||||
// output dims
|
||||
dnnl::memory::dims dLdxShape = xShape;
|
||||
dnnl::memory::dims dLdwShape = wShape;
|
||||
|
||||
dnnl::memory::format_tag format = dnnl::memory::format_tag::ab;
|
||||
dnnl::memory::data_type dataType = dnnl::memory::data_type::f32;
|
||||
|
||||
// memory descriptors for arrays
|
||||
// x
|
||||
dnnl::memory::desc x_mkl_md = dnnl::memory::desc(xShape, dataType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dataType, format);
|
||||
mkldnnUtils::setBlockStrides(x, x_user_md);
|
||||
dnnl::memory::desc x_user_md = dnnl::memory::desc(xShape, dataType, mkldnnUtils::getFormat(*x));
|
||||
mkldnnUtils::setBlockStrides(*x, x_user_md);
|
||||
|
||||
// weights
|
||||
dnnl::memory::desc weights_mkl_md = dnnl::memory::desc(wShape, dataType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc weights_user_md = dnnl::memory::desc(wShape, dataType, format);
|
||||
if (weights->ews() != 1 || weights->ordering() != 'c' || bShouldTransp) {
|
||||
dnnl::memory::desc weights_user_md = dnnl::memory::desc(wShape, dataType, mkldnnUtils::getFormat(*weights));
|
||||
mkldnnUtils::setBlockStrides(*weights, weights_user_md, bShouldTransp ? std::vector<int>({1,0}) : std::vector<int>());
|
||||
|
||||
weights_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
if (bShouldTransp) {
|
||||
weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(1);
|
||||
weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(0);
|
||||
}
|
||||
else {
|
||||
weights_user_md.data.format_desc.blocking.strides[0] = weights->strideAt(0);
|
||||
weights_user_md.data.format_desc.blocking.strides[1] = weights->strideAt(1);
|
||||
}
|
||||
}
|
||||
// bias
|
||||
dnnl::memory::desc bias_mkl_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x);
|
||||
dnnl::memory::desc bias_user_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x);
|
||||
mkldnnUtils::setBlockStrides(bias, bias_user_md);
|
||||
dnnl::memory::desc bias_mkl_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc bias_user_md = dnnl::memory::desc(bShape, dataType, mkldnnUtils::getFormat(*bias));
|
||||
mkldnnUtils::setBlockStrides(*bias, bias_user_md);
|
||||
|
||||
// dLdz
|
||||
dnnl::memory::desc dLdz_mkl_md = dnnl::memory::desc(dLdzShape, dataType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(dLdzShape, dataType, format);
|
||||
mkldnnUtils::setBlockStrides(dLdz, dLdz_user_md);
|
||||
dnnl::memory::desc dLdz_user_md = dnnl::memory::desc(dLdzShape, dataType, mkldnnUtils::getFormat(*dLdz));
|
||||
mkldnnUtils::setBlockStrides(*dLdz, dLdz_user_md);
|
||||
|
||||
|
||||
// dLdw
|
||||
dnnl::memory::desc dLdw_mkl_md = dnnl::memory::desc(wShape, dataType, format);
|
||||
dnnl::memory::desc dLdw_user_md = dnnl::memory::desc(wShape, dataType, format);
|
||||
if (dLdw->ews() != 1 || dLdw->ordering() != 'c' || bShouldTransp) {
|
||||
|
||||
dLdw_user_md.data.format_kind = dnnl_blocked; // overrides format
|
||||
if (bShouldTransp) {
|
||||
dLdw_user_md.data.format_desc.blocking.strides[0] = dLdw->strideAt(1);
|
||||
dLdw_user_md.data.format_desc.blocking.strides[1] = dLdw->strideAt(0);
|
||||
}
|
||||
else {
|
||||
dLdw_user_md.data.format_desc.blocking.strides[0] = dLdw->strideAt(0);
|
||||
dLdw_user_md.data.format_desc.blocking.strides[1] = dLdw->strideAt(1);
|
||||
}
|
||||
}
|
||||
dnnl::memory::desc dLdw_mkl_md = dnnl::memory::desc(wShape, dataType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc dLdw_user_md = dnnl::memory::desc(wShape, dataType, mkldnnUtils::getFormat(*dLdw));
|
||||
mkldnnUtils::setBlockStrides(*dLdw, dLdw_user_md, bShouldTransp ? std::vector<int>({1,0}) : std::vector<int>());
|
||||
|
||||
// dLdb
|
||||
dnnl::memory::desc dLdb_mkl_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x);
|
||||
dnnl::memory::desc dLdb_user_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::x);
|
||||
mkldnnUtils::setBlockStrides(dLdb, dLdb_user_md);
|
||||
dnnl::memory::desc dLdb_mkl_md = dnnl::memory::desc(bShape, dataType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc dLdb_user_md = dnnl::memory::desc(bShape, dataType, mkldnnUtils::getFormat(*dLdb));
|
||||
mkldnnUtils::setBlockStrides(*dLdb, dLdb_user_md);
|
||||
|
||||
// dLdx
|
||||
dnnl::memory::desc dLdx_mkl_md = dnnl::memory::desc(xShape, dataType, dnnl::memory::format_tag::any);
|
||||
dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(xShape, dataType, format);
|
||||
mkldnnUtils::setBlockStrides(dLdx, dLdx_user_md);
|
||||
dnnl::memory::desc dLdx_user_md = dnnl::memory::desc(xShape, dataType, mkldnnUtils::getFormat(*dLdx));
|
||||
mkldnnUtils::setBlockStrides(*dLdx, dLdx_user_md);
|
||||
|
||||
// create engine
|
||||
auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine());
|
||||
|
||||
// forward
|
||||
// operation primitive description
|
||||
dnnl::inner_product_forward::desc op_ff_desc(dnnl::prop_kind::forward_inference, x_mkl_md, weights_mkl_md, bias_mkl_md, dLdz_mkl_md);
|
||||
|
@ -254,34 +223,25 @@ namespace sd {
|
|||
dnnl::stream stream(engine);
|
||||
|
||||
// dLdz dw
|
||||
mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_bpdw_prim_desc.diff_dst_desc(), argsDw[DNNL_ARG_DIFF_DST]);
|
||||
mkldnnUtils::loadDataToMklStream(*dLdz, engine, stream, dLdz_user_md, op_bpdw_prim_desc.diff_dst_desc(), argsDw[DNNL_ARG_DIFF_DST]);
|
||||
|
||||
// dLdz - dx
|
||||
mkldnnUtils::loadDataToMklStream(dLdz, engine, stream, dLdz_user_md, op_bpdx_prim_desc.diff_dst_desc(), argsDx[DNNL_ARG_DIFF_DST]);
|
||||
mkldnnUtils::loadDataToMklStream(*dLdz, engine, stream, dLdz_user_md, op_bpdx_prim_desc.diff_dst_desc(), argsDx[DNNL_ARG_DIFF_DST]);
|
||||
|
||||
// input x for dw
|
||||
mkldnnUtils::loadDataToMklStream(x, engine, stream, x_user_md, op_bpdw_prim_desc.src_desc(), argsDw[DNNL_ARG_SRC]);
|
||||
mkldnnUtils::loadDataToMklStream(*x, engine, stream, x_user_md, op_bpdw_prim_desc.src_desc(), argsDw[DNNL_ARG_SRC]);
|
||||
|
||||
// weights - dx
|
||||
mkldnnUtils::loadDataToMklStream(weights, engine, stream, weights_user_md, op_bpdx_prim_desc.weights_desc(), argsDx[DNNL_ARG_WEIGHTS]);
|
||||
mkldnnUtils::loadDataToMklStream(*weights, engine, stream, weights_user_md, op_bpdx_prim_desc.weights_desc(), argsDx[DNNL_ARG_WEIGHTS]);
|
||||
|
||||
// dLdw
|
||||
auto dLdw_user_mem = dnnl::memory(dLdw_user_md, engine, dLdw->buffer());
|
||||
const bool dLdwReorder = op_bpdw_prim_desc.diff_weights_desc() != dLdw_user_mem.get_desc();
|
||||
auto dLdw_mkl_mem = dLdwReorder ? dnnl::memory(op_bpdw_prim_desc.diff_weights_desc(), engine) : dLdw_user_mem;
|
||||
argsDw[DNNL_ARG_DIFF_WEIGHTS] = dLdw_mkl_mem;
|
||||
// dLdw
|
||||
auto dLdw_user_mem = mkldnnUtils::loadDataToMklStream(*dLdw, engine, stream, dLdw_user_md, op_bpdw_prim_desc.diff_weights_desc(), argsDw[DNNL_ARG_DIFF_WEIGHTS]);
|
||||
|
||||
// dLdx
|
||||
auto dLdx_user_mem = dnnl::memory(dLdx_user_md, engine, dLdx->buffer());
|
||||
const bool dLdxReorder = op_bpdx_prim_desc.diff_src_desc() != dLdx_user_mem.get_desc();
|
||||
auto dLdx_mkl_mem = dLdxReorder ? dnnl::memory(op_bpdx_prim_desc.diff_src_desc(), engine) : dLdx_user_mem;
|
||||
argsDx[DNNL_ARG_DIFF_SRC] = dLdx_mkl_mem;
|
||||
// dLdx
|
||||
auto dLdx_user_mem = mkldnnUtils::loadDataToMklStream(*dLdx, engine, stream, dLdx_user_md, op_bpdx_prim_desc.diff_src_desc(), argsDx[DNNL_ARG_DIFF_SRC]);
|
||||
|
||||
// dLdb
|
||||
auto dLdb_user_mem = dnnl::memory(dLdb_user_md, engine, dLdb->buffer());
|
||||
const bool dLdbReorder = op_bpdw_prim_desc.diff_bias_desc() != dLdb_user_mem.get_desc();
|
||||
auto dLdb_mkl_mem = dLdbReorder ? dnnl::memory(op_bpdw_prim_desc.diff_bias_desc(), engine) : dLdb_user_mem;
|
||||
argsDw[DNNL_ARG_DIFF_BIAS] = dLdb_mkl_mem;
|
||||
auto dLdb_user_mem = mkldnnUtils::loadDataToMklStream(*dLdb, engine, stream, dLdb_user_md, op_bpdw_prim_desc.diff_bias_desc(), argsDw[DNNL_ARG_DIFF_BIAS]);
|
||||
|
||||
// run calculations dw
|
||||
dnnl::inner_product_backward_weights(op_bpdw_prim_desc).execute(stream, argsDw);
|
||||
|
@ -289,14 +249,14 @@ namespace sd {
|
|||
dnnl::inner_product_backward_data(op_bpdx_prim_desc).execute(stream, argsDx);
|
||||
|
||||
// reorder outputs if necessary
|
||||
if (dLdxReorder)
|
||||
dnnl::reorder(dLdx_mkl_mem, dLdx_user_mem).execute(stream, dLdx_mkl_mem, dLdx_user_mem);
|
||||
if (op_bpdx_prim_desc.diff_src_desc() != dLdx_user_mem.get_desc())
|
||||
dnnl::reorder(argsDx[DNNL_ARG_DIFF_SRC], dLdx_user_mem).execute(stream, argsDx[DNNL_ARG_DIFF_SRC], dLdx_user_mem);
|
||||
|
||||
if (dLdwReorder)
|
||||
dnnl::reorder(dLdw_mkl_mem, dLdw_user_mem).execute(stream, dLdw_mkl_mem, dLdw_user_mem);
|
||||
if (op_bpdw_prim_desc.diff_weights_desc() != dLdw_user_mem.get_desc())
|
||||
dnnl::reorder(argsDw[DNNL_ARG_DIFF_WEIGHTS], dLdw_user_mem).execute(stream, argsDw[DNNL_ARG_DIFF_WEIGHTS], dLdw_user_mem);
|
||||
|
||||
if (dLdbReorder)
|
||||
dnnl::reorder(dLdb_mkl_mem, dLdb_user_mem).execute(stream, dLdb_mkl_mem, dLdb_user_mem);
|
||||
if (op_bpdw_prim_desc.diff_bias_desc() != dLdb_user_mem.get_desc())
|
||||
dnnl::reorder(argsDw[DNNL_ARG_DIFF_BIAS], dLdb_user_mem).execute(stream, argsDw[DNNL_ARG_DIFF_BIAS], dLdb_user_mem);
|
||||
|
||||
stream.wait();
|
||||
}
|
||||
|
@ -315,7 +275,7 @@ namespace sd {
|
|||
const int wRank = w->rankOf();
|
||||
const int zRank = z->rankOf();
|
||||
|
||||
const bool bShouldTransp = block.getIArguments()->size() > 0 ? (1 != INT_ARG(0)) : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N]
|
||||
const bool bShouldTransp = block.getIArguments()->size() > 0 ? (1 != INT_ARG(0)) : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N]
|
||||
|
||||
REQUIRE_TRUE(xRank == 2, 0, "xw_plus_b MKL: Input x array should have rank equal 2, but got instead %i!", xRank);
|
||||
REQUIRE_TRUE(wRank == 2, 0, "xw_plus_b MKL: Input weights array should have rank equal 2, but got instead %i!", wRank);
|
||||
|
@ -378,7 +338,7 @@ namespace sd {
|
|||
const int wRank = w->rankOf();
|
||||
const int dLdzRank = dLdz->rankOf();
|
||||
|
||||
const bool bShouldTransp = block.getIArguments()->size() > 0 ? (1 != INT_ARG(0)) : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N]
|
||||
const bool bShouldTransp = block.getIArguments()->size() > 0 ? (1 != INT_ARG(0)) : true; // [M,K] * [K,N] -> [M, N], mkl -> [M,K] * [N, K]^T -> [M, N]
|
||||
|
||||
REQUIRE_TRUE(x->rankOf() == 2, 0, "xw_plus_b BP MKL: Input x array should have rank equal 2, but got instead %i!", x->rankOf());
|
||||
REQUIRE_TRUE(w->rankOf() == 2, 0, "xw_plus_b BP MKL: Input weights array should have rank equal 2, but got instead %i!", w->rankOf());
|
||||
|
|
|
@ -107,6 +107,25 @@ namespace sd {
|
|||
// samediff::Threads::parallel_tad(func, 0, numOfArrs);
|
||||
// }
|
||||
|
||||
// static Nd4jLong strideOverContigAxis(const int axis, const Nd4jLong* inShapeInfo) {
|
||||
|
||||
// Nd4jLong result = 9223372036854775807LL;
|
||||
|
||||
// for(uint i = 0; i < shape::rank(inShapeInfo); ++i) {
|
||||
|
||||
// const auto currentStride = shape::stride(inShapeInfo)[i];
|
||||
|
||||
// if(i == axis || shape::shapeOf(inShapeInfo)[i] == 1)
|
||||
// continue;
|
||||
|
||||
// if(result > currentStride)
|
||||
// result = currentStride;
|
||||
// }
|
||||
|
||||
// return result == 9223372036854775807LL ? 1 : result;
|
||||
// }
|
||||
|
||||
|
||||
template <typename T>
|
||||
void SpecialMethods<T>::concatCpuGeneric(const std::vector<const NDArray*>& inArrs, NDArray& output, const int axis) {
|
||||
|
||||
|
@ -150,7 +169,7 @@ void SpecialMethods<T>::concatCpuGeneric(const std::vector<const NDArray*>& inAr
|
|||
// if(!areInputsContin || !allSameOrder)
|
||||
// break;
|
||||
|
||||
// strideOfContigStride[i] = shape::strideOverContigAxis(axis, inArrs[i]->shapeInfo());
|
||||
// strideOfContigStride[i] = strideOverContigAxis(axis, inArrs[i]->getShapeInfo());
|
||||
// }
|
||||
// }
|
||||
|
||||
|
@ -158,7 +177,7 @@ void SpecialMethods<T>::concatCpuGeneric(const std::vector<const NDArray*>& inAr
|
|||
|
||||
// if(luckCase2) { // for example {2,1,3} + {2,5,3} + {2,10,3} = {2,16,3}, here axis 1 shoud have stride = 1 for all inputs arrays and output array
|
||||
|
||||
// const auto zStep = shape::strideOverContigAxis(axis, output.shapeInfo());
|
||||
// const auto zStep = strideOverContigAxis(axis, output.getShapeInfo());
|
||||
|
||||
// for (uint i = 0; i < output.lengthOf() / output.sizeAt(axis); ++i) {
|
||||
|
||||
|
|
|
@ -184,7 +184,7 @@ TEST_F(DeclarableOpsTests16, test_range_2) {
|
|||
double tArgs[] = { -1.0, 1.0, 0.01 };
|
||||
|
||||
auto shapes = ::calculateOutputShapes2(nullptr, op.getOpHash(), nullptr, nullptr, 0, tArgs, 3, nullptr, 0, nullptr, 0, nullptr, 0);
|
||||
shape::printShapeInfoLinear("Result", shapes->at(0));
|
||||
// shape::printShapeInfoLinear("Result", shapes->at(0));
|
||||
ASSERT_TRUE(shape::shapeEquals(z.shapeInfo(), shapes->at(0)));
|
||||
|
||||
delete shapes;
|
||||
|
@ -426,7 +426,7 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_hsv_6) {
|
|||
0.928968489f, 0.684074104f
|
||||
});
|
||||
|
||||
//get subarray
|
||||
//get subarray
|
||||
//get subarray
|
||||
NDArray subArrRgbs = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
||||
NDArray expected = hsvs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
||||
|
@ -627,7 +627,7 @@ TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_6) {
|
|||
});
|
||||
|
||||
auto actual = NDArrayFactory::create<float>('c', { 3 });
|
||||
//get subarray
|
||||
//get subarray
|
||||
NDArray subArrHsvs = hsvs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
||||
subArrHsvs.reshapei({ 3 });
|
||||
NDArray expected = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
||||
|
@ -635,7 +635,7 @@ TEST_F(DeclarableOpsTests16, test_hsv_to_rgb_6) {
|
|||
#if 0
|
||||
//[RANK][SHAPE][STRIDES][OPTIONS][EWS][ORDER]
|
||||
subArrHsvs.printShapeInfo("subArrHsvs");
|
||||
#endif
|
||||
#endif
|
||||
|
||||
Context ctx(1);
|
||||
ctx.setInputArray(0, &subArrHsvs);
|
||||
|
@ -855,7 +855,7 @@ TEST_F(DeclarableOpsTests16, test_rgb_to_yiq_6) {
|
|||
-0.04447775f, -0.44518381f
|
||||
});
|
||||
|
||||
//get subarray
|
||||
//get subarray
|
||||
NDArray subArrRgbs = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
||||
NDArray expected = yiqs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
||||
subArrRgbs.reshapei({ 3 });
|
||||
|
@ -1054,7 +1054,7 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_6) {
|
|||
0.280231822f, 1.91936605f
|
||||
});
|
||||
|
||||
//get subarray
|
||||
//get subarray
|
||||
NDArray subArrYiqs = yiqs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
||||
NDArray expected = rgbs.subarray({ NDIndex::all(), NDIndex::point(0) });
|
||||
subArrYiqs.reshapei({ 3 });
|
||||
|
@ -1074,3 +1074,422 @@ TEST_F(DeclarableOpsTests16, test_yiq_to_rgb_6) {
|
|||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
||||
ASSERT_TRUE(expected.equalsTo(actual));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests16, clipbynorm_1) {
|
||||
auto x= NDArrayFactory::create<double>('c', {2, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0});
|
||||
auto exp= NDArrayFactory::create<double>('c', {2, 3}, {-2.4, 0.0, 0.0, 3.2, 0.0, 0.0});
|
||||
|
||||
sd::ops::clipbynorm op;
|
||||
auto result = op.evaluate({&x}, {4.0}, {});
|
||||
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests16, clipbynorm_2) {
|
||||
auto x= NDArrayFactory::create<double>('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f});
|
||||
auto exp= NDArrayFactory::create<double>('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f});
|
||||
|
||||
sd::ops::clipbynorm op;
|
||||
auto result = op.evaluate({&x}, {6.0}, {});
|
||||
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests16, clipbynorm_3) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {3, 5});
|
||||
auto unities = NDArrayFactory::create<double>('c', {3, 1}, {1., 1., 1.});
|
||||
auto scale = NDArrayFactory::create<double>('c', {3, 1}, {1.1, 1., 0.9});
|
||||
|
||||
x.linspace(100.);
|
||||
|
||||
auto xNorm1 = x.reduceAlongDimension(reduce::Norm2, {1}, true);
|
||||
x /= xNorm1;
|
||||
xNorm1 = x.reduceAlongDimension(reduce::Norm2,{1}, true);
|
||||
|
||||
ASSERT_TRUE(unities.isSameShape(xNorm1));
|
||||
ASSERT_TRUE(unities.equalsTo(xNorm1));
|
||||
|
||||
x *= scale;
|
||||
xNorm1 = x.reduceAlongDimension(reduce::Norm2, {1}, true);
|
||||
|
||||
sd::ops::clipbynorm op;
|
||||
auto result = op.evaluate({&x}, {1.0}, {1});
|
||||
auto z = result.at(0);
|
||||
|
||||
auto zNorm1 = z->reduceAlongDimension(reduce::Norm2, {1}, true);
|
||||
auto exp = NDArrayFactory::create<double>('c', {3, 1}, {1., 1., xNorm1.e<double>(2)});
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(&zNorm1));
|
||||
ASSERT_TRUE(exp.equalsTo(&zNorm1));
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests16, clipbynorm_4) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {3, 5}, {0.7044955, 0.55606544, 0.15833677, 0.001874401, 0.61595726, 0.3924779, 0.7414847, 0.4127324, 0.24026828, 0.26093036, 0.46741188, 0.01863421, 0.08528871, 0.529365, 0.5510694});
|
||||
auto exp = NDArrayFactory::create<double>('c', {3, 5}, {0.405392, 0.319980, 0.091113, 0.001079, 0.354444, 0.225846, 0.426676, 0.237501, 0.138259, 0.150149, 0.268965, 0.010723, 0.049078, 0.304615, 0.317105});
|
||||
|
||||
sd::ops::clipbynorm op;
|
||||
auto result = op.evaluate({&x}, {1.f}, {});
|
||||
auto output = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(output));
|
||||
ASSERT_TRUE(exp.equalsTo(output));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests16, clipbynorm_5) {
|
||||
|
||||
// auto x = NDArrayFactory::create<double>('c', {3, 5}, {1,2,3,4,5, 1,2,3,4,5, 1,2,3,4,5});
|
||||
auto x = NDArrayFactory::create<double>('c', {3, 5});
|
||||
auto exp = NDArrayFactory::create<double>('c', {3, 5}, {1., 2., 2.89271, 3.50524, 4.00892, 6., 7., 7.71389, 7.88678, 8.01784, 11., 12., 12.53507, 12.26833, 12.02676});
|
||||
// auto exp = NDArrayFactory::create<double>('c', {3, 5}, {1,1,1,1,1,1,1,1,1,1,1,1,1,1,1});
|
||||
|
||||
x.linspace(1);
|
||||
|
||||
sd::ops::clipbynorm op;
|
||||
auto result = op.evaluate({&x}, {15.f}, {0});
|
||||
auto output = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(output));
|
||||
ASSERT_TRUE(exp.equalsTo(output));
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests16, clipbynorm_6) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {3, 5});
|
||||
auto exp = NDArrayFactory::create<double>('c', {3, 5}, {1., 2., 3., 4., 5., 4.95434, 5.78006, 6.60578, 7.43151, 8.25723, 5.64288, 6.15587, 6.66886, 7.18185, 7.69484});
|
||||
|
||||
x.linspace(1);
|
||||
|
||||
sd::ops::clipbynorm op;
|
||||
auto result = op.evaluate({&x}, {15.f}, {1});
|
||||
auto output = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(output));
|
||||
ASSERT_TRUE(exp.equalsTo(output));
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests16, clipbynorm_7) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {3, 5});
|
||||
auto exp = NDArrayFactory::create<double>('c', {3, 5}, {0.42597, 0.85194, 1.27791, 1.70389, 2.12986, 2.55583, 2.9818 , 3.40777, 3.83374, 4.25971, 4.68569, 5.11166, 5.53763, 5.9636 , 6.38957});
|
||||
|
||||
x.linspace(1);
|
||||
|
||||
sd::ops::clipbynorm op;
|
||||
auto result = op.evaluate({&x}, {15.f}, {0,1});
|
||||
auto output = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(output));
|
||||
ASSERT_TRUE(exp.equalsTo(output));
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests16, clipbynorm_8) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {3, 5});
|
||||
auto exp = NDArrayFactory::create<double>('c', {3, 5}, {0.42597, 0.85194, 1.27791, 1.70389, 2.12986, 2.55583, 2.9818 , 3.40777, 3.83374, 4.25971, 4.68569, 5.11166, 5.53763, 5.9636 , 6.38957});
|
||||
|
||||
x.linspace(1);
|
||||
|
||||
sd::ops::clipbynorm op;
|
||||
auto result = op.evaluate({&x}, {15.}, {});
|
||||
auto output = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(output));
|
||||
ASSERT_TRUE(exp.equalsTo(output));
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests16, clipbynorm_9) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {2}, {3., 4.});
|
||||
auto exp = NDArrayFactory::create<double>('c', {2}, {2.4, 3.2});
|
||||
|
||||
sd::ops::clipbynorm op;
|
||||
auto result = op.evaluate({&x}, {4.}, {});
|
||||
auto output = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(output));
|
||||
ASSERT_TRUE(exp.equalsTo(output));
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests16, clipbynorm_10) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>(6.);
|
||||
auto exp = NDArrayFactory::create<double>(5.);
|
||||
|
||||
sd::ops::clipbynorm op;
|
||||
auto result = op.evaluate({&x}, {5.}, {});
|
||||
auto output = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(output));
|
||||
ASSERT_TRUE(exp.equalsTo(output));
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests16, clipbynorm_11) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {2, 3, 4});
|
||||
auto exp = NDArrayFactory::create<double>('c', {2, 3, 4}, {1., 2., 3., 4., 4.44787, 5.33745, 6.22702, 7.1166 , 6.33046, 7.03384, 7.73723, 8.44061,
|
||||
13., 14., 15., 16., 15.12277, 16.01235, 16.90192, 17.7915 ,14.77107, 15.47446, 16.17784, 16.88123});
|
||||
|
||||
x.linspace(1);
|
||||
|
||||
sd::ops::clipbynorm op;
|
||||
auto result = op.evaluate({&x}, {35.}, {0, 2});
|
||||
auto output = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(output));
|
||||
ASSERT_TRUE(exp.equalsTo(output));
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests16, clipbynorm_12) {
|
||||
auto x = NDArrayFactory::create<double>('c', {3, 3}, {1, 2, 3, 4, 5,6, 7, 8, 9});
|
||||
auto e = NDArrayFactory::create<double>('c', {3, 3}, {0.03198684, 0.06397368, 0.09596053, 0.12794736, 0.15993419, 0.19192106, 0.22390789, 0.25589472, 0.28788155});
|
||||
|
||||
sd::ops::clipbynorm op;
|
||||
auto result = op.evaluate({&x}, {0.54}, {});
|
||||
|
||||
ASSERT_EQ(e, *result.at(0));
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests16, clipbynorm_13) {
|
||||
|
||||
const int bS = 5;
|
||||
const int nOut = 4;
|
||||
const int axis = 0;
|
||||
const double clip = 2.;
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.897 ,0.173 ,0.931 ,0.736 ,0.540 ,0.953 ,0.278 ,0.573 ,0.787 ,0.320 ,0.776 ,0.338 ,0.311 ,0.835 ,0.909 ,0.890 ,0.290}); // uniform random in range [0,1]
|
||||
auto colVect = NDArrayFactory::create<double>('c', {bS, 1}, {0.9, 0.95, 1.00, 1.05, 1.1});
|
||||
auto expect = NDArrayFactory::create<double>('c', {bS, nOut});
|
||||
|
||||
auto norm2 = x.reduceAlongDimension(reduce::Norm2, {axis}, true); // norm2 has shape [1, nOut]
|
||||
|
||||
auto y = ( (x / norm2) * clip) * colVect ;
|
||||
auto temp = (x / norm2) * clip;
|
||||
|
||||
for (int j = 0; j < nOut; ++j) {
|
||||
auto yCol = y({0,0, j,j+1});
|
||||
const double norm2Col = yCol.reduceNumber(reduce::Norm2).e<double>(0);
|
||||
if (norm2Col <= clip)
|
||||
expect({0,0, j,j+1}).assign(yCol);
|
||||
else
|
||||
expect({0,0, j,j+1}).assign ( yCol * (clip / norm2Col) );
|
||||
}
|
||||
|
||||
sd::ops::clipbynorm op;
|
||||
auto result = op.evaluate({&y}, {clip}, {axis});
|
||||
auto outFF = result.at(0);
|
||||
|
||||
ASSERT_TRUE(expect.isSameShape(outFF));
|
||||
ASSERT_TRUE(expect.equalsTo(outFF));
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests16, clipbynorm_bp_1) {
|
||||
|
||||
const int bS = 2;
|
||||
const int nOut = 3;
|
||||
const double clip = 0.7;
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1]
|
||||
auto gradO = NDArrayFactory::create<double>('c', {bS, nOut});
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x}, {clip}, {});
|
||||
const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {});
|
||||
|
||||
sd::ops::clipbynorm opFF;
|
||||
sd::ops::clipbynorm_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests16, clipbynorm_bp_2) {
|
||||
|
||||
const int bS = 2;
|
||||
const int nOut = 3;
|
||||
const int axis = 0;
|
||||
const double clip = 0.7;
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1]
|
||||
auto gradO = NDArrayFactory::create<double>('c', {bS, nOut});
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x}, {clip}, {axis});
|
||||
const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {axis});
|
||||
|
||||
sd::ops::clipbynorm opFF;
|
||||
sd::ops::clipbynorm_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests16, clipbynorm_bp_3) {
|
||||
|
||||
const int bS = 2;
|
||||
const int nOut = 3;
|
||||
const int axis = 1;
|
||||
const double clip = 1.;
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1]
|
||||
auto gradO = NDArrayFactory::create<double>('c', {bS, nOut});
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x}, {clip}, {axis});
|
||||
const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {axis});
|
||||
|
||||
sd::ops::clipbynorm opFF;
|
||||
sd::ops::clipbynorm_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests16, clipbyavgnorm_1) {
|
||||
auto x = NDArrayFactory::create<double>('c', {2, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0});
|
||||
auto exp = NDArrayFactory::create<double>('c', {2, 3}, {-2.88, 0.0, 0.0, 3.84, 0.0, 0.0});
|
||||
|
||||
sd::ops::clipbyavgnorm op;
|
||||
auto result = op.evaluate({&x}, {0.8}, {});
|
||||
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests16, clipbyavgnorm_2) {
|
||||
auto x= NDArrayFactory::create<float>('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f});
|
||||
auto exp= NDArrayFactory::create<float>('c', {2, 3}, {-3.f, 0.0f, 0.0f, 4.f, 0.0f, 0.0f});
|
||||
|
||||
sd::ops::clipbyavgnorm op;
|
||||
auto result = op.evaluate({&x}, {0.9}, {});
|
||||
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests16, clipbyavgnorm_bp_1) {
|
||||
|
||||
const int bS = 2;
|
||||
const int nOut = 3;
|
||||
const double clip = 0.7;
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1]
|
||||
auto gradO = NDArrayFactory::create<double>('c', {bS, nOut});
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x}, {clip}, {});
|
||||
const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {});
|
||||
|
||||
sd::ops::clipbyavgnorm opFF;
|
||||
sd::ops::clipbyavgnorm_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests16, clipbyavgnorm_bp_2) {
|
||||
|
||||
const int bS = 2;
|
||||
const int nOut = 3;
|
||||
const int axis = 1;
|
||||
const double clip = 1.;
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1]
|
||||
auto gradO = NDArrayFactory::create<double>('c', {bS, nOut});
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x}, {clip}, {axis});
|
||||
const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {axis});
|
||||
|
||||
sd::ops::clipbyavgnorm opFF;
|
||||
sd::ops::clipbyavgnorm_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests16, clipbyavgnorm_bp_3) {
|
||||
|
||||
NDArray x('c', {2, 3, 4}, {-0.14 ,0.96 ,0.47 ,-0.98 ,0.03 ,0.95 ,0.33 ,-0.97 ,0.59 ,-0.92 ,-0.12 ,-0.33 ,0.82 ,-0.76 ,-0.69 ,-0.95 ,-0.77 ,0.25 ,-0.35 ,0.94 ,0.50 ,0.04 ,0.61 ,0.99}, sd::DataType::DOUBLE);
|
||||
NDArray gradO('c', {2, 3, 4}, sd::DataType::DOUBLE);
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x}, {0.7}, {0,2});
|
||||
const OpArgsHolder argsHolderBP({&x, &gradO}, {0.7}, {0,2});
|
||||
|
||||
sd::ops::clipbyavgnorm opFF;
|
||||
sd::ops::clipbyavgnorm_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
|
|
|
@ -50,7 +50,7 @@ TEST_F(DeclarableOpsTests3, Test_Tile_1) {
|
|||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -68,7 +68,7 @@ TEST_F(DeclarableOpsTests3, Test_Tile_2) {
|
|||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests3, Test_Permute_1) {
|
||||
|
@ -123,7 +123,7 @@ TEST_F(DeclarableOpsTests3, Test_Unique_1) {
|
|||
ASSERT_TRUE(expI.isSameShape(i));
|
||||
ASSERT_TRUE(expI.equalsTo(i));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests3, Test_Unique_2) {
|
||||
|
@ -171,7 +171,7 @@ TEST_F(DeclarableOpsTests3, Test_Rint_1) {
|
|||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -226,7 +226,7 @@ TEST_F(DeclarableOpsTests3, Test_Norm_2) {
|
|||
ASSERT_TRUE(exp0.isSameShape(z0));
|
||||
ASSERT_TRUE(exp0.equalsTo(z0));
|
||||
|
||||
|
||||
|
||||
|
||||
auto result1 = op.evaluate({&x, &axis}, {1}, {});
|
||||
|
||||
|
@ -244,94 +244,6 @@ TEST_F(DeclarableOpsTests3, Test_Norm_2) {
|
|||
|
||||
}
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests3, Test_ClipByAvgNorm_1) {
|
||||
auto x = NDArrayFactory::create<double>('c', {2, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0});
|
||||
auto exp = NDArrayFactory::create<double>('c', {2, 3}, {-2.88, 0.0, 0.0, 3.84, 0.0, 0.0});
|
||||
|
||||
sd::ops::clipbyavgnorm op;
|
||||
auto result = op.evaluate({&x}, {0.8}, {});
|
||||
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests3, Test_ClipByAvgNorm_2) {
|
||||
auto x= NDArrayFactory::create<float>('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f});
|
||||
auto exp= NDArrayFactory::create<float>('c', {2, 3}, {-3.f, 0.0f, 0.0f, 4.f, 0.0f, 0.0f});
|
||||
|
||||
sd::ops::clipbyavgnorm op;
|
||||
auto result = op.evaluate({&x}, {0.9}, {});
|
||||
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
}
|
||||
|
||||
|
||||
TEST_F(DeclarableOpsTests3, Test_ClipByNorm_1) {
|
||||
auto x= NDArrayFactory::create<double>('c', {2, 3}, {-3.0, 0.0, 0.0, 4.0, 0.0, 0.0});
|
||||
auto exp= NDArrayFactory::create<double>('c', {2, 3}, {-2.4, 0.0, 0.0, 3.2, 0.0, 0.0});
|
||||
|
||||
sd::ops::clipbynorm op;
|
||||
auto result = op.evaluate({&x}, {4.0}, {});
|
||||
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests3, Test_ClipByNorm_2) {
|
||||
auto x= NDArrayFactory::create<double>('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f});
|
||||
auto exp= NDArrayFactory::create<double>('c', {2, 3}, {-3.0f, 0.0f, 0.0f, 4.0f, 0.0f, 0.0f});
|
||||
|
||||
sd::ops::clipbynorm op;
|
||||
auto result = op.evaluate({&x}, {6.0}, {});
|
||||
|
||||
auto z = result.at(0);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests3, Test_ClipByNorm_3) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {3, 5});
|
||||
auto unities = NDArrayFactory::create<double>('c', {3, 1}, {1., 1., 1.});
|
||||
auto scale = NDArrayFactory::create<double>('c', {3, 1}, {1.1, 1., 0.9});
|
||||
|
||||
x.linspace(100.);
|
||||
|
||||
auto xNorm1 = x.reduceAlongDimension(reduce::Norm2, {1}, true);
|
||||
x /= xNorm1;
|
||||
xNorm1 = x.reduceAlongDimension(reduce::Norm2,{1}, true);
|
||||
|
||||
ASSERT_TRUE(unities.isSameShape(xNorm1));
|
||||
ASSERT_TRUE(unities.equalsTo(xNorm1));
|
||||
|
||||
x *= scale;
|
||||
xNorm1 = x.reduceAlongDimension(reduce::Norm2, {1}, true);
|
||||
|
||||
sd::ops::clipbynorm op;
|
||||
auto result = op.evaluate({&x}, {1.0}, {1});
|
||||
auto z = result.at(0);
|
||||
|
||||
auto zNorm1 = z->reduceAlongDimension(reduce::Norm2, {1}, true);
|
||||
auto exp = NDArrayFactory::create<double>('c', {3, 1}, {1., 1., xNorm1.e<double>(2)});
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(&zNorm1));
|
||||
ASSERT_TRUE(exp.equalsTo(&zNorm1));
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests3, Test_ListDiff_1) {
|
||||
auto x= NDArrayFactory::create<float>('c', {6}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f});
|
||||
auto y= NDArrayFactory::create<float>('c', {3}, {1.f, 3.f, 5.f});
|
||||
|
@ -551,7 +463,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_1) {
|
|||
}
|
||||
|
||||
delete exp;
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_2) {
|
||||
|
@ -579,7 +491,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_2) {
|
|||
}
|
||||
|
||||
delete exp;
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_3) {
|
||||
|
@ -607,7 +519,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_3) {
|
|||
}
|
||||
|
||||
delete exp;
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_4) {
|
||||
|
@ -635,7 +547,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_4) {
|
|||
}
|
||||
|
||||
delete exp;
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_5) {
|
||||
|
@ -663,7 +575,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_5) {
|
|||
}
|
||||
|
||||
delete exp;
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -692,7 +604,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_6) {
|
|||
}
|
||||
|
||||
delete exp;
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_7) {
|
||||
|
@ -722,7 +634,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_7) {
|
|||
}
|
||||
|
||||
delete exp;
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_Validation_1) {
|
||||
|
@ -734,7 +646,7 @@ TEST_F(DeclarableOpsTests3, Test_Batched_Gemm_Validation_1) {
|
|||
sd::ops::batched_gemm op;
|
||||
try {
|
||||
auto result = op.evaluate({&a, &b, &x, &x, &x, &y, &y, &y}, {}, {112, 112, 2, 3, 5, 5, 3, 2, 3});
|
||||
|
||||
|
||||
ASSERT_TRUE(false);
|
||||
} catch (std::invalid_argument &e) {
|
||||
//
|
||||
|
@ -875,7 +787,7 @@ TEST_F(DeclarableOpsTests3, sruCell_test3) {
|
|||
ASSERT_TRUE(expCt.isSameShape(ct));
|
||||
ASSERT_TRUE(expCt.equalsTo(ct));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
|
@ -946,7 +858,7 @@ TEST_F(DeclarableOpsTests3, gruCell_test2) {
|
|||
ASSERT_TRUE(expHt.isSameShape(ht));
|
||||
ASSERT_TRUE(expHt.equalsTo(ht));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
|
@ -1001,7 +913,7 @@ TEST_F(DeclarableOpsTests3, invertPermutation_test1) {
|
|||
ASSERT_TRUE(expected.isSameShape(output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
@ -1021,7 +933,7 @@ TEST_F(DeclarableOpsTests3, invertPermutation_test2) {
|
|||
ASSERT_TRUE(expected.isSameShape(output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
@ -1099,7 +1011,7 @@ TEST_F(DeclarableOpsTests3, diag_test_vector) {
|
|||
|
||||
ASSERT_TRUE(expected.isSameShape(output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
|
||||
delete input;
|
||||
}
|
||||
|
||||
|
@ -1120,7 +1032,7 @@ TEST_F(DeclarableOpsTests3, diag_test_col_vector) {
|
|||
ASSERT_TRUE(expected.isSameShape(output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
|
||||
|
||||
delete input;
|
||||
}
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
@ -1245,7 +1157,7 @@ TEST_F(DeclarableOpsTests3, matrixSetDiag_test2) {
|
|||
ASSERT_TRUE(expected.isSameShape(output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
@ -1551,7 +1463,7 @@ TEST_F(DeclarableOpsTests3, betainc_test8) {
|
|||
ASSERT_TRUE(expected.isSameShape(output));
|
||||
ASSERT_TRUE(expected.equalsTo(output, 1e-6));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
@ -1576,7 +1488,7 @@ TEST_F(DeclarableOpsTests3, betainc_test9) {
|
|||
|
||||
ASSERT_TRUE(expected.isSameShape(output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
@ -1642,7 +1554,7 @@ TEST_F(DeclarableOpsTests3, betainc_test12) {
|
|||
|
||||
ASSERT_TRUE(expected.isSameShape(output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
@ -1689,7 +1601,7 @@ TEST_F(DeclarableOpsTests3, zeta_test2) {
|
|||
ASSERT_TRUE(expected.isSameShape(output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
@ -1831,7 +1743,7 @@ TEST_F(DeclarableOpsTests3, zeta_test8) {
|
|||
ASSERT_TRUE(expected.isSameShape(output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
@ -1856,7 +1768,7 @@ TEST_F(DeclarableOpsTests3, zeta_test9) {
|
|||
ASSERT_TRUE(expected.isSameShape(z));
|
||||
ASSERT_TRUE(expected.equalsTo(z));
|
||||
|
||||
//
|
||||
//
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
@ -1881,7 +1793,7 @@ TEST_F(DeclarableOpsTests3, zeta_test10) {
|
|||
ASSERT_TRUE(expected.isSameShape(z));
|
||||
ASSERT_TRUE(expected.equalsTo(z));
|
||||
|
||||
//
|
||||
//
|
||||
}
|
||||
|
||||
|
||||
|
@ -1908,7 +1820,7 @@ TEST_F(DeclarableOpsTests3, polygamma_test1) {
|
|||
x.assign(0.5);
|
||||
|
||||
auto expected= NDArrayFactory::create<double>('c', {3,3}, {4.934802, -16.828796, 97.409088, -771.474243, 7691.113770, -92203.460938, 1290440.250000, -20644900.000000, 3.71595e+08});
|
||||
|
||||
|
||||
sd::ops::polygamma op;
|
||||
auto result = op.evaluate({&n, &x}, {}, {});
|
||||
|
||||
|
@ -1920,7 +1832,7 @@ TEST_F(DeclarableOpsTests3, polygamma_test1) {
|
|||
ASSERT_TRUE(expected.isSameShape(output));
|
||||
ASSERT_TRUE(expected.equalsTo(output));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
@ -2263,7 +2175,7 @@ TEST_F(DeclarableOpsTests3, svd_test7) {
|
|||
|
||||
ASSERT_TRUE(expS.equalsTo(s));
|
||||
ASSERT_TRUE(expS.isSameShape(s));
|
||||
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
@ -2416,7 +2328,7 @@ TEST_F(DeclarableOpsTests3, svd_test7) {
|
|||
// ASSERT_NEAR(sd::math::nd4j_abs(expV.e<float>(i)), sd::math::nd4j_abs(v->e<float>(i)), 1e-5);
|
||||
// }
|
||||
|
||||
//
|
||||
//
|
||||
// }
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -57,7 +57,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_1) {
|
|||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_2) {
|
||||
|
@ -78,7 +78,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_2) {
|
|||
|
||||
ASSERT_EQ(exp, *z);
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_3) {
|
||||
|
@ -100,7 +100,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_3) {
|
|||
ASSERT_TRUE(z->isEmpty());
|
||||
//ASSERT_EQ(exp, *z);
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_4) {
|
||||
|
@ -122,7 +122,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_4) {
|
|||
ASSERT_TRUE(z->equalsTo(exp));
|
||||
//ASSERT_EQ(exp, *z);
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_04) {
|
||||
|
@ -185,7 +185,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_5) {
|
|||
auto z = result.at(0);
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_6) {
|
||||
|
@ -205,7 +205,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_6) {
|
|||
auto z = result.at(0);
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_7) {
|
||||
|
@ -226,7 +226,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_Once_Again_7) {
|
|||
auto z = result.at(0);
|
||||
//ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_1) {
|
||||
|
@ -248,7 +248,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_1) {
|
|||
auto z = result.at(0);
|
||||
//ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_2) {
|
||||
|
@ -270,7 +270,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_2) {
|
|||
auto z = result.at(0);
|
||||
//ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_3) {
|
||||
|
@ -292,7 +292,7 @@ TEST_F(DeclarableOpsTests6, Test_StridedSlice_BP_3) {
|
|||
auto z = result.at(0);
|
||||
//ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) {
|
||||
|
@ -309,7 +309,7 @@ TEST_F(DeclarableOpsTests6, Test_Simple_Scalar_1) {
|
|||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests6, Test_Order_1) {
|
||||
|
@ -326,7 +326,7 @@ TEST_F(DeclarableOpsTests6, Test_Order_1) {
|
|||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
ASSERT_NE(x.ordering(), z->ordering());
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests6, cumSum_1) {
|
||||
|
@ -342,7 +342,7 @@ TEST_F(DeclarableOpsTests6, cumSum_1) {
|
|||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests6, cumSum_2) {
|
||||
|
@ -359,7 +359,7 @@ TEST_F(DeclarableOpsTests6, cumSum_2) {
|
|||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests6, cumSum_3) {
|
||||
|
@ -375,7 +375,7 @@ TEST_F(DeclarableOpsTests6, cumSum_3) {
|
|||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests6, cumSum_4) {
|
||||
|
@ -391,7 +391,7 @@ TEST_F(DeclarableOpsTests6, cumSum_4) {
|
|||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests6, cumSum_5) {
|
||||
|
@ -406,7 +406,7 @@ TEST_F(DeclarableOpsTests6, cumSum_5) {
|
|||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests6, cumSum_6) {
|
||||
|
@ -421,7 +421,7 @@ TEST_F(DeclarableOpsTests6, cumSum_6) {
|
|||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests6, cumSum_7) {
|
||||
|
@ -436,7 +436,7 @@ TEST_F(DeclarableOpsTests6, cumSum_7) {
|
|||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests6, cumSum_8) {
|
||||
|
@ -452,7 +452,7 @@ TEST_F(DeclarableOpsTests6, cumSum_8) {
|
|||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -477,7 +477,7 @@ TEST_F(DeclarableOpsTests6, cumSum_9) {
|
|||
ASSERT_EQ(Status::OK(), result.status());
|
||||
auto z = result.at(0);
|
||||
ASSERT_TRUE(expFF.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
//************************************//
|
||||
exclusive = 1; reverse = 0;
|
||||
|
@ -486,7 +486,7 @@ TEST_F(DeclarableOpsTests6, cumSum_9) {
|
|||
ASSERT_EQ(Status::OK(), result.status());
|
||||
z = result.at(0);
|
||||
ASSERT_TRUE(expTF.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
//************************************//
|
||||
exclusive = 0; reverse = 1;
|
||||
|
@ -495,7 +495,7 @@ TEST_F(DeclarableOpsTests6, cumSum_9) {
|
|||
ASSERT_EQ(Status::OK(), result.status());
|
||||
z = result.at(0);
|
||||
ASSERT_TRUE(expFT.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
//************************************//
|
||||
exclusive = 1; reverse = 1;
|
||||
|
@ -504,7 +504,7 @@ TEST_F(DeclarableOpsTests6, cumSum_9) {
|
|||
ASSERT_EQ(Status::OK(), result.status());
|
||||
z = result.at(0);
|
||||
ASSERT_TRUE(expTT.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
@ -517,7 +517,7 @@ TEST_F(DeclarableOpsTests6, cumSum_10) {
|
|||
auto result = op.evaluate({&x, &y}, {}, {1, 1});
|
||||
ASSERT_EQ(Status::OK(), result.status());
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -536,7 +536,7 @@ TEST_F(DeclarableOpsTests6, cumSum_11) {
|
|||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -555,7 +555,7 @@ TEST_F(DeclarableOpsTests6, cumSum_12) {
|
|||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -574,7 +574,7 @@ TEST_F(DeclarableOpsTests6, cumSum_13) {
|
|||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -593,7 +593,7 @@ TEST_F(DeclarableOpsTests6, cumSum_14) {
|
|||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -612,7 +612,7 @@ TEST_F(DeclarableOpsTests6, cumSum_15) {
|
|||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -631,7 +631,7 @@ TEST_F(DeclarableOpsTests6, cumSum_16) {
|
|||
ASSERT_TRUE(z->ews() == 1);
|
||||
ASSERT_TRUE(x.ews() == 1);
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -664,7 +664,7 @@ TEST_F(DeclarableOpsTests6, cumSum_17) {
|
|||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -697,7 +697,7 @@ TEST_F(DeclarableOpsTests6, cumSum_18) {
|
|||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -731,7 +731,7 @@ TEST_F(DeclarableOpsTests6, cumSum_19) {
|
|||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -764,7 +764,7 @@ TEST_F(DeclarableOpsTests6, cumSum_20) {
|
|||
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -779,30 +779,40 @@ TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_1) {
|
|||
auto res = op.evaluate({&x, &y, &z}, {}, {}, {});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, res.status());
|
||||
// res.at(0)->printIndexedBuffer("MergeMaxIndex Result is ");
|
||||
// res.at(0)->printShapeInfo("Shape info for MergeMaxIdex");
|
||||
// x.printIndexedBuffer("Input is");
|
||||
ASSERT_TRUE(res.at(0)->equalsTo(exp));
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_2) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f});
|
||||
auto y = NDArrayFactory::create<double>('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f});
|
||||
auto z = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 20.f, 3.f, 40.f, 5.f, 60.f, 7.f, 80.f});
|
||||
auto exp = NDArrayFactory::create<Nd4jLong>('c', {2, 2, 2}, {1, 2, 1, 2, 1, 2, 1, 2});
|
||||
auto x = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 2.f, 3.f, 4.f, 5.f, 60.f, 7.f, 8.f});
|
||||
auto y = NDArrayFactory::create<double>('c', {2, 2, 2}, {10.f, 2.f, 30.f, 4.f, 50.f, 6.f, 70.f, 8.f});
|
||||
auto z = NDArrayFactory::create<double>('c', {2, 2, 2}, {1.f, 20.f, 3.f, 40.f, 5.f, 6.f, 7.f, 80.f});
|
||||
auto exp = NDArrayFactory::create<Nd4jLong>('c', {2, 2, 2}, {1, 2, 1, 2, 1, 0, 1, 2});
|
||||
sd::ops::mergemaxindex op;
|
||||
|
||||
auto ress = op.evaluate({&x, &y, &z}, {}, {sd::DataType::INT64});
|
||||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, ress.status());
|
||||
// res.at(0)->printIndexedBuffer("MergeMaxIndex2 Result is ");
|
||||
// res.at(0)->printShapeInfo("Shape info for MergeMaxIdex2");
|
||||
// x.printIndexedBuffer("Input is");
|
||||
ASSERT_TRUE(ress.at(0)->equalsTo(exp));
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests6, TestMergeMaxIndex_3) {
|
||||
|
||||
auto x1 = NDArrayFactory::create<double>('c', {3}, {1.f, 0.f, 0.f});
|
||||
auto x2 = NDArrayFactory::create<double>('c', {3}, {0.f, 1.f, 0.f});
|
||||
auto x3 = NDArrayFactory::create<double>('c', {3}, {0.f, 0.f, 1.f});
|
||||
NDArray z('c', {3}, sd::DataType::INT32);
|
||||
NDArray expZ('c', {3}, {0, 1, 2}, sd::DataType::INT32);
|
||||
|
||||
sd::ops::mergemaxindex op;
|
||||
auto result = op.execute({&x1, &x2, &x3}, {&z}, {}, {}, {});
|
||||
|
||||
ASSERT_EQ(Status::OK(), result);
|
||||
ASSERT_TRUE(z.equalsTo(expZ));
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -818,7 +828,7 @@ TEST_F(DeclarableOpsTests6, TestDropout_1) {
|
|||
//res.at(0)->printIndexedBuffer("Result is ");
|
||||
//x.printIndexedBuffer("Input is");
|
||||
|
||||
|
||||
|
||||
}
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests6, TestMod_1) {
|
||||
|
@ -834,7 +844,7 @@ TEST_F(DeclarableOpsTests6, TestMod_1) {
|
|||
// res.at(0)->printIndexedBuffer("MOD Result is ");
|
||||
// x.printIndexedBuffer("Input is");
|
||||
ASSERT_TRUE(res.at(0)->equalsTo(exp));
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -853,7 +863,7 @@ TEST_F(DeclarableOpsTests6, TestMod_BP_1) {
|
|||
|
||||
// x.printIndexedBuffer("Input is");
|
||||
ASSERT_TRUE(res.at(0)->equalsTo(exp));
|
||||
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -870,7 +880,7 @@ TEST_F(DeclarableOpsTests6, TestRank_1) {
|
|||
ASSERT_EQ(ND4J_STATUS_OK, res.status());
|
||||
|
||||
ASSERT_TRUE(res.at(0)->equalsTo(exp));
|
||||
|
||||
|
||||
}
|
||||
TEST_F(DeclarableOpsTests6, TestDropout_2) {
|
||||
// auto x0 = NDArrayFactory::create<double>('c', {10, 10});
|
||||
|
@ -883,7 +893,7 @@ TEST_F(DeclarableOpsTests6, TestDropout_2) {
|
|||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, res.status());
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests6, TestDropout_3) {
|
||||
|
@ -898,7 +908,7 @@ TEST_F(DeclarableOpsTests6, TestDropout_3) {
|
|||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, res.status());
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -922,7 +932,7 @@ TEST_F(DeclarableOpsTests6, MaxPoolWithArgmax_1) {
|
|||
|
||||
ASSERT_TRUE(expI.equalsTo(res.at(1)));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -947,7 +957,7 @@ TEST_F(DeclarableOpsTests6, SufficientStatistics_1) {
|
|||
ASSERT_TRUE(sumExp.equalsTo(res.at(1)));
|
||||
ASSERT_TRUE(sqrExp.equalsTo(res.at(2)));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -979,7 +989,7 @@ TEST_F(DeclarableOpsTests6, SufficientStatistics_2) {
|
|||
ASSERT_TRUE(sumExp.equalsTo(res.at(1)));
|
||||
ASSERT_TRUE(sqrExp.equalsTo(res.at(2)));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1270,7 +1280,7 @@ TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_1) {
|
|||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
// ASSERT_TRUE(expNorm.equalsTo(norm));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1310,7 +1320,7 @@ TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_2) {
|
|||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
ASSERT_TRUE(exp.equalsTo(y));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1344,7 +1354,7 @@ TEST_F(DeclarableOpsTests6, ClipByGlobalNorm_3) {
|
|||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
ASSERT_TRUE(exp.equalsTo(y));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1365,7 +1375,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_1) {
|
|||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1386,7 +1396,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_2) {
|
|||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1407,7 +1417,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_3) {
|
|||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1428,7 +1438,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_4) {
|
|||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1452,7 +1462,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_5) {
|
|||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1477,7 +1487,7 @@ TEST_F(DeclarableOpsTests6, MatrixDeterminant_6) {
|
|||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1496,7 +1506,7 @@ TEST_F(DeclarableOpsTests6, LogMatrixDeterminant_1) {
|
|||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1514,7 +1524,7 @@ TEST_F(DeclarableOpsTests6, LogDet_1) {
|
|||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1533,7 +1543,7 @@ TEST_F(DeclarableOpsTests6, LogDet_2) {
|
|||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1552,7 +1562,7 @@ TEST_F(DeclarableOpsTests6, LogDet_3) {
|
|||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1596,7 +1606,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_1) {
|
|||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1615,7 +1625,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_010) {
|
|||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1634,7 +1644,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_01) {
|
|||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1653,7 +1663,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_02) {
|
|||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1700,7 +1710,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_2) {
|
|||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
*/
|
||||
TEST_F(DeclarableOpsTests6, MatrixInverse_03) {
|
||||
|
@ -1733,7 +1743,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_03) {
|
|||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1767,7 +1777,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_3) {
|
|||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1801,7 +1811,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_4) {
|
|||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1835,7 +1845,7 @@ TEST_F(DeclarableOpsTests6, MatrixInverse_04) {
|
|||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -1864,7 +1874,7 @@ TEST_F(DeclarableOpsTests6, ReluLayer_1) {
|
|||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests6, Test_Reduce3_Edge) {
|
||||
|
@ -1917,7 +1927,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test1) {
|
|||
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
|
||||
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
@ -1960,7 +1970,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test2) {
|
|||
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
|
||||
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
@ -2003,7 +2013,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test3) {
|
|||
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
|
||||
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
@ -2045,7 +2055,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test4) {
|
|||
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
|
||||
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
@ -2087,7 +2097,7 @@ TEST_F(DeclarableOpsTests6, static_rnn_test5) {
|
|||
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
|
||||
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
@ -2141,7 +2151,7 @@ TEST_F(DeclarableOpsTests6, static_bidir_rnn_test1) {
|
|||
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
|
||||
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
@ -2194,7 +2204,7 @@ TEST_F(DeclarableOpsTests6, static_bidir_rnn_test2) {
|
|||
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
|
||||
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -2247,7 +2257,7 @@ TEST_F(DeclarableOpsTests6, static_bidir_rnn_test3) {
|
|||
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
|
||||
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
@ -2290,7 +2300,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test1) {
|
|||
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
|
||||
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -2335,7 +2345,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test2) {
|
|||
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
|
||||
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
@ -2377,7 +2387,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test3) {
|
|||
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
|
||||
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
@ -2418,7 +2428,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test4) {
|
|||
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
|
||||
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
@ -2459,7 +2469,7 @@ TEST_F(DeclarableOpsTests6, dynamic_rnn_test5) {
|
|||
ASSERT_TRUE(expHFinal.isSameShape(hFinal));
|
||||
ASSERT_TRUE(expHFinal.equalsTo(hFinal));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
@ -2521,7 +2531,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test1) {
|
|||
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
|
||||
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
@ -2581,7 +2591,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test2) {
|
|||
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
|
||||
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
@ -2637,7 +2647,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test3) {
|
|||
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
|
||||
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
|
@ -2696,7 +2706,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test4) {
|
|||
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
|
||||
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test5) {
|
||||
|
@ -2749,7 +2759,7 @@ TEST_F(DeclarableOpsTests6, dynamic_bidir_rnn_test5) {
|
|||
ASSERT_TRUE(expHBWfinal.isSameShape(hBWfinal));
|
||||
ASSERT_TRUE(expHBWfinal.equalsTo(hBWfinal));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -2763,7 +2773,7 @@ TEST_F(DeclarableOpsTests6, Test_Diag_119_1) {
|
|||
|
||||
ASSERT_EQ(e, *result.at(0));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests6, Test_Diag_119_2) {
|
||||
|
@ -2776,7 +2786,7 @@ TEST_F(DeclarableOpsTests6, Test_Diag_119_2) {
|
|||
|
||||
ASSERT_EQ(e, *result.at(0));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests6, Test_Diag_119_3) {
|
||||
|
@ -2789,7 +2799,7 @@ TEST_F(DeclarableOpsTests6, Test_Diag_119_3) {
|
|||
|
||||
ASSERT_EQ(e, *result.at(0));
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -236,10 +236,10 @@ TEST_F(DeclarableOpsTests9, ScalarOpTest_MixedOrders_1) {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, concat_test1) {
|
||||
|
||||
auto x0 = NDArrayFactory::create<double>('c', {2,3,4});
|
||||
auto x1 = NDArrayFactory::create<double>('c', {2,2,4});
|
||||
auto x2 = NDArrayFactory::create<double>('c', {2,1,4});
|
||||
auto exp = NDArrayFactory::create<double>('c', {2,6,4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f,
|
||||
auto x0 = NDArrayFactory::create<float>('c', {2,3,4});
|
||||
auto x1 = NDArrayFactory::create<float>('c', {2,2,4});
|
||||
auto x2 = NDArrayFactory::create<float>('c', {2,1,4});
|
||||
auto exp = NDArrayFactory::create<float>('c', {2,6,4}, {1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f,
|
||||
13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.});
|
||||
|
||||
x0.linspace(1);
|
||||
|
@ -261,10 +261,10 @@ TEST_F(DeclarableOpsTests9, concat_test1) {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, concat_test2) {
|
||||
|
||||
auto x0 = NDArrayFactory::create<double>('c', {1,3,1});
|
||||
auto x1 = NDArrayFactory::create<double>('c', {1,2,1});
|
||||
auto x2 = NDArrayFactory::create<double>('c', {1,1,1});
|
||||
auto exp = NDArrayFactory::create<double>('c', {1,6,1}, {1.f, 2.f, 3.f, 1.f, 2.f, 1.f});
|
||||
auto x0 = NDArrayFactory::create<float>('c', {1,3,1});
|
||||
auto x1 = NDArrayFactory::create<float>('c', {1,2,1});
|
||||
auto x2 = NDArrayFactory::create<float>('c', {1,1,1});
|
||||
auto exp = NDArrayFactory::create<float>('c', {1,6,1}, {1.f, 2.f, 3.f, 1.f, 2.f, 1.f});
|
||||
|
||||
x0.linspace(1);
|
||||
x1.linspace(1);
|
||||
|
@ -285,10 +285,10 @@ TEST_F(DeclarableOpsTests9, concat_test2) {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, concat_test3) {
|
||||
|
||||
auto x0 = NDArrayFactory::create<double>('c', {3});
|
||||
auto x1 = NDArrayFactory::create<double>('c', {2});
|
||||
auto x2 = NDArrayFactory::create<double>('c', {1});
|
||||
auto exp = NDArrayFactory::create<double>('c', {6}, {1.f, 2.f, 3.f, 1.f, 2.f, 1.f});
|
||||
auto x0 = NDArrayFactory::create<float>('c', {3});
|
||||
auto x1 = NDArrayFactory::create<float>('c', {2});
|
||||
auto x2 = NDArrayFactory::create<float>('c', {1});
|
||||
auto exp = NDArrayFactory::create<float>('c', {6}, {1.f, 2.f, 3.f, 1.f, 2.f, 1.f});
|
||||
|
||||
x0.linspace(1);
|
||||
x1.linspace(1);
|
||||
|
@ -300,21 +300,17 @@ TEST_F(DeclarableOpsTests9, concat_test3) {
|
|||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
auto output = result.at(0);
|
||||
|
||||
output->printBuffer();
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(output));
|
||||
ASSERT_TRUE(exp.equalsTo(output));
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, concat_test4) {
|
||||
|
||||
auto x0 = NDArrayFactory::create<double>('c', {1,1,1}, {1.f});
|
||||
auto x1 = NDArrayFactory::create<double>('c', {1,1,1}, {2.f});
|
||||
auto x2 = NDArrayFactory::create<double>('c', {1,1,1}, {3.f});
|
||||
auto exp = NDArrayFactory::create<double>('c', {1,3,1}, {1.f, 2.f, 3.f});
|
||||
auto x0 = NDArrayFactory::create<float>('c', {1,1,1}, {1.f});
|
||||
auto x1 = NDArrayFactory::create<float>('c', {1,1,1}, {2.f});
|
||||
auto x2 = NDArrayFactory::create<float>('c', {1,1,1}, {3.f});
|
||||
auto exp = NDArrayFactory::create<float>('c', {1,3,1}, {1.f, 2.f, 3.f});
|
||||
|
||||
sd::ops::concat op;
|
||||
|
||||
|
@ -331,10 +327,10 @@ TEST_F(DeclarableOpsTests9, concat_test4) {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, concat_test5) {
|
||||
|
||||
auto x0 = NDArrayFactory::create<double>(1.f);
|
||||
auto x1 = NDArrayFactory::create<double>('c', {1}, {2.f});
|
||||
auto x2 = NDArrayFactory::create<double>(3.f);
|
||||
auto exp = NDArrayFactory::create<double>('c', {3}, {1.f, 2.f, 3.f});
|
||||
auto x0 = NDArrayFactory::create<float>(1.f);
|
||||
auto x1 = NDArrayFactory::create<float>('c', {1}, {2.f});
|
||||
auto x2 = NDArrayFactory::create<float>(3.f);
|
||||
auto exp = NDArrayFactory::create<float>('c', {3}, {1.f, 2.f, 3.f});
|
||||
|
||||
sd::ops::concat op;
|
||||
|
||||
|
@ -351,10 +347,10 @@ TEST_F(DeclarableOpsTests9, concat_test5) {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, concat_test6) {
|
||||
|
||||
auto x0 = NDArrayFactory::create<double>(1.f);
|
||||
auto x1 = NDArrayFactory::create<double>('c', {2}, {2.f, 20.f});
|
||||
auto x2 = NDArrayFactory::create<double>(3.f);
|
||||
auto exp = NDArrayFactory::create<double>('c', {4}, {1.f, 2.f, 20.f, 3.f});
|
||||
auto x0 = NDArrayFactory::create<float>(1.f);
|
||||
auto x1 = NDArrayFactory::create<float>('c', {2}, {2.f, 20.f});
|
||||
auto x2 = NDArrayFactory::create<float>(3.f);
|
||||
auto exp = NDArrayFactory::create<float>('c', {4}, {1.f, 2.f, 20.f, 3.f});
|
||||
|
||||
sd::ops::concat op;
|
||||
|
||||
|
@ -371,10 +367,10 @@ TEST_F(DeclarableOpsTests9, concat_test6) {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, concat_test7) {
|
||||
|
||||
auto x0 = NDArrayFactory::create<double>(1.f);
|
||||
auto x1 = NDArrayFactory::create<double>(2.f);
|
||||
auto x2 = NDArrayFactory::create<double>(3.f);
|
||||
auto exp = NDArrayFactory::create<double>('c', {3}, {1.f, 2.f, 3.f});
|
||||
auto x0 = NDArrayFactory::create<float>(1.f);
|
||||
auto x1 = NDArrayFactory::create<float>(2.f);
|
||||
auto x2 = NDArrayFactory::create<float>(3.f);
|
||||
auto exp = NDArrayFactory::create<float>('c', {3}, {1.f, 2.f, 3.f});
|
||||
|
||||
sd::ops::concat op;
|
||||
|
||||
|
@ -391,8 +387,8 @@ TEST_F(DeclarableOpsTests9, concat_test7) {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, concat_test8) {
|
||||
|
||||
auto x0 = NDArrayFactory::create<double>(1.f);
|
||||
auto exp = NDArrayFactory::create<double>('c', {1}, {1.f});
|
||||
auto x0 = NDArrayFactory::create<float>(1.f);
|
||||
auto exp = NDArrayFactory::create<float>('c', {1}, {1.f});
|
||||
|
||||
sd::ops::concat op;
|
||||
|
||||
|
@ -409,8 +405,8 @@ TEST_F(DeclarableOpsTests9, concat_test8) {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, concat_test9) {
|
||||
|
||||
auto x0 = NDArrayFactory::create<double>('c', {1}, {1.f});
|
||||
auto exp = NDArrayFactory::create<double>('c', {1}, {1.f});
|
||||
auto x0 = NDArrayFactory::create<float>('c', {1}, {1.f});
|
||||
auto exp = NDArrayFactory::create<float>('c', {1}, {1.f});
|
||||
|
||||
sd::ops::concat op;
|
||||
|
||||
|
@ -427,10 +423,10 @@ TEST_F(DeclarableOpsTests9, concat_test9) {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, concat_test10) {
|
||||
|
||||
auto x0 = NDArrayFactory::create<double>('c', {2,3,4});
|
||||
auto x1 = NDArrayFactory::create<double>('f', {2,2,4});
|
||||
auto x2 = NDArrayFactory::create<double>('c', {2,1,4});
|
||||
auto exp = NDArrayFactory::create<double>('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f,
|
||||
auto x0 = NDArrayFactory::create<float>('c', {2,3,4});
|
||||
auto x1 = NDArrayFactory::create<float>('f', {2,2,4});
|
||||
auto x2 = NDArrayFactory::create<float>('c', {2,1,4});
|
||||
auto exp = NDArrayFactory::create<float>('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f,
|
||||
13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.f});
|
||||
|
||||
x0.linspace(1);
|
||||
|
@ -452,10 +448,10 @@ TEST_F(DeclarableOpsTests9, concat_test10) {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, concat_test11) {
|
||||
|
||||
auto x0 = NDArrayFactory::create<double>('c', {2,3,4});
|
||||
auto x1 = NDArrayFactory::create<double>('f', {2,2,4});
|
||||
auto x2 = NDArrayFactory::create<double>('f', {2,1,4});
|
||||
auto exp = NDArrayFactory::create<double>('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f,
|
||||
auto x0 = NDArrayFactory::create<float>('c', {2,3,4});
|
||||
auto x1 = NDArrayFactory::create<float>('f', {2,2,4});
|
||||
auto x2 = NDArrayFactory::create<float>('f', {2,1,4});
|
||||
auto exp = NDArrayFactory::create<float>('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f,
|
||||
13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.f});
|
||||
|
||||
x0.linspace(1);
|
||||
|
@ -477,10 +473,10 @@ TEST_F(DeclarableOpsTests9, concat_test11) {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, concat_test12) {
|
||||
|
||||
auto x0 = NDArrayFactory::create<double>('c', {2,3,4});
|
||||
auto x1 = NDArrayFactory::create<double>('f', {2,2,4});
|
||||
auto x2 = NDArrayFactory::create<double>('f', {2,1,4});
|
||||
auto exp = NDArrayFactory::create<double>('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f,
|
||||
auto x0 = NDArrayFactory::create<float>('c', {2,3,4});
|
||||
auto x1 = NDArrayFactory::create<float>('f', {2,2,4});
|
||||
auto x2 = NDArrayFactory::create<float>('f', {2,1,4});
|
||||
auto exp = NDArrayFactory::create<float>('c', {2,6,4}, { 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 9.f, 10.f, 11.f, 12.f, 1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f, 1.f, 2.f, 3.f, 4.f,
|
||||
13.f, 14.f, 15.f, 16.f,17.f, 18.f, 19.f, 20.f,21.f, 22.f, 23.f, 24.f, 9.f, 10.f, 11.f, 12.f,13.f, 14.f, 15.f, 16.f, 5.f, 6.f, 7.f, 8.f});
|
||||
|
||||
x0.linspace(1);
|
||||
|
@ -502,10 +498,10 @@ TEST_F(DeclarableOpsTests9, concat_test12) {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, concat_test13) {
|
||||
|
||||
auto x0 = NDArrayFactory::create<double>('f', {2,3,4});
|
||||
auto x1 = NDArrayFactory::create<double>('f', {2,2,4});
|
||||
auto x2 = NDArrayFactory::create<double>('f', {2,1,4});
|
||||
auto exp = NDArrayFactory::create<double>('f', {2,6,4}, { 1.f, 13.f, 5.f, 17.f, 9.f, 21.f, 1.f, 9.f, 5.f, 13.f, 1.f, 5.f, 2.f, 14.f, 6.f, 18.f,10.f, 22.f, 2.f, 10.f, 6.f, 14.f, 2.f, 6.f,
|
||||
auto x0 = NDArrayFactory::create<float>('f', {2,3,4});
|
||||
auto x1 = NDArrayFactory::create<float>('f', {2,2,4});
|
||||
auto x2 = NDArrayFactory::create<float>('f', {2,1,4});
|
||||
auto exp = NDArrayFactory::create<float>('f', {2,6,4}, { 1.f, 13.f, 5.f, 17.f, 9.f, 21.f, 1.f, 9.f, 5.f, 13.f, 1.f, 5.f, 2.f, 14.f, 6.f, 18.f,10.f, 22.f, 2.f, 10.f, 6.f, 14.f, 2.f, 6.f,
|
||||
3.f, 15.f, 7.f, 19.f,11.f, 23.f, 3.f, 11.f, 7.f, 15.f, 3.f, 7.f, 4.f, 16.f, 8.f, 20.f,12.f, 24.f, 4.f, 12.f, 8.f, 16.f, 4.f, 8.f});
|
||||
|
||||
x0.linspace(1);
|
||||
|
@ -527,8 +523,8 @@ TEST_F(DeclarableOpsTests9, concat_test13) {
|
|||
|
||||
TEST_F(DeclarableOpsTests9, concat_test14) {
|
||||
|
||||
NDArray x0('c', {1, 40, 60}, sd::DataType::DOUBLE);
|
||||
NDArray x1('c', {1, 40, 60}, sd::DataType::DOUBLE);
|
||||
NDArray x0('c', {1, 40, 60}, sd::DataType::FLOAT32);
|
||||
NDArray x1('c', {1, 40, 60}, sd::DataType::FLOAT32);
|
||||
|
||||
x0 = 1.;
|
||||
x1 = 2.;
|
||||
|
@ -544,7 +540,7 @@ TEST_F(DeclarableOpsTests9, concat_test14) {
|
|||
|
||||
for (int e = 0; e < numOfTads; ++e) {
|
||||
NDArray tad = (*z)(e, {0});
|
||||
auto mean = tad.meanNumber().e<double>(0);
|
||||
auto mean = tad.meanNumber().e<float>(0);
|
||||
ASSERT_NEAR((e+1)*1., mean, 1e-5);
|
||||
}
|
||||
|
||||
|
@ -552,9 +548,9 @@ TEST_F(DeclarableOpsTests9, concat_test14) {
|
|||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests9, concat_test15) {
|
||||
auto x = NDArrayFactory::create<double>('c', {2}, {1, 0});
|
||||
auto y = NDArrayFactory::create<double> (3.0f);
|
||||
auto exp = NDArrayFactory::create<double>('c', {3}, {1, 0, 3});
|
||||
auto x = NDArrayFactory::create<float>('c', {2}, {1, 0});
|
||||
auto y = NDArrayFactory::create<float> (3.0f);
|
||||
auto exp = NDArrayFactory::create<float>('c', {3}, {1, 0, 3});
|
||||
|
||||
sd::ops::concat op;
|
||||
auto result = op.evaluate({&x, &y}, {}, {0});
|
||||
|
@ -571,9 +567,9 @@ TEST_F(DeclarableOpsTests9, concat_test15) {
|
|||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, concat_test16) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {0,2,3});
|
||||
auto y = NDArrayFactory::create<double>('c', {0,2,3});
|
||||
auto exp = NDArrayFactory::create<double>('c', {0,2,3});
|
||||
auto x = NDArrayFactory::create<float>('c', {0,2,3});
|
||||
auto y = NDArrayFactory::create<float>('c', {0,2,3});
|
||||
auto exp = NDArrayFactory::create<float>('c', {0,2,3});
|
||||
|
||||
sd::ops::concat op;
|
||||
auto result = op.evaluate({&x, &y}, {}, {0});
|
||||
|
@ -587,8 +583,8 @@ TEST_F(DeclarableOpsTests9, concat_test16) {
|
|||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, concat_test17) {
|
||||
|
||||
NDArray x0('c', {1, 55, 40}, sd::DataType::DOUBLE);
|
||||
NDArray x1('c', {1, 55, 40}, sd::DataType::DOUBLE);
|
||||
NDArray x0('c', {1, 55, 40}, sd::DataType::FLOAT32);
|
||||
NDArray x1('c', {1, 55, 40}, sd::DataType::FLOAT32);
|
||||
|
||||
x0 = 1.;
|
||||
x1 = 2.;
|
||||
|
@ -606,7 +602,7 @@ TEST_F(DeclarableOpsTests9, concat_test17) {
|
|||
|
||||
for (int e = 0; e < numOfTads; ++e) {
|
||||
NDArray tad = (*z)(e, {0});
|
||||
auto mean = tad.meanNumber().e<double>(0);
|
||||
auto mean = tad.meanNumber().e<float>(0);
|
||||
ASSERT_NEAR((e+1)*1., mean, 1e-5);
|
||||
}
|
||||
}
|
||||
|
@ -664,10 +660,10 @@ TEST_F(DeclarableOpsTests9, concat_test19) {
|
|||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, concat_test20) {
|
||||
auto x0 = NDArrayFactory::create<double>('c', {1, 100, 150});
|
||||
auto x1 = NDArrayFactory::create<double>('c', {1, 100, 150});
|
||||
auto x2 = NDArrayFactory::create<double>('c', {1, 100, 150});
|
||||
auto x3 = NDArrayFactory::create<double>('c', {1, 100, 150});
|
||||
auto x0 = NDArrayFactory::create<float>('c', {1, 100, 150});
|
||||
auto x1 = NDArrayFactory::create<float>('c', {1, 100, 150});
|
||||
auto x2 = NDArrayFactory::create<float>('c', {1, 100, 150});
|
||||
auto x3 = NDArrayFactory::create<float>('c', {1, 100, 150});
|
||||
|
||||
x0.assign(1.0);
|
||||
x1.assign(2.0);
|
||||
|
@ -685,8 +681,8 @@ TEST_F(DeclarableOpsTests9, concat_test20) {
|
|||
|
||||
for (int e = 0; e < numOfTads; e++) {
|
||||
NDArray tad = (*z)(e, {0});
|
||||
auto mean = tad.meanNumber().e<double>(0);
|
||||
ASSERT_NEAR((double) e+1, mean, 1e-5);
|
||||
auto mean = tad.meanNumber().e<float>(0);
|
||||
ASSERT_NEAR((float) e+1, mean, 1e-5);
|
||||
}
|
||||
|
||||
|
||||
|
@ -710,10 +706,10 @@ TEST_F(DeclarableOpsTests9, concat_test21) {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, concat_test22) {
|
||||
|
||||
NDArray x0('c', {1,6}, {1,2,3,4,5,6});
|
||||
NDArray x1('c', {1,6}, {7,8,9,10,11,12});
|
||||
NDArray output('f', {2,6}, sd::DataType::DOUBLE);
|
||||
NDArray exp('c', {2,6}, {1,2,3,4,5,6,7,8,9,10,11,12});
|
||||
NDArray x0('c', {1,6}, {1,2,3,4,5,6}, sd::DataType::FLOAT32);
|
||||
NDArray x1('c', {1,6}, {7,8,9,10,11,12}, sd::DataType::FLOAT32);
|
||||
NDArray output('f', {2,6}, sd::DataType::FLOAT32);
|
||||
NDArray exp('c', {2,6}, {1,2,3,4,5,6,7,8,9,10,11,12}, sd::DataType::FLOAT32);
|
||||
|
||||
sd::ops::concat op;
|
||||
|
||||
|
@ -726,10 +722,10 @@ TEST_F(DeclarableOpsTests9, concat_test22) {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, concat_test23) {
|
||||
|
||||
NDArray x0('c', {1,4}, {1,2,3,4});
|
||||
NDArray x1('c', {1,4}, {5,6,7,8});
|
||||
NDArray output('c', {2,4}, sd::DataType::DOUBLE);
|
||||
NDArray exp('c', {2,4}, {1,2,3,4,5,6,7,8});
|
||||
NDArray x0('c', {1,4}, {1,2,3,4},sd::DataType::FLOAT32);
|
||||
NDArray x1('c', {1,4}, {5,6,7,8},sd::DataType::FLOAT32);
|
||||
NDArray output('c', {2,4}, sd::DataType::FLOAT32);
|
||||
NDArray exp('c', {2,4}, {1,2,3,4,5,6,7,8}, sd::DataType::FLOAT32);
|
||||
|
||||
sd::ops::concat op;
|
||||
|
||||
|
@ -741,10 +737,10 @@ TEST_F(DeclarableOpsTests9, concat_test23) {
|
|||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, concat_test24) {
|
||||
auto x = NDArrayFactory::create<double>('c', {2, 1}, {1, 1});
|
||||
auto y = NDArrayFactory::create<double>('c', {2, 1}, {0, 0});
|
||||
auto e = NDArrayFactory::create<double>('c', {2, 2}, {1, 0, 1, 0});
|
||||
auto z = NDArrayFactory::create<double>('c', {2, 2});
|
||||
auto x = NDArrayFactory::create<float>('c', {2, 1}, {1, 1});
|
||||
auto y = NDArrayFactory::create<float>('c', {2, 1}, {0, 0});
|
||||
auto e = NDArrayFactory::create<float>('c', {2, 2}, {1, 0, 1, 0});
|
||||
auto z = NDArrayFactory::create<float>('c', {2, 2});
|
||||
|
||||
sd::ops::concat op;
|
||||
auto status = op.execute({&x, &y}, {&z}, {}, {1}, {});
|
||||
|
@ -756,10 +752,10 @@ TEST_F(DeclarableOpsTests9, concat_test24) {
|
|||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, concat_test25) {
|
||||
|
||||
auto x0 = NDArrayFactory::create<double>('c', {1,4}, {1,2,3,4});
|
||||
auto x1 = NDArrayFactory::create<double>('c', {1,4}, {5,6,7,8});
|
||||
auto axis = NDArrayFactory::create<double>('c', {1}, {0.});
|
||||
auto exp = NDArrayFactory::create<double>('c', {2,4}, {1,2,3,4,5,6,7,8});
|
||||
auto x0 = NDArrayFactory::create<float>('c', {1,4}, {1,2,3,4});
|
||||
auto x1 = NDArrayFactory::create<float>('c', {1,4}, {5,6,7,8});
|
||||
auto axis = NDArrayFactory::create<float>('c', {1}, {0.});
|
||||
auto exp = NDArrayFactory::create<float>('c', {2,4}, {1,2,3,4,5,6,7,8});
|
||||
|
||||
sd::ops::concat op;
|
||||
|
||||
|
@ -793,7 +789,7 @@ TEST_F(DeclarableOpsTests9, concat_test26) {
|
|||
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
auto output = result.at(0);
|
||||
output->printLinearBuffer();
|
||||
// output->printLinearBuffer();
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(output));
|
||||
ASSERT_TRUE(exp.equalsTo(output));
|
||||
|
@ -802,10 +798,10 @@ TEST_F(DeclarableOpsTests9, concat_test26) {
|
|||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, concat_test27) {
|
||||
|
||||
auto x1 = NDArrayFactory::create<double>('c', {0,1});
|
||||
auto x2 = NDArrayFactory::create<double>('c', {0,1});
|
||||
auto x3 = NDArrayFactory::create<double>('c', {0,1});
|
||||
auto x4 = NDArrayFactory::create<double>('c', {0,1});
|
||||
auto x1 = NDArrayFactory::create<float>('c', {0,1});
|
||||
auto x2 = NDArrayFactory::create<float>('c', {0,1});
|
||||
auto x3 = NDArrayFactory::create<float>('c', {0,1});
|
||||
auto x4 = NDArrayFactory::create<float>('c', {0,1});
|
||||
|
||||
std::vector<Nd4jLong> expShape = {0, 4};
|
||||
|
||||
|
@ -1245,109 +1241,6 @@ TEST_F(DeclarableOpsTests9, test_unstack_SGO_1) {
|
|||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, clipbynorm_test12) {
|
||||
|
||||
const int bS = 5;
|
||||
const int nOut = 4;
|
||||
const int axis = 0;
|
||||
const double clip = 2.;
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.897 ,0.173 ,0.931 ,0.736 ,0.540 ,0.953 ,0.278 ,0.573 ,0.787 ,0.320 ,0.776 ,0.338 ,0.311 ,0.835 ,0.909 ,0.890 ,0.290}); // uniform random in range [0,1]
|
||||
auto colVect = NDArrayFactory::create<double>('c', {bS, 1}, {0.9, 0.95, 1.00, 1.05, 1.1});
|
||||
auto expect = NDArrayFactory::create<double>('c', {bS, nOut});
|
||||
|
||||
auto norm2 = x.reduceAlongDimension(reduce::Norm2, {axis}, true); // norm2 has shape [1, nOut]
|
||||
|
||||
auto y = ( (x / norm2) * clip) * colVect ;
|
||||
auto temp = (x / norm2) * clip;
|
||||
|
||||
for (int j = 0; j < nOut; ++j) {
|
||||
auto yCol = y({0,0, j,j+1});
|
||||
const double norm2Col = yCol.reduceNumber(reduce::Norm2).e<double>(0);
|
||||
if (norm2Col <= clip)
|
||||
expect({0,0, j,j+1}).assign(yCol);
|
||||
else
|
||||
expect({0,0, j,j+1}).assign ( yCol * (clip / norm2Col) );
|
||||
}
|
||||
|
||||
sd::ops::clipbynorm op;
|
||||
auto result = op.evaluate({&y}, {clip}, {axis});
|
||||
auto outFF = result.at(0);
|
||||
|
||||
ASSERT_TRUE(expect.isSameShape(outFF));
|
||||
ASSERT_TRUE(expect.equalsTo(outFF));
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, clipbynorm_bp_test1) {
|
||||
|
||||
const int bS = 2;
|
||||
const int nOut = 3;
|
||||
const double clip = 0.7;
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1]
|
||||
auto gradO = NDArrayFactory::create<double>('c', {bS, nOut});
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x}, {clip}, {});
|
||||
const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {});
|
||||
|
||||
sd::ops::clipbynorm opFF;
|
||||
sd::ops::clipbynorm_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, clipbynorm_bp_test2) {
|
||||
|
||||
const int bS = 2;
|
||||
const int nOut = 3;
|
||||
const int axis = 0;
|
||||
const double clip = 0.7;
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1]
|
||||
auto gradO = NDArrayFactory::create<double>('c', {bS, nOut});
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x}, {clip}, {axis});
|
||||
const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {axis});
|
||||
|
||||
sd::ops::clipbynorm opFF;
|
||||
sd::ops::clipbynorm_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, clipbynorm_bp_test3) {
|
||||
|
||||
const int bS = 2;
|
||||
const int nOut = 3;
|
||||
const int axis = 1;
|
||||
const double clip = 1.;
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {bS, nOut}, {0.412 ,0.184 ,0.961 ,0.173 ,0.736 ,0.540 }); // uniform random in range [0,1]
|
||||
auto gradO = NDArrayFactory::create<double>('c', {bS, nOut});
|
||||
|
||||
const OpArgsHolder argsHolderFF({&x}, {clip}, {axis});
|
||||
const OpArgsHolder argsHolderBP({&x, &gradO}, {clip}, {axis});
|
||||
|
||||
sd::ops::clipbynorm opFF;
|
||||
sd::ops::clipbynorm_bp opBP;
|
||||
|
||||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
||||
|
||||
ASSERT_TRUE(isGradCorrect);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests9, cumprod_1) {
|
||||
|
||||
|
|
Loading…
Reference in New Issue