DNNL/MKLDNN dilated causal conv1d + betainc (#103)
* - add padding calculation in same mode in causal conv1d op for right mkl paddings Signed-off-by: Yurii <iuriish@yahoo.com> * - correct causal condition in mkldnnUtils.cpp Signed-off-by: Yurii <iuriish@yahoo.com> * - correct some code which caused additional round errors is betainc op Signed-off-by: Yurii <iuriish@yahoo.com> * - put float in place of template parameter in nan assign in betainc op Signed-off-by: Yurii <iuriish@yahoo.com>master
parent
cb18d3d996
commit
578a5abb68
|
@ -38,6 +38,12 @@ CONFIGURABLE_OP_IMPL(betainc, 3, 1, false, 0, 0) {
|
||||||
auto b = INPUT_VARIABLE(1);
|
auto b = INPUT_VARIABLE(1);
|
||||||
auto x = INPUT_VARIABLE(2);
|
auto x = INPUT_VARIABLE(2);
|
||||||
|
|
||||||
|
// just skip op if input is empty
|
||||||
|
if (x->isEmpty()) {
|
||||||
|
*x = DataTypeUtils::nanOrZero<float>();
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
REQUIRE_TRUE(a->isSameShape(b) && a->isSameShape(x), 0, "CONFIGURABLE_OP betainc: all three input arrays must have the same shapes, bit got a=%s, b=%s and x=%s instead !", ShapeUtils::shapeAsString(a).c_str(), ShapeUtils::shapeAsString(b).c_str(), ShapeUtils::shapeAsString(x).c_str());
|
REQUIRE_TRUE(a->isSameShape(b) && a->isSameShape(x), 0, "CONFIGURABLE_OP betainc: all three input arrays must have the same shapes, bit got a=%s, b=%s and x=%s instead !", ShapeUtils::shapeAsString(a).c_str(), ShapeUtils::shapeAsString(b).c_str(), ShapeUtils::shapeAsString(x).c_str());
|
||||||
|
|
|
@ -31,61 +31,56 @@ namespace helpers {
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
// modified Lentz’s algorithm for continued fractions,
|
// modified Lentz’s algorithm for continued fractions,
|
||||||
// reference: Lentz, W.J. 1976, “Generating Bessel Functions in Mie Scattering Calculations Using Continued Fractions”
|
// reference: Lentz, W.J. 1976, “Generating Bessel Functions in Mie Scattering Calculations Using Continued Fractions”
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static T continuedFraction(const T a, const T b, const T x) {
|
static T continuedFraction(const T a, const T b, const T x) {
|
||||||
|
|
||||||
const T min = DataTypeUtils::min<T>() / DataTypeUtils::eps<T>();
|
const T min = DataTypeUtils::min<T>() / DataTypeUtils::eps<T>();
|
||||||
const T aPlusb = a + b;
|
const T aPlusb = a + b;
|
||||||
T val, delta, aPlus2i;
|
T val, aPlus2i;
|
||||||
|
|
||||||
// first iteration
|
T t2 = 1;
|
||||||
T c = 1;
|
T t1 = static_cast<T>(1) - aPlusb * x / (a + static_cast<T>(1));
|
||||||
T d = static_cast<T>(1) - aPlusb * x / (a + static_cast<T>(1));
|
if(math::nd4j_abs<T>(t1) < min)
|
||||||
if(math::nd4j_abs<T>(d) < min)
|
t1 = min;
|
||||||
d = min;
|
t1 = static_cast<T>(1) / t1;
|
||||||
d = static_cast<T>(1) / d;
|
T result = t1;
|
||||||
T f = d;
|
|
||||||
|
|
||||||
for(uint i = 1; i <= maxIter; i += 2) {
|
for(uint i = 1; i <= maxIter; ++i) {
|
||||||
|
|
||||||
aPlus2i = a + static_cast<T>(2*i);
|
aPlus2i = a + static_cast<T>(2*i);
|
||||||
|
|
||||||
/***** even part *****/
|
|
||||||
val = i * (b - i) * x / ((aPlus2i - static_cast<T>(1)) * aPlus2i);
|
val = i * (b - i) * x / ((aPlus2i - static_cast<T>(1)) * aPlus2i);
|
||||||
// d
|
// t1
|
||||||
d = static_cast<T>(1) + val * d;
|
t1 = static_cast<T>(1) + val * t1;
|
||||||
if(math::nd4j_abs<T>(d) < min)
|
if(math::nd4j_abs<T>(t1) < min)
|
||||||
d = min;
|
t1 = min;
|
||||||
d = static_cast<T>(1) / d;
|
t1 = static_cast<T>(1) / t1;
|
||||||
// c
|
// t2
|
||||||
c = static_cast<T>(1) + val / c;
|
t2 = static_cast<T>(1) + val / t2;
|
||||||
if(math::nd4j_abs<T>(c) < min)
|
if(math::nd4j_abs<T>(t2) < min)
|
||||||
c = min;
|
t2 = min;
|
||||||
// f
|
// result
|
||||||
f *= c * d;
|
result *= t2 * t1;
|
||||||
|
|
||||||
|
|
||||||
/***** odd part *****/
|
|
||||||
val = -(a + i) * (aPlusb + i) * x / ((aPlus2i + static_cast<T>(1)) * aPlus2i);
|
val = -(a + i) * (aPlusb + i) * x / ((aPlus2i + static_cast<T>(1)) * aPlus2i);
|
||||||
// d
|
// t1
|
||||||
d = static_cast<T>(1) + val * d;
|
t1 = static_cast<T>(1) + val * t1;
|
||||||
if(math::nd4j_abs<T>(d) < min)
|
if(math::nd4j_abs<T>(t1) < min)
|
||||||
d = min;
|
t1 = min;
|
||||||
d = static_cast<T>(1) / d;
|
t1 = static_cast<T>(1) / t1;
|
||||||
// c
|
// t2
|
||||||
c = static_cast<T>(1) + val / c;
|
t2 = static_cast<T>(1) + val / t2;
|
||||||
if(math::nd4j_abs<T>(c) < min)
|
if(math::nd4j_abs<T>(t2) < min)
|
||||||
c = min;
|
t2 = min;
|
||||||
// f
|
// result
|
||||||
delta = c * d;
|
val = t2 * t1;
|
||||||
f *= delta;
|
result *= val;
|
||||||
|
|
||||||
// condition to stop loop
|
// condition to stop loop
|
||||||
if(math::nd4j_abs<T>(delta - static_cast<T>(1)) <= DataTypeUtils::eps<T>())
|
if(math::nd4j_abs<T>(val - static_cast<T>(1)) <= DataTypeUtils::eps<T>())
|
||||||
return f;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
return std::numeric_limits<float>::infinity(); // no convergence, more iterations is required
|
return DataTypeUtils::infOrMax<T>(); // no convergence, more iterations is required, return infinity
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
@ -110,10 +105,9 @@ static T betaIncCore(T a, T b, T x) {
|
||||||
const T front = math::nd4j_exp<T,T>(math::nd4j_log<T, T>(x) * a + math::nd4j_log<T, T>(1.f - x) * b - gammaPart);
|
const T front = math::nd4j_exp<T,T>(math::nd4j_log<T, T>(x) * a + math::nd4j_log<T, T>(1.f - x) * b - gammaPart);
|
||||||
|
|
||||||
if (x <= (a + static_cast<T>(1)) / (a + b + static_cast<T>(2)))
|
if (x <= (a + static_cast<T>(1)) / (a + b + static_cast<T>(2)))
|
||||||
return front * continuedFraction(a, b, x) / a;
|
return front * continuedFraction<T>(a, b, x) / a;
|
||||||
else // symmetry relation
|
else // symmetry relation
|
||||||
return static_cast<T>(1) - front * continuedFraction(b, a, static_cast<T>(1) - x) / b;
|
return static_cast<T>(1) - front * continuedFraction<T>(b, a, static_cast<T>(1) - x) / b;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (t2) 2015-2018 Skymind, Inc.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -39,54 +39,50 @@ __device__ T continuedFractionCuda(const T a, const T b, const T x) {
|
||||||
|
|
||||||
const T min = DataTypeUtils::min<T>() / DataTypeUtils::eps<T>();
|
const T min = DataTypeUtils::min<T>() / DataTypeUtils::eps<T>();
|
||||||
const T aPlusb = a + b;
|
const T aPlusb = a + b;
|
||||||
T val, delta, aPlus2i;
|
T val, aPlus2i;
|
||||||
|
|
||||||
// first iteration
|
T t2 = coeffs[1];
|
||||||
T c = 1;
|
T t1 = coeffs[0];
|
||||||
T d = static_cast<T>(1) - aPlusb * x / (a + static_cast<T>(1));
|
if(math::nd4j_abs<T>(t1) < min)
|
||||||
if(math::nd4j_abs<T>(d) < min)
|
t1 = min;
|
||||||
d = min;
|
t1 = static_cast<T>(1) / t1;
|
||||||
d = static_cast<T>(1) / d;
|
T result = t1;
|
||||||
T f = d;
|
|
||||||
|
|
||||||
for(uint i = 1; i <= maxIter; i += 2) {
|
for(uint i = 1; i <= maxIter; ++i) {
|
||||||
|
|
||||||
aPlus2i = a + static_cast<T>(2*i);
|
const uint i2 = 2*i;
|
||||||
|
aPlus2i = a + static_cast<T>(i2);
|
||||||
|
|
||||||
/***** even part *****/
|
// t1
|
||||||
// d
|
t1 = static_cast<T>(1) + coeffs[i2] * t1;
|
||||||
d = static_cast<T>(1) + coeffs[i - 1] * d;
|
if(math::nd4j_abs<T>(t1) < min)
|
||||||
if(math::nd4j_abs<T>(d) < min)
|
t1 = min;
|
||||||
d = min;
|
t1 = static_cast<T>(1) / t1;
|
||||||
d = static_cast<T>(1) / d;
|
// t2
|
||||||
// c
|
t2 = static_cast<T>(1) + coeffs[i2] / t2;
|
||||||
c = static_cast<T>(1) + coeffs[i - 1] / c;
|
if(math::nd4j_abs<T>(t2) < min)
|
||||||
if(math::nd4j_abs<T>(c) < min)
|
t2 = min;
|
||||||
c = min;
|
// result
|
||||||
// f
|
result *= t2 * t1;
|
||||||
f *= c * d;
|
// t1
|
||||||
|
t1 = static_cast<T>(1) + coeffs[i2 + 1] * t1;
|
||||||
|
if(math::nd4j_abs<T>(t1) < min)
|
||||||
/***** odd part *****/
|
t1 = min;
|
||||||
// d
|
t1 = static_cast<T>(1) / t1;
|
||||||
d = static_cast<T>(1) + coeffs[i] * d;
|
// t2
|
||||||
if(math::nd4j_abs<T>(d) < min)
|
t2 = static_cast<T>(1) + coeffs[i2 + 1] / t2;
|
||||||
d = min;
|
if(math::nd4j_abs<T>(t2) < min)
|
||||||
d = static_cast<T>(1) / d;
|
t2 = min;
|
||||||
// c
|
// result
|
||||||
c = static_cast<T>(1) + coeffs[i] / c;
|
val = t2 * t1;
|
||||||
if(math::nd4j_abs<T>(c) < min)
|
result *= val;
|
||||||
c = min;
|
|
||||||
// f
|
|
||||||
delta = c * d;
|
|
||||||
f *= delta;
|
|
||||||
|
|
||||||
// condition to stop loop
|
// condition to stop loop
|
||||||
if(math::nd4j_abs<T>(delta - static_cast<T>(1)) <= DataTypeUtils::eps<T>())
|
if(math::nd4j_abs<T>(val - static_cast<T>(1)) <= DataTypeUtils::eps<T>())
|
||||||
return f;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
return 1.f / 0.f; // no convergence, more iterations is required
|
return DataTypeUtils::infOrMax<T>(); // no convergence, more iterations is required, return infinity
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
@ -112,7 +108,14 @@ __global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo,
|
||||||
b = *(reinterpret_cast<const T*>(vb) + shape::getIndexOffset(j, bShapeInfo));
|
b = *(reinterpret_cast<const T*>(vb) + shape::getIndexOffset(j, bShapeInfo));
|
||||||
x = *(reinterpret_cast<const T*>(vx) + shape::getIndexOffset(j, xShapeInfo));
|
x = *(reinterpret_cast<const T*>(vx) + shape::getIndexOffset(j, xShapeInfo));
|
||||||
|
|
||||||
symmCond = x <= (a + static_cast<T>(1)) / (a + b + static_cast<T>(2));
|
symmCond = x > (a + static_cast<T>(1)) / (a + b + static_cast<T>(2));
|
||||||
|
|
||||||
|
if(symmCond) { // swap a and b, x = 1 - x
|
||||||
|
T temp = a;
|
||||||
|
a = b;
|
||||||
|
b = temp;
|
||||||
|
x = static_cast<T>(1) - x;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
@ -124,23 +127,17 @@ __global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo,
|
||||||
}
|
}
|
||||||
|
|
||||||
if (x == static_cast<T>(0) || x == static_cast<T>(1)) {
|
if (x == static_cast<T>(0) || x == static_cast<T>(1)) {
|
||||||
z = x;
|
z = symmCond ? static_cast<T>(1) - x : x;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if(threadIdx.x % 2 == 0) { /***** even part *****/
|
// calculate two coefficients per thread
|
||||||
const int m = threadIdx.x + 1;
|
if(threadIdx.x != 0) {
|
||||||
if(symmCond)
|
|
||||||
sharedMem[threadIdx.x] = m * (b - m) * x / ((a + 2 * m - static_cast<T>(1)) * (a + 2 * m));
|
const int i = threadIdx.x;
|
||||||
else
|
const T aPlus2i = a + 2*i;
|
||||||
sharedMem[threadIdx.x] = m * (a - m) * (1.f-x) / ((b + 2 * m - static_cast<T>(1)) * (b + 2 * m));
|
sharedMem[2*i] = i * (b - i) * x / ((aPlus2i - static_cast<T>(1)) * aPlus2i);
|
||||||
}
|
sharedMem[2*i + 1] = -(a + i) * (a + b + i) * x / ((aPlus2i + static_cast<T>(1)) * aPlus2i);
|
||||||
else { /***** odd part *****/
|
|
||||||
const int m = threadIdx.x;
|
|
||||||
if(symmCond)
|
|
||||||
sharedMem[threadIdx.x] = -(a + m) * (a + b + m) * x / ((a + 2 * m + static_cast<T>(1)) * (a + 2 * m));
|
|
||||||
else
|
|
||||||
sharedMem[threadIdx.x] = -(b + m) * (a + b + m) * (1.f-x) / ((b + 2 * m + static_cast<T>(1)) * (b + 2 * m));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
@ -150,10 +147,13 @@ __global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo,
|
||||||
const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b);
|
const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b);
|
||||||
const T front = math::nd4j_exp<T,T>(math::nd4j_log<T, T>(x) * a + math::nd4j_log<T, T>(1.f - x) * b - gammaPart);
|
const T front = math::nd4j_exp<T,T>(math::nd4j_log<T, T>(x) * a + math::nd4j_log<T, T>(1.f - x) * b - gammaPart);
|
||||||
|
|
||||||
if (symmCond)
|
sharedMem[0] = static_cast<T>(1) - (a + b) * x / (a + static_cast<T>(1));
|
||||||
z = front * continuedFractionCuda(a, b, x) / a;
|
sharedMem[1] = static_cast<T>(1);
|
||||||
else // symmetry relation
|
|
||||||
z = static_cast<T>(1) - front * continuedFractionCuda(b, a, static_cast<T>(1) - x) / b;
|
z = front * continuedFractionCuda(a, b, x) / a;
|
||||||
|
|
||||||
|
if(symmCond) // symmetry relation
|
||||||
|
z = static_cast<T>(1) - z;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -174,7 +174,7 @@ void betaInc(nd4j::LaunchContext* context, const NDArray& a, const NDArray& b, c
|
||||||
|
|
||||||
const int threadsPerBlock = maxIter;
|
const int threadsPerBlock = maxIter;
|
||||||
const int blocksPerGrid = output.lengthOf();
|
const int blocksPerGrid = output.lengthOf();
|
||||||
const int sharedMem = output.sizeOfT() * threadsPerBlock + 128;
|
const int sharedMem = 2 * output.sizeOfT() * threadsPerBlock + 128;
|
||||||
|
|
||||||
const auto xType = x.dataType();
|
const auto xType = x.dataType();
|
||||||
|
|
||||||
|
|
|
@ -141,7 +141,7 @@ namespace nd4j {
|
||||||
|
|
||||||
|
|
||||||
void getMKLDNNMemoryDescConv2d(
|
void getMKLDNNMemoryDescConv2d(
|
||||||
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, bool isSameMode, bool isNCHW,
|
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const int paddingMode, bool isNCHW,
|
||||||
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
||||||
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
||||||
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
|
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
|
||||||
|
@ -154,9 +154,11 @@ namespace nd4j {
|
||||||
dnnl::memory::dims conv_bias_tz = { oC };
|
dnnl::memory::dims conv_bias_tz = { oC };
|
||||||
dnnl::memory::dims conv_dst_tz = { bS, oC, oH, oW };
|
dnnl::memory::dims conv_dst_tz = { bS, oC, oH, oW };
|
||||||
|
|
||||||
|
const int pWSame = (paddingMode == 2 && dW > 1) ? ((oW - 1) * sW + (kW - 1) * dW + 1 - iW) / 2 : pW; // dH == 1 for causal mode in conv1d
|
||||||
|
|
||||||
conv_strides = { sH, sW };
|
conv_strides = { sH, sW };
|
||||||
conv_padding = { pH, pW };
|
conv_padding = { pH, pW };
|
||||||
conv_padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pW };
|
conv_padding_r = { (oH - 1) * sH - iH + kH - pH, (oW - 1) * sW - iW + kW - pWSame };
|
||||||
conv_dilation = { dH-1, dW-1};
|
conv_dilation = { dH-1, dW-1};
|
||||||
|
|
||||||
auto type = dnnl::memory::data_type::f32;
|
auto type = dnnl::memory::data_type::f32;
|
||||||
|
@ -220,7 +222,7 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
|
|
||||||
void getMKLDNNMemoryDescConv3d(
|
void getMKLDNNMemoryDescConv3d(
|
||||||
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool isSameMode, bool isNCDHW,
|
int kD, int kH, int kW, int sD, int sH, int sW, int pD, int pH, int pW, int dD, int dH, int dW, bool paddingMode, bool isNCDHW,
|
||||||
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
int bS, int iC, int iD, int iH, int iW, int oC, int oD, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
||||||
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
||||||
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
|
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
|
||||||
|
|
|
@ -86,7 +86,7 @@ namespace nd4j{
|
||||||
* Utility methods for MKLDNN
|
* Utility methods for MKLDNN
|
||||||
*/
|
*/
|
||||||
void getMKLDNNMemoryDescConv2d(
|
void getMKLDNNMemoryDescConv2d(
|
||||||
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, bool isSameMode, bool isNCHW,
|
int kH, int kW, int sH, int sW, int pH, int pW, int dH, int dW, const int paddingMode, bool isNCHW,
|
||||||
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
int bS, int iC, int iH, int iW, int oC, int oH, int oW, const NDArray* src, const NDArray* diff_src,
|
||||||
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
const NDArray* weights, const NDArray* diff_weights, const NDArray* bias, const NDArray* dst,
|
||||||
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
|
dnnl::memory::desc* conv_src_md, dnnl::memory::desc* conv_diff_src_md, dnnl::memory::desc* conv_weights_md,
|
||||||
|
|
|
@ -1040,6 +1040,39 @@ TEST_F(ConvolutionTests1, conv1d_causal_7) {
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(ConvolutionTests1, conv1d_causal_8) {
|
||||||
|
|
||||||
|
int bS=2, iW=8, iC=3,oC=4, kW=2, sW=1, pW=0, dW=2;
|
||||||
|
int oW = (iW-1)/sW + 1;
|
||||||
|
int paddingMode = 2; // CAUSAL
|
||||||
|
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
||||||
|
|
||||||
|
NDArray input('c', {bS, iW, iC}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray weights('c', {kW, iC, oC}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray expOutput('c', {bS, oW, oC}, {11.000000, 11.600000, 12.200000, 12.800000, 26.299999, 27.799999, 29.299999, 30.799999, 45.399998, 48.399998,
|
||||||
|
51.400002, 54.400005, 65.199997, 70.000000, 74.800003, 79.600006, 85.000000, 91.600006, 98.199997, 104.800003, 104.799995, 113.199997, 121.600006,
|
||||||
|
130.000000, 124.599998, 134.800003, 145.000000, 155.200012, 144.399994, 156.399994, 168.399994, 180.400009, 133.400009, 141.199997, 149.000000,
|
||||||
|
156.800003, 148.699997, 157.400009, 166.099991, 174.800003, 203.800003, 221.200012, 238.599991, 256.000000, 223.599991, 242.799988, 262.000000,
|
||||||
|
281.200012, 243.399994, 264.399994, 285.399994, 306.399994, 263.199982, 286.000000, 308.799988, 331.600006, 283.000000, 307.600006, 332.200012,
|
||||||
|
356.800018, 302.799988, 329.199982, 355.600006, 382.000000}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
input.linspace(1., 1.);
|
||||||
|
weights.linspace(0.1, 0.1);
|
||||||
|
|
||||||
|
nd4j::ops::conv1d op;
|
||||||
|
auto results = op.execute({&input, &weights}, {}, {kW, sW, pW, dW, paddingMode, dataFormat});
|
||||||
|
auto output = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
||||||
|
ASSERT_TRUE(expOutput.isSameShape(output));
|
||||||
|
ASSERT_TRUE(expOutput.equalsTo(output));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(ConvolutionTests1, conv1d_causal_bp_1) {
|
TEST_F(ConvolutionTests1, conv1d_causal_bp_1) {
|
||||||
|
|
||||||
|
|
|
@ -1781,7 +1781,28 @@ TEST_F(DeclarableOpsTests3, betainc_test11) {
|
||||||
NDArray b('c', {4}, {0.7717f, 0.9281f, 0.9846f, 0.4838f}, nd4j::DataType::FLOAT32);
|
NDArray b('c', {4}, {0.7717f, 0.9281f, 0.9846f, 0.4838f}, nd4j::DataType::FLOAT32);
|
||||||
NDArray x('c', {4}, {0.9441f, 0.5957f, 0.8669f, 0.3502f}, nd4j::DataType::FLOAT32);
|
NDArray x('c', {4}, {0.9441f, 0.5957f, 0.8669f, 0.3502f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
NDArray expected('c', {4}, {0.912156, 0.634443, 0.898314, 0.624544}, nd4j::DataType::FLOAT32);
|
NDArray expected('c', {4}, {0.912156, 0.634460, 0.898314, 0.624538}, nd4j::DataType::FLOAT32);
|
||||||
|
nd4j::ops::betainc op;
|
||||||
|
auto results = op.execute({&a, &b, &x}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto *output = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests3, betainc_test12) {
|
||||||
|
|
||||||
|
NDArray a('c', {4}, {8.0091f, 8.2108f, 7.5194f, 3.0780f}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray b('c', {4}, {7.9456f, 9.3527f, 9.8610f, 5.3541f}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray x('c', {4}, {0.9441f, 0.5957f, 0.8669f, 0.3502f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray expected('c', {4}, {0.9999995 , 0.8594694 , 0.999988 , 0.49124345}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::betainc op;
|
nd4j::ops::betainc op;
|
||||||
auto results = op.execute({&a, &b, &x}, {}, {});
|
auto results = op.execute({&a, &b, &x}, {}, {});
|
||||||
|
|
Loading…
Reference in New Issue