diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/betaInc.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/betaInc.cpp index 9d0a935a9..1b09bbf77 100644 --- a/libnd4j/include/ops/declarable/generic/parity_ops/betaInc.cpp +++ b/libnd4j/include/ops/declarable/generic/parity_ops/betaInc.cpp @@ -38,6 +38,12 @@ CONFIGURABLE_OP_IMPL(betainc, 3, 1, false, 0, 0) { auto b = INPUT_VARIABLE(1); auto x = INPUT_VARIABLE(2); + // just skip op if input is empty + if (x->isEmpty()) { + *x = DataTypeUtils::nanOrZero(); + return Status::OK(); + } + auto output = OUTPUT_VARIABLE(0); REQUIRE_TRUE(a->isSameShape(b) && a->isSameShape(x), 0, "CONFIGURABLE_OP betainc: all three input arrays must have the same shapes, bit got a=%s, b=%s and x=%s instead !", ShapeUtils::shapeAsString(a).c_str(), ShapeUtils::shapeAsString(b).c_str(), ShapeUtils::shapeAsString(x).c_str()); diff --git a/libnd4j/include/ops/declarable/helpers/cpu/betaInc.cpp b/libnd4j/include/ops/declarable/helpers/cpu/betaInc.cpp index 88186b62a..83cc966ba 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/betaInc.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/betaInc.cpp @@ -31,61 +31,56 @@ namespace helpers { /////////////////////////////////////////////////////////////////// // modified Lentz’s algorithm for continued fractions, // reference: Lentz, W.J. 1976, “Generating Bessel Functions in Mie Scattering Calculations Using Continued Fractions” + template static T continuedFraction(const T a, const T b, const T x) { const T min = DataTypeUtils::min() / DataTypeUtils::eps(); const T aPlusb = a + b; - T val, delta, aPlus2i; + T val, aPlus2i; - // first iteration - T c = 1; - T d = static_cast(1) - aPlusb * x / (a + static_cast(1)); - if(math::nd4j_abs(d) < min) - d = min; - d = static_cast(1) / d; - T f = d; + T t2 = 1; + T t1 = static_cast(1) - aPlusb * x / (a + static_cast(1)); + if(math::nd4j_abs(t1) < min) + t1 = min; + t1 = static_cast(1) / t1; + T result = t1; - for(uint i = 1; i <= maxIter; i += 2) { + for(uint i = 1; i <= maxIter; ++i) { aPlus2i = a + static_cast(2*i); - - /***** even part *****/ val = i * (b - i) * x / ((aPlus2i - static_cast(1)) * aPlus2i); - // d - d = static_cast(1) + val * d; - if(math::nd4j_abs(d) < min) - d = min; - d = static_cast(1) / d; - // c - c = static_cast(1) + val / c; - if(math::nd4j_abs(c) < min) - c = min; - // f - f *= c * d; - - - /***** odd part *****/ + // t1 + t1 = static_cast(1) + val * t1; + if(math::nd4j_abs(t1) < min) + t1 = min; + t1 = static_cast(1) / t1; + // t2 + t2 = static_cast(1) + val / t2; + if(math::nd4j_abs(t2) < min) + t2 = min; + // result + result *= t2 * t1; val = -(a + i) * (aPlusb + i) * x / ((aPlus2i + static_cast(1)) * aPlus2i); - // d - d = static_cast(1) + val * d; - if(math::nd4j_abs(d) < min) - d = min; - d = static_cast(1) / d; - // c - c = static_cast(1) + val / c; - if(math::nd4j_abs(c) < min) - c = min; - // f - delta = c * d; - f *= delta; + // t1 + t1 = static_cast(1) + val * t1; + if(math::nd4j_abs(t1) < min) + t1 = min; + t1 = static_cast(1) / t1; + // t2 + t2 = static_cast(1) + val / t2; + if(math::nd4j_abs(t2) < min) + t2 = min; + // result + val = t2 * t1; + result *= val; // condition to stop loop - if(math::nd4j_abs(delta - static_cast(1)) <= DataTypeUtils::eps()) - return f; + if(math::nd4j_abs(val - static_cast(1)) <= DataTypeUtils::eps()) + return result; } - return std::numeric_limits::infinity(); // no convergence, more iterations is required + return DataTypeUtils::infOrMax(); // no convergence, more iterations is required, return infinity } /////////////////////////////////////////////////////////////////// @@ -110,10 +105,9 @@ static T betaIncCore(T a, T b, T x) { const T front = math::nd4j_exp(math::nd4j_log(x) * a + math::nd4j_log(1.f - x) * b - gammaPart); if (x <= (a + static_cast(1)) / (a + b + static_cast(2))) - return front * continuedFraction(a, b, x) / a; - else // symmetry relation - return static_cast(1) - front * continuedFraction(b, a, static_cast(1) - x) / b; - + return front * continuedFraction(a, b, x) / a; + else // symmetry relation + return static_cast(1) - front * continuedFraction(b, a, static_cast(1) - x) / b; } /////////////////////////////////////////////////////////////////// diff --git a/libnd4j/include/ops/declarable/helpers/cuda/betaInc.cu b/libnd4j/include/ops/declarable/helpers/cuda/betaInc.cu index e7541a005..267ae21c2 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/betaInc.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/betaInc.cu @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (t2) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at @@ -39,54 +39,50 @@ __device__ T continuedFractionCuda(const T a, const T b, const T x) { const T min = DataTypeUtils::min() / DataTypeUtils::eps(); const T aPlusb = a + b; - T val, delta, aPlus2i; + T val, aPlus2i; - // first iteration - T c = 1; - T d = static_cast(1) - aPlusb * x / (a + static_cast(1)); - if(math::nd4j_abs(d) < min) - d = min; - d = static_cast(1) / d; - T f = d; + T t2 = coeffs[1]; + T t1 = coeffs[0]; + if(math::nd4j_abs(t1) < min) + t1 = min; + t1 = static_cast(1) / t1; + T result = t1; - for(uint i = 1; i <= maxIter; i += 2) { + for(uint i = 1; i <= maxIter; ++i) { - aPlus2i = a + static_cast(2*i); + const uint i2 = 2*i; + aPlus2i = a + static_cast(i2); - /***** even part *****/ - // d - d = static_cast(1) + coeffs[i - 1] * d; - if(math::nd4j_abs(d) < min) - d = min; - d = static_cast(1) / d; - // c - c = static_cast(1) + coeffs[i - 1] / c; - if(math::nd4j_abs(c) < min) - c = min; - // f - f *= c * d; - - - /***** odd part *****/ - // d - d = static_cast(1) + coeffs[i] * d; - if(math::nd4j_abs(d) < min) - d = min; - d = static_cast(1) / d; - // c - c = static_cast(1) + coeffs[i] / c; - if(math::nd4j_abs(c) < min) - c = min; - // f - delta = c * d; - f *= delta; + // t1 + t1 = static_cast(1) + coeffs[i2] * t1; + if(math::nd4j_abs(t1) < min) + t1 = min; + t1 = static_cast(1) / t1; + // t2 + t2 = static_cast(1) + coeffs[i2] / t2; + if(math::nd4j_abs(t2) < min) + t2 = min; + // result + result *= t2 * t1; + // t1 + t1 = static_cast(1) + coeffs[i2 + 1] * t1; + if(math::nd4j_abs(t1) < min) + t1 = min; + t1 = static_cast(1) / t1; + // t2 + t2 = static_cast(1) + coeffs[i2 + 1] / t2; + if(math::nd4j_abs(t2) < min) + t2 = min; + // result + val = t2 * t1; + result *= val; // condition to stop loop - if(math::nd4j_abs(delta - static_cast(1)) <= DataTypeUtils::eps()) - return f; + if(math::nd4j_abs(val - static_cast(1)) <= DataTypeUtils::eps()) + return result; } - return 1.f / 0.f; // no convergence, more iterations is required + return DataTypeUtils::infOrMax(); // no convergence, more iterations is required, return infinity } /////////////////////////////////////////////////////////////////// @@ -112,7 +108,14 @@ __global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo, b = *(reinterpret_cast(vb) + shape::getIndexOffset(j, bShapeInfo)); x = *(reinterpret_cast(vx) + shape::getIndexOffset(j, xShapeInfo)); - symmCond = x <= (a + static_cast(1)) / (a + b + static_cast(2)); + symmCond = x > (a + static_cast(1)) / (a + b + static_cast(2)); + + if(symmCond) { // swap a and b, x = 1 - x + T temp = a; + a = b; + b = temp; + x = static_cast(1) - x; + } } __syncthreads(); @@ -124,23 +127,17 @@ __global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo, } if (x == static_cast(0) || x == static_cast(1)) { - z = x; + z = symmCond ? static_cast(1) - x : x; return; } - if(threadIdx.x % 2 == 0) { /***** even part *****/ - const int m = threadIdx.x + 1; - if(symmCond) - sharedMem[threadIdx.x] = m * (b - m) * x / ((a + 2 * m - static_cast(1)) * (a + 2 * m)); - else - sharedMem[threadIdx.x] = m * (a - m) * (1.f-x) / ((b + 2 * m - static_cast(1)) * (b + 2 * m)); - } - else { /***** odd part *****/ - const int m = threadIdx.x; - if(symmCond) - sharedMem[threadIdx.x] = -(a + m) * (a + b + m) * x / ((a + 2 * m + static_cast(1)) * (a + 2 * m)); - else - sharedMem[threadIdx.x] = -(b + m) * (a + b + m) * (1.f-x) / ((b + 2 * m + static_cast(1)) * (b + 2 * m)); + // calculate two coefficients per thread + if(threadIdx.x != 0) { + + const int i = threadIdx.x; + const T aPlus2i = a + 2*i; + sharedMem[2*i] = i * (b - i) * x / ((aPlus2i - static_cast(1)) * aPlus2i); + sharedMem[2*i + 1] = -(a + i) * (a + b + i) * x / ((aPlus2i + static_cast(1)) * aPlus2i); } __syncthreads(); @@ -150,10 +147,13 @@ __global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo, const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b); const T front = math::nd4j_exp(math::nd4j_log(x) * a + math::nd4j_log(1.f - x) * b - gammaPart); - if (symmCond) - z = front * continuedFractionCuda(a, b, x) / a; - else // symmetry relation - z = static_cast(1) - front * continuedFractionCuda(b, a, static_cast(1) - x) / b; + sharedMem[0] = static_cast(1) - (a + b) * x / (a + static_cast(1)); + sharedMem[1] = static_cast(1); + + z = front * continuedFractionCuda(a, b, x) / a; + + if(symmCond) // symmetry relation + z = static_cast(1) - z; } } @@ -174,7 +174,7 @@ void betaInc(nd4j::LaunchContext* context, const NDArray& a, const NDArray& b, c const int threadsPerBlock = maxIter; const int blocksPerGrid = output.lengthOf(); - const int sharedMem = output.sizeOfT() * threadsPerBlock + 128; + const int sharedMem = 2 * output.sizeOfT() * threadsPerBlock + 128; const auto xType = x.dataType(); diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp index 07cf60ef9..96bbffcf8 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.cpp @@ -141,7 +141,7 @@ namespace nd4j { void getMKLDNNMemoryDescConv2d( - int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, bool isSameMode, bool isNCHW, + int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const int paddingMode, bool isNCHW, int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src, const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, @@ -154,9 +154,11 @@ namespace nd4j { dnnl::memory::dims conv_bias_tz = { oC }; dnnl::memory::dims conv_dst_tz = { bS, oC, oH, oW }; + const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d + conv_strides = { sH, sW }; conv_padding = { pH, pW }; - conv_padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW }; + conv_padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame }; conv_dilation = { dH-1, dW-1}; auto type = dnnl::memory::data_type::f32; @@ -220,7 +222,7 @@ namespace nd4j { } void getMKLDNNMemoryDescConv3d( - int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW, + int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool paddingMode, bool isNCDHW, int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src, const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h index 1f9b9e010..6274a645f 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h +++ b/libnd4j/include/ops/declarable/platform/mkldnn/mkldnnUtils.h @@ -86,7 +86,7 @@ namespace nd4j{ * Utility methods for MKLDNN */ void getMKLDNNMemoryDescConv2d( - int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, bool isSameMode, bool isNCHW, + int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const int paddingMode, bool isNCHW, int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src, const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst, dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md, diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp index 99092b37d..99cc98af9 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests1.cpp @@ -1040,6 +1040,39 @@ TEST_F(ConvolutionTests1, conv1d_causal_7) { delete results; } +////////////////////////////////////////////////////////////////////// +TEST_F(ConvolutionTests1, conv1d_causal_8) { + + int bS=2, iW=8, iC=3,oC=4, kW=2, sW=1, pW=0, dW=2; + int oW = (iW-1)/sW + 1; + int paddingMode = 2; // CAUSAL + int dataFormat = 1; // 1-NHWC, 0-NCHW + + NDArray input('c', {bS, iW, iC}, nd4j::DataType::FLOAT32); + NDArray weights('c', {kW, iC, oC}, nd4j::DataType::FLOAT32); + + NDArray expOutput('c', {bS, oW, oC}, {11.000000, 11.600000, 12.200000, 12.800000, 26.299999, 27.799999, 29.299999, 30.799999, 45.399998, 48.399998, + 51.400002, 54.400005, 65.199997, 70.000000, 74.800003, 79.600006, 85.000000, 91.600006, 98.199997, 104.800003, 104.799995, 113.199997, 121.600006, + 130.000000, 124.599998, 134.800003, 145.000000, 155.200012, 144.399994, 156.399994, 168.399994, 180.400009, 133.400009, 141.199997, 149.000000, + 156.800003, 148.699997, 157.400009, 166.099991, 174.800003, 203.800003, 221.200012, 238.599991, 256.000000, 223.599991, 242.799988, 262.000000, + 281.200012, 243.399994, 264.399994, 285.399994, 306.399994, 263.199982, 286.000000, 308.799988, 331.600006, 283.000000, 307.600006, 332.200012, + 356.800018, 302.799988, 329.199982, 355.600006, 382.000000}, nd4j::DataType::FLOAT32); + + input.linspace(1., 1.); + weights.linspace(0.1, 0.1); + + nd4j::ops::conv1d op; + auto results = op.execute({&input, &weights}, {}, {kW, sW, pW, dW, paddingMode, dataFormat}); + auto output = results->at(0); + + ASSERT_EQ(Status::OK(), results->status()); + + ASSERT_TRUE(expOutput.isSameShape(output)); + ASSERT_TRUE(expOutput.equalsTo(output)); + + delete results; +} + ////////////////////////////////////////////////////////////////////// TEST_F(ConvolutionTests1, conv1d_causal_bp_1) { diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp index a521be97b..6d224b323 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests3.cpp @@ -1781,7 +1781,28 @@ TEST_F(DeclarableOpsTests3, betainc_test11) { NDArray b('c', {4}, {0.7717f, 0.9281f, 0.9846f, 0.4838f}, nd4j::DataType::FLOAT32); NDArray x('c', {4}, {0.9441f, 0.5957f, 0.8669f, 0.3502f}, nd4j::DataType::FLOAT32); - NDArray expected('c', {4}, {0.912156, 0.634443, 0.898314, 0.624544}, nd4j::DataType::FLOAT32); + NDArray expected('c', {4}, {0.912156, 0.634460, 0.898314, 0.624538}, nd4j::DataType::FLOAT32); + nd4j::ops::betainc op; + auto results = op.execute({&a, &b, &x}, {}, {}); + + ASSERT_EQ(ND4J_STATUS_OK, results->status()); + + auto *output = results->at(0); + + ASSERT_TRUE(expected.isSameShape(output)); + ASSERT_TRUE(expected.equalsTo(output)); + + delete results; +} + +/////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests3, betainc_test12) { + + NDArray a('c', {4}, {8.0091f, 8.2108f, 7.5194f, 3.0780f}, nd4j::DataType::FLOAT32); + NDArray b('c', {4}, {7.9456f, 9.3527f, 9.8610f, 5.3541f}, nd4j::DataType::FLOAT32); + NDArray x('c', {4}, {0.9441f, 0.5957f, 0.8669f, 0.3502f}, nd4j::DataType::FLOAT32); + + NDArray expected('c', {4}, {0.9999995 , 0.8594694 , 0.999988 , 0.49124345}, nd4j::DataType::FLOAT32); nd4j::ops::betainc op; auto results = op.execute({&a, &b, &x}, {}, {});