DNNL/MKLDNN dilated causal conv1d + betainc (#103)

* - add padding calculation in same mode in causal conv1d op for right mkl paddings

Signed-off-by: Yurii <iuriish@yahoo.com>

* - correct causal condition in mkldnnUtils.cpp

Signed-off-by: Yurii <iuriish@yahoo.com>

* - correct some code which caused additional round errors is betainc op

Signed-off-by: Yurii <iuriish@yahoo.com>

* - put float in place of template parameter in nan assign in betainc op

Signed-off-by: Yurii <iuriish@yahoo.com>
master
Alex Black 2019-12-04 22:50:17 +11:00 committed by raver119
parent cb18d3d996
commit 578a5abb68
7 changed files with 166 additions and 110 deletions

View File

@ -38,6 +38,12 @@ CONFIGURABLE_OP_IMPL(betainc, 3, 1, false, 0, 0) {
auto b = INPUT_VARIABLE(1); auto b = INPUT_VARIABLE(1);
auto x = INPUT_VARIABLE(2); auto x = INPUT_VARIABLE(2);
// just skip op if input is empty
if (x->isEmpty()) {
*x = DataTypeUtils::nanOrZero<float>();
return Status::OK();
}
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(a->isSameShape(b) && a->isSameShape(x), 0, "CONFIGURABLE_OP betainc: all three input arrays must have the same shapes, bit got a=%s, b=%s and x=%s instead !", ShapeUtils::shapeAsString(a).c_str(), ShapeUtils::shapeAsString(b).c_str(), ShapeUtils::shapeAsString(x).c_str()); REQUIRE_TRUE(a->isSameShape(b) && a->isSameShape(x), 0, "CONFIGURABLE_OP betainc: all three input arrays must have the same shapes, bit got a=%s, b=%s and x=%s instead !", ShapeUtils::shapeAsString(a).c_str(), ShapeUtils::shapeAsString(b).c_str(), ShapeUtils::shapeAsString(x).c_str());

View File

@ -31,61 +31,56 @@ namespace helpers {
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
// modified Lentzs algorithm for continued fractions, // modified Lentzs algorithm for continued fractions,
// reference: Lentz, W.J. 1976, “Generating Bessel Functions in Mie Scattering Calculations Using Continued Fractions” // reference: Lentz, W.J. 1976, “Generating Bessel Functions in Mie Scattering Calculations Using Continued Fractions”
template <typename T> template <typename T>
static T continuedFraction(const T a, const T b, const T x) { static T continuedFraction(const T a, const T b, const T x) {
const T min = DataTypeUtils::min<T>() / DataTypeUtils::eps<T>(); const T min = DataTypeUtils::min<T>() / DataTypeUtils::eps<T>();
const T aPlusb = a + b; const T aPlusb = a + b;
T val, delta, aPlus2i; T val, aPlus2i;
// first iteration T t2 = 1;
T c = 1; T t1 = static_cast<T>(1) - aPlusb * x / (a + static_cast<T>(1));
T d = static_cast<T>(1) - aPlusb * x / (a + static_cast<T>(1)); if(math::nd4j_abs<T>(t1) < min)
if(math::nd4j_abs<T>(d) < min) t1 = min;
d = min; t1 = static_cast<T>(1) / t1;
d = static_cast<T>(1) / d; T result = t1;
T f = d;
for(uint i = 1; i <= maxIter; i += 2) { for(uint i = 1; i <= maxIter; ++i) {
aPlus2i = a + static_cast<T>(2*i); aPlus2i = a + static_cast<T>(2*i);
/***** even part *****/
val = i * (b - i) * x / ((aPlus2i - static_cast<T>(1)) * aPlus2i); val = i * (b - i) * x / ((aPlus2i - static_cast<T>(1)) * aPlus2i);
// d // t1
d = static_cast<T>(1) + val * d; t1 = static_cast<T>(1) + val * t1;
if(math::nd4j_abs<T>(d) < min) if(math::nd4j_abs<T>(t1) < min)
d = min; t1 = min;
d = static_cast<T>(1) / d; t1 = static_cast<T>(1) / t1;
// c // t2
c = static_cast<T>(1) + val / c; t2 = static_cast<T>(1) + val / t2;
if(math::nd4j_abs<T>(c) < min) if(math::nd4j_abs<T>(t2) < min)
c = min; t2 = min;
// f // result
f *= c * d; result *= t2 * t1;
/***** odd part *****/
val = -(a + i) * (aPlusb + i) * x / ((aPlus2i + static_cast<T>(1)) * aPlus2i); val = -(a + i) * (aPlusb + i) * x / ((aPlus2i + static_cast<T>(1)) * aPlus2i);
// d // t1
d = static_cast<T>(1) + val * d; t1 = static_cast<T>(1) + val * t1;
if(math::nd4j_abs<T>(d) < min) if(math::nd4j_abs<T>(t1) < min)
d = min; t1 = min;
d = static_cast<T>(1) / d; t1 = static_cast<T>(1) / t1;
// c // t2
c = static_cast<T>(1) + val / c; t2 = static_cast<T>(1) + val / t2;
if(math::nd4j_abs<T>(c) < min) if(math::nd4j_abs<T>(t2) < min)
c = min; t2 = min;
// f // result
delta = c * d; val = t2 * t1;
f *= delta; result *= val;
// condition to stop loop // condition to stop loop
if(math::nd4j_abs<T>(delta - static_cast<T>(1)) <= DataTypeUtils::eps<T>()) if(math::nd4j_abs<T>(val - static_cast<T>(1)) <= DataTypeUtils::eps<T>())
return f; return result;
} }
return std::numeric_limits<float>::infinity(); // no convergence, more iterations is required return DataTypeUtils::infOrMax<T>(); // no convergence, more iterations is required, return infinity
} }
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
@ -110,10 +105,9 @@ static T betaIncCore(T a, T b, T x) {
const T front = math::nd4j_exp<T,T>(math::nd4j_log<T, T>(x) * a + math::nd4j_log<T, T>(1.f - x) * b - gammaPart); const T front = math::nd4j_exp<T,T>(math::nd4j_log<T, T>(x) * a + math::nd4j_log<T, T>(1.f - x) * b - gammaPart);
if (x <= (a + static_cast<T>(1)) / (a + b + static_cast<T>(2))) if (x <= (a + static_cast<T>(1)) / (a + b + static_cast<T>(2)))
return front * continuedFraction(a, b, x) / a; return front * continuedFraction<T>(a, b, x) / a;
else // symmetry relation else // symmetry relation
return static_cast<T>(1) - front * continuedFraction(b, a, static_cast<T>(1) - x) / b; return static_cast<T>(1) - front * continuedFraction<T>(b, a, static_cast<T>(1) - x) / b;
} }
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////

View File

@ -1,5 +1,5 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (t2) 2015-2018 Skymind, Inc.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -39,54 +39,50 @@ __device__ T continuedFractionCuda(const T a, const T b, const T x) {
const T min = DataTypeUtils::min<T>() / DataTypeUtils::eps<T>(); const T min = DataTypeUtils::min<T>() / DataTypeUtils::eps<T>();
const T aPlusb = a + b; const T aPlusb = a + b;
T val, delta, aPlus2i; T val, aPlus2i;
// first iteration T t2 = coeffs[1];
T c = 1; T t1 = coeffs[0];
T d = static_cast<T>(1) - aPlusb * x / (a + static_cast<T>(1)); if(math::nd4j_abs<T>(t1) < min)
if(math::nd4j_abs<T>(d) < min) t1 = min;
d = min; t1 = static_cast<T>(1) / t1;
d = static_cast<T>(1) / d; T result = t1;
T f = d;
for(uint i = 1; i <= maxIter; i += 2) { for(uint i = 1; i <= maxIter; ++i) {
aPlus2i = a + static_cast<T>(2*i); const uint i2 = 2*i;
aPlus2i = a + static_cast<T>(i2);
/***** even part *****/ // t1
// d t1 = static_cast<T>(1) + coeffs[i2] * t1;
d = static_cast<T>(1) + coeffs[i - 1] * d; if(math::nd4j_abs<T>(t1) < min)
if(math::nd4j_abs<T>(d) < min) t1 = min;
d = min; t1 = static_cast<T>(1) / t1;
d = static_cast<T>(1) / d; // t2
// c t2 = static_cast<T>(1) + coeffs[i2] / t2;
c = static_cast<T>(1) + coeffs[i - 1] / c; if(math::nd4j_abs<T>(t2) < min)
if(math::nd4j_abs<T>(c) < min) t2 = min;
c = min; // result
// f result *= t2 * t1;
f *= c * d; // t1
t1 = static_cast<T>(1) + coeffs[i2 + 1] * t1;
if(math::nd4j_abs<T>(t1) < min)
/***** odd part *****/ t1 = min;
// d t1 = static_cast<T>(1) / t1;
d = static_cast<T>(1) + coeffs[i] * d; // t2
if(math::nd4j_abs<T>(d) < min) t2 = static_cast<T>(1) + coeffs[i2 + 1] / t2;
d = min; if(math::nd4j_abs<T>(t2) < min)
d = static_cast<T>(1) / d; t2 = min;
// c // result
c = static_cast<T>(1) + coeffs[i] / c; val = t2 * t1;
if(math::nd4j_abs<T>(c) < min) result *= val;
c = min;
// f
delta = c * d;
f *= delta;
// condition to stop loop // condition to stop loop
if(math::nd4j_abs<T>(delta - static_cast<T>(1)) <= DataTypeUtils::eps<T>()) if(math::nd4j_abs<T>(val - static_cast<T>(1)) <= DataTypeUtils::eps<T>())
return f; return result;
} }
return 1.f / 0.f; // no convergence, more iterations is required return DataTypeUtils::infOrMax<T>(); // no convergence, more iterations is required, return infinity
} }
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
@ -112,7 +108,14 @@ __global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo,
b = *(reinterpret_cast<const T*>(vb) + shape::getIndexOffset(j, bShapeInfo)); b = *(reinterpret_cast<const T*>(vb) + shape::getIndexOffset(j, bShapeInfo));
x = *(reinterpret_cast<const T*>(vx) + shape::getIndexOffset(j, xShapeInfo)); x = *(reinterpret_cast<const T*>(vx) + shape::getIndexOffset(j, xShapeInfo));
symmCond = x <= (a + static_cast<T>(1)) / (a + b + static_cast<T>(2)); symmCond = x > (a + static_cast<T>(1)) / (a + b + static_cast<T>(2));
if(symmCond) { // swap a and b, x = 1 - x
T temp = a;
a = b;
b = temp;
x = static_cast<T>(1) - x;
}
} }
__syncthreads(); __syncthreads();
@ -124,23 +127,17 @@ __global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo,
} }
if (x == static_cast<T>(0) || x == static_cast<T>(1)) { if (x == static_cast<T>(0) || x == static_cast<T>(1)) {
z = x; z = symmCond ? static_cast<T>(1) - x : x;
return; return;
} }
if(threadIdx.x % 2 == 0) { /***** even part *****/ // calculate two coefficients per thread
const int m = threadIdx.x + 1; if(threadIdx.x != 0) {
if(symmCond)
sharedMem[threadIdx.x] = m * (b - m) * x / ((a + 2 * m - static_cast<T>(1)) * (a + 2 * m)); const int i = threadIdx.x;
else const T aPlus2i = a + 2*i;
sharedMem[threadIdx.x] = m * (a - m) * (1.f-x) / ((b + 2 * m - static_cast<T>(1)) * (b + 2 * m)); sharedMem[2*i] = i * (b - i) * x / ((aPlus2i - static_cast<T>(1)) * aPlus2i);
} sharedMem[2*i + 1] = -(a + i) * (a + b + i) * x / ((aPlus2i + static_cast<T>(1)) * aPlus2i);
else { /***** odd part *****/
const int m = threadIdx.x;
if(symmCond)
sharedMem[threadIdx.x] = -(a + m) * (a + b + m) * x / ((a + 2 * m + static_cast<T>(1)) * (a + 2 * m));
else
sharedMem[threadIdx.x] = -(b + m) * (a + b + m) * (1.f-x) / ((b + 2 * m + static_cast<T>(1)) * (b + 2 * m));
} }
__syncthreads(); __syncthreads();
@ -150,10 +147,13 @@ __global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo,
const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b); const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b);
const T front = math::nd4j_exp<T,T>(math::nd4j_log<T, T>(x) * a + math::nd4j_log<T, T>(1.f - x) * b - gammaPart); const T front = math::nd4j_exp<T,T>(math::nd4j_log<T, T>(x) * a + math::nd4j_log<T, T>(1.f - x) * b - gammaPart);
if (symmCond) sharedMem[0] = static_cast<T>(1) - (a + b) * x / (a + static_cast<T>(1));
z = front * continuedFractionCuda(a, b, x) / a; sharedMem[1] = static_cast<T>(1);
else // symmetry relation
z = static_cast<T>(1) - front * continuedFractionCuda(b, a, static_cast<T>(1) - x) / b; z = front * continuedFractionCuda(a, b, x) / a;
if(symmCond) // symmetry relation
z = static_cast<T>(1) - z;
} }
} }
@ -174,7 +174,7 @@ void betaInc(nd4j::LaunchContext* context, const NDArray& a, const NDArray& b, c
const int threadsPerBlock = maxIter; const int threadsPerBlock = maxIter;
const int blocksPerGrid = output.lengthOf(); const int blocksPerGrid = output.lengthOf();
const int sharedMem = output.sizeOfT() * threadsPerBlock + 128; const int sharedMem = 2 * output.sizeOfT() * threadsPerBlock + 128;
const auto xType = x.dataType(); const auto xType = x.dataType();

View File

@ -141,7 +141,7 @@ namespace nd4j {
void getMKLDNNMemoryDescConv2d( void getMKLDNNMemoryDescConv2d(
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, bool isSameMode, bool isNCHW, int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const int paddingMode, bool isNCHW,
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src, int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
@ -154,9 +154,11 @@ namespace nd4j {
dnnl::memory::dims conv_bias_tz = { oC }; dnnl::memory::dims conv_bias_tz = { oC };
dnnl::memory::dims conv_dst_tz = { bS, oC, oH, oW }; dnnl::memory::dims conv_dst_tz = { bS, oC, oH, oW };
const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
conv_strides = { sH, sW }; conv_strides = { sH, sW };
conv_padding = { pH, pW }; conv_padding = { pH, pW };
conv_padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW }; conv_padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame };
conv_dilation = { dH-1, dW-1}; conv_dilation = { dH-1, dW-1};
auto type = dnnl::memory::data_type::f32; auto type = dnnl::memory::data_type::f32;
@ -220,7 +222,7 @@ namespace nd4j {
} }
void getMKLDNNMemoryDescConv3d( void getMKLDNNMemoryDescConv3d(
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW, int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool paddingMode, bool isNCDHW,
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src, int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src,
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,

View File

@ -86,7 +86,7 @@ namespace nd4j{
* Utility methods for MKLDNN * Utility methods for MKLDNN
*/ */
void getMKLDNNMemoryDescConv2d( void getMKLDNNMemoryDescConv2d(
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, bool isSameMode, bool isNCHW, int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const int paddingMode, bool isNCHW,
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src, int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,

View File

@ -1040,6 +1040,39 @@ TEST_F(ConvolutionTests1, conv1d_causal_7) {
delete results; delete results;
} }
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, conv1d_causal_8) {
int bS=2, iW=8, iC=3,oC=4, kW=2, sW=1, pW=0, dW=2;
int oW = (iW-1)/sW + 1;
int paddingMode = 2; // CAUSAL
int dataFormat = 1; // 1-NHWC, 0-NCHW
NDArray input('c', {bS, iW, iC}, nd4j::DataType::FLOAT32);
NDArray weights('c', {kW, iC, oC}, nd4j::DataType::FLOAT32);
NDArray expOutput('c', {bS, oW, oC}, {11.000000, 11.600000, 12.200000, 12.800000, 26.299999, 27.799999, 29.299999, 30.799999, 45.399998, 48.399998,
51.400002, 54.400005, 65.199997, 70.000000, 74.800003, 79.600006, 85.000000, 91.600006, 98.199997, 104.800003, 104.799995, 113.199997, 121.600006,
130.000000, 124.599998, 134.800003, 145.000000, 155.200012, 144.399994, 156.399994, 168.399994, 180.400009, 133.400009, 141.199997, 149.000000,
156.800003, 148.699997, 157.400009, 166.099991, 174.800003, 203.800003, 221.200012, 238.599991, 256.000000, 223.599991, 242.799988, 262.000000,
281.200012, 243.399994, 264.399994, 285.399994, 306.399994, 263.199982, 286.000000, 308.799988, 331.600006, 283.000000, 307.600006, 332.200012,
356.800018, 302.799988, 329.199982, 355.600006, 382.000000}, nd4j::DataType::FLOAT32);
input.linspace(1., 1.);
weights.linspace(0.1, 0.1);
nd4j::ops::conv1d op;
auto results = op.execute({&input, &weights}, {}, {kW, sW, pW, dW, paddingMode, dataFormat});
auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expOutput.isSameShape(output));
ASSERT_TRUE(expOutput.equalsTo(output));
delete results;
}
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, conv1d_causal_bp_1) { TEST_F(ConvolutionTests1, conv1d_causal_bp_1) {

View File

@ -1781,7 +1781,28 @@ TEST_F(DeclarableOpsTests3, betainc_test11) {
NDArray b('c', {4}, {0.7717f, 0.9281f, 0.9846f, 0.4838f}, nd4j::DataType::FLOAT32); NDArray b('c', {4}, {0.7717f, 0.9281f, 0.9846f, 0.4838f}, nd4j::DataType::FLOAT32);
NDArray x('c', {4}, {0.9441f, 0.5957f, 0.8669f, 0.3502f}, nd4j::DataType::FLOAT32); NDArray x('c', {4}, {0.9441f, 0.5957f, 0.8669f, 0.3502f}, nd4j::DataType::FLOAT32);
NDArray expected('c', {4}, {0.912156, 0.634443, 0.898314, 0.624544}, nd4j::DataType::FLOAT32); NDArray expected('c', {4}, {0.912156, 0.634460, 0.898314, 0.624538}, nd4j::DataType::FLOAT32);
nd4j::ops::betainc op;
auto results = op.execute({&a, &b, &x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *output = results->at(0);
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests3, betainc_test12) {
NDArray a('c', {4}, {8.0091f, 8.2108f, 7.5194f, 3.0780f}, nd4j::DataType::FLOAT32);
NDArray b('c', {4}, {7.9456f, 9.3527f, 9.8610f, 5.3541f}, nd4j::DataType::FLOAT32);
NDArray x('c', {4}, {0.9441f, 0.5957f, 0.8669f, 0.3502f}, nd4j::DataType::FLOAT32);
NDArray expected('c', {4}, {0.9999995 , 0.8594694 , 0.999988 , 0.49124345}, nd4j::DataType::FLOAT32);
nd4j::ops::betainc op; nd4j::ops::betainc op;
auto results = op.execute({&a, &b, &x}, {}, {}); auto results = op.execute({&a, &b, &x}, {}, {});