Shyrma adjust (#98)

* - add possibility of passing scalar-array as input parameter for scale factor in adjust hue/contrast/saturation ops
- correct typo in function which calculates regularized incomplete beta integral

Signed-off-by: Yurii <iuriish@yahoo.com>

* - fix bug in betainc cuda kernel

Signed-off-by: Yurii <iuriish@yahoo.com>

* - start working on implementation of digamma function

Signed-off-by: Yurii <iuriish@yahoo.com>

* - further work on digamma function (cpu)

Signed-off-by: Yurii <iuriish@yahoo.com>

* - testing and fixing bugs in digamma op

Signed-off-by: Yurii <iuriish@yahoo.com>

* - make correction n cuda kernel for polyGamma

Signed-off-by: Yurii <iuriish@yahoo.com>

* - remove unnecessary stuff from betaInc cuda kernel

Signed-off-by: Yurii <iuriish@yahoo.com>

* - resolve conflicts in DeclarableOpsTests3.cpp after master branch has been merged

Signed-off-by: Yurii <iuriish@yahoo.com>

* - restore id number of Not opertion in legacy_ops.h

Signed-off-by: Yurii <iuriish@yahoo.com>

* - correct padding calculation in mkl dnn conv1d causal

Signed-off-by: Yurii <iuriish@yahoo.com>

* restore empty check in adjust_contrast_v2

Signed-off-by: raver119 <raver119@gmail.com>
master
Yurii Shyrma 2019-12-03 08:40:45 +02:00 committed by raver119
parent 1e9ff114aa
commit 1f5e15b541
20 changed files with 750 additions and 326 deletions

View File

@ -112,7 +112,8 @@
(4, IsInfOrNan), \ (4, IsInfOrNan), \
(5, MatchConditionBool), \ (5, MatchConditionBool), \
(6, IsPositive) , \ (6, IsPositive) , \
(7, Not) (7, Not), \
(8, IsNegative)
#define TRANSFORM_STRICT_OPS \ #define TRANSFORM_STRICT_OPS \
@ -279,7 +280,8 @@
(3, IsInfOrNan), \ (3, IsInfOrNan), \
(4, IsNan), \ (4, IsNan), \
(5, IsInf), \ (5, IsInf), \
(6, IsPositive) (6, IsPositive), \
(7, IsNegative)
#define REDUCE_SAME_OPS \ #define REDUCE_SAME_OPS \
(0, Sum), \ (0, Sum), \

View File

@ -27,7 +27,8 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, -2, 0) { ////////////////////////////////////////////////////////////////////
CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, 0, 0) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
@ -37,23 +38,31 @@ CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, -2, 0) {
return Status::OK(); return Status::OK();
REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST: Scale factor required"); REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST: Scale factor required");
const double factor = block.width() > 1 ? INPUT_VARIABLE(1)->e<double>(0) : T_ARG(0);
REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST: op expects rank of input array to be >= 3, but got %i instead", input->rankOf()); REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST: op expects rank of input array to be >= 3, but got %i instead", input->rankOf());
REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1)); REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1));
// compute mean before
NDArray* factor = nullptr;
if(block.width() > 1)
factor = INPUT_VARIABLE(1);
else {
factor = new NDArray(output->dataType(), block.launchContext());
factor->p(0, T_ARG(0));
}
// fill up axes vector first // fill up axes vector first
std::vector<int> axes(input->rankOf() - 1); std::vector<int> axes(input->rankOf() - 1);
for (auto i = 0; i < axes.size(); ++i) for (auto i = 0; i < axes.size(); ++i)
axes[i] = i; axes[i] = i;
// mean as reduction for last dimension set // mean as reduction for last dimension set
auto mean = input->reduceAlongDims(reduce::Mean, axes); auto mean = input->reduceAlongDims(reduce::Mean, axes);
NDArray factorT(output->dataType(), block.launchContext()); // = NDArrayFactory::create(factor, block.launchContext());
factorT.p(0, factor);
// this is contrast calculation // this is contrast calculation
output->assign((*input - mean) * factorT + mean); output->assign((*input - mean) * (*factor) + mean);
if(block.width() == 1)
delete factor;
return Status::OK(); return Status::OK();
} }
@ -64,22 +73,28 @@ DECLARE_TYPES(adjust_contrast) {
->setSameMode(true); ->setSameMode(true);
} }
////////////////////////////////////////////////////////////////////
CONFIGURABLE_OP_IMPL(adjust_contrast_v2, 1, 1, true, -2, 0) { CONFIGURABLE_OP_IMPL(adjust_contrast_v2, 1, 1, true, 0, 0) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST_V2: Scale factor required");
const double factor = block.width() > 1 ? INPUT_VARIABLE(1)->e<double>(0) : T_ARG(0);
// just skip op if input is empty // just skip op if input is empty
if (input->isEmpty()) if (input->isEmpty())
return Status::OK(); return Status::OK();
REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST_V2: op expects rank of input array to be >= 3, but got %i instead", input->rankOf()); REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST_V2: op expects rank of input array to be >= 3, but got %i instead", input->rankOf());
REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST_V2: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1)); REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST_V2: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1));
REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST_V2: Scale factor required");
NDArray* factor = nullptr;
if(block.width() > 1)
factor = INPUT_VARIABLE(1);
else {
factor = new NDArray(output->dataType(), block.launchContext());
factor->p(0, T_ARG(0));
}
// compute mean before // compute mean before
std::vector<int> axes(input->rankOf() - 1); std::vector<int> axes(input->rankOf() - 1);
@ -92,9 +107,12 @@ DECLARE_TYPES(adjust_contrast) {
// result as (x - mean) * factor + mean // result as (x - mean) * factor + mean
auto temp = input->ulike(); auto temp = input->ulike();
input->applyTrueBroadcast(BroadcastOpsTuple::Subtract(), &mean, &temp); input->applyTrueBroadcast(BroadcastOpsTuple::Subtract(), &mean, &temp);
temp.applyScalar(scalar::Multiply, factor); temp.applyScalarArr(scalar::Multiply, factor);
temp.applyTrueBroadcast(BroadcastOpsTuple::Add(), &mean, output); temp.applyTrueBroadcast(BroadcastOpsTuple::Add(), &mean, output);
if(block.width() == 1)
delete factor;
return Status::OK(); return Status::OK();
} }

View File

@ -24,13 +24,12 @@
#include <ops/declarable/headers/parity_ops.h> #include <ops/declarable/headers/parity_ops.h>
#include <ops/declarable/helpers/adjust_hue.h> #include <ops/declarable/helpers/adjust_hue.h>
#include <NDArrayFactory.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, 1, -2) { CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, 0, 0) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
@ -41,15 +40,26 @@ CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, 1, -2) {
const int rank = input->rankOf(); const int rank = input->rankOf();
const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
const double delta = T_ARG(0);
REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_HUE: delta factor is required !");
REQUIRE_TRUE(rank >= 3, 0, "ADJUST_HUE: op expects rank of input array to be >= 3, but got %i instead", rank); REQUIRE_TRUE(rank >= 3, 0, "ADJUST_HUE: op expects rank of input array to be >= 3, but got %i instead", rank);
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_HUE: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC)); REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_HUE: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC));
REQUIRE_TRUE(-1. <= delta && delta <= 1., 0, "ADJUST_HUE: parameter delta must be within [-1, 1] interval, but got %f instead", delta);
NDArray deltaScalarArr = NDArrayFactory::create<double>(delta, block.launchContext()); NDArray* delta = nullptr;
helpers::adjustHue(block.launchContext(), input, &deltaScalarArr, output, dimC); if(block.width() > 1)
delta = INPUT_VARIABLE(1);
else {
delta = new NDArray(output->dataType(), block.launchContext());
delta->p(0, T_ARG(0));
}
REQUIRE_TRUE(-1. <= delta->e<double>(0) && delta->e<double>(0) <= 1., 0, "ADJUST_HUE: parameter delta must be within [-1, 1] interval, but got %f instead", delta);
helpers::adjustHue(block.launchContext(), input, delta, output, dimC);
if(block.width() == 1)
delete delta;
return Status::OK(); return Status::OK();
} }

View File

@ -28,7 +28,7 @@
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, 1, -2) { CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, 0, 0) {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
@ -39,14 +39,24 @@ CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, 1, -2) {
const int rank = input->rankOf(); const int rank = input->rankOf();
const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1; const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
const double factor = T_ARG(0);
REQUIRE_TRUE(rank >= 3, 0, "ADJUST_SATURATION: op expects rank of input array to be >= 3, but got %i instead", rank); REQUIRE_TRUE(rank >= 3, 0, "ADJUST_SATURATION: op expects rank of input array to be >= 3, but got %i instead", rank);
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_SATURATION: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC)); REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_SATURATION: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC));
REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_SATURATION: scale factor is required !");
NDArray factorScalarArr = NDArrayFactory::create<double>(factor, block.launchContext()); NDArray* factor = nullptr;
helpers::adjustSaturation(block.launchContext(), input, &factorScalarArr, output, dimC); if(block.width() > 1)
factor = INPUT_VARIABLE(1);
else {
factor = new NDArray(output->dataType(), block.launchContext());
factor->p(0, T_ARG(0));
}
helpers::adjustSaturation(block.launchContext(), input, factor, output, dimC);
if(block.width() == 1)
delete factor;
return Status::OK(); return Status::OK();
} }

View File

@ -1,5 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* *
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
@ -15,27 +16,35 @@
******************************************************************************/ ******************************************************************************/
// //
// Created by Yurii Shyrma on 13.12.2017. // @author Yurii Shyrma (iuriish@yahoo.com)
// //
#ifndef LIBND4J_POLYGAMMA_H #include <op_boilerplate.h>
#define LIBND4J_POLYGAMMA_H #if NOT_EXCLUDED(OP_digamma)
#include <ops/declarable/helpers/helpers.h> #include <ops/declarable/CustomOperations.h>
#include "NDArray.h" #include <ops/declarable/helpers/gammaMathFunc.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
namespace helpers {
CONFIGURABLE_OP_IMPL(digamma, 1, 1, false, 0, 0) {
// calculate the polygamma function auto x = INPUT_VARIABLE(0);
void polyGamma(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output); auto z = OUTPUT_VARIABLE(0);
helpers::diGamma(block.launchContext(), *x, *z);
return Status::OK();
}
DECLARE_TYPES(digamma) {
getOpDescriptor()
->setAllowedInputTypes({ALL_FLOATS, ALL_INTS})
->setSameMode(true);
}
} }
} }
}
#endif
#endif //LIBND4J_POLYGAMMA_H

View File

@ -15,14 +15,14 @@
******************************************************************************/ ******************************************************************************/
// //
// @author Yurii Shyrma (iuriish@yahoo.com), created on 13.12.2017 // @author Yurii Shyrma (iuriish@yahoo.com)
// //
#include <op_boilerplate.h> #include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_polygamma) #if NOT_EXCLUDED(OP_polygamma)
#include <ops/declarable/CustomOperations.h> #include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/polyGamma.h> #include <ops/declarable/helpers/gammaMathFunc.h>
namespace nd4j { namespace nd4j {
namespace ops { namespace ops {
@ -37,11 +37,11 @@ CONFIGURABLE_OP_IMPL(polygamma, 2, 1, false, 0, 0) {
Nd4jLong arrLen = n->lengthOf(); Nd4jLong arrLen = n->lengthOf();
// FIXME: this shit should be single op call, not a loop! // FIXME: this shit should be single op call, not a loop!
auto nPositive = n->reduceNumber(nd4j::reduce::IsPositive, nullptr); auto nNegative = n->reduceNumber(nd4j::reduce::IsNegative, nullptr);
auto xPositive = x->reduceNumber(nd4j::reduce::IsPositive, nullptr); auto xPositive = x->reduceNumber(nd4j::reduce::IsPositive, nullptr);
bool nPositiveFlag = nPositive.e<bool>(0); bool nPositiveFlag = !nNegative.e<bool>(0); // require all n >= 0
bool xPositiveFlag = xPositive.e<bool>(0); bool xPositiveFlag = xPositive.e<bool>(0); // require all x > 0
REQUIRE_TRUE(nPositiveFlag, 0, "POLYGAMMA op: all elements of n array must be > 0 !"); REQUIRE_TRUE(nPositiveFlag, 0, "POLYGAMMA op: all elements of n array must be >= 0 !");
REQUIRE_TRUE(xPositiveFlag, 0, "POLYGAMMA op: all elements of x array must be > 0 !"); REQUIRE_TRUE(xPositiveFlag, 0, "POLYGAMMA op: all elements of x array must be > 0 !");
helpers::polyGamma(block.launchContext(), *n, *x, *output); helpers::polyGamma(block.launchContext(), *n, *x, *output);

View File

@ -513,7 +513,6 @@ namespace nd4j {
/** /**
* This op calculates polygamma function psi^(n)(x). Implementation is based on serial representation written in * This op calculates polygamma function psi^(n)(x). Implementation is based on serial representation written in
* terms of the Hurwitz zeta function: polygamma = (-1)^{n+1} * n! * zeta(n+1, x). * terms of the Hurwitz zeta function: polygamma = (-1)^{n+1} * n! * zeta(n+1, x).
* Currently the case n = 0 is not supported.
* *
* Input arrays: * Input arrays:
* 0: n - define derivative order (n+1), type integer (however currently is implemented as float casted to integer) * 0: n - define derivative order (n+1), type integer (however currently is implemented as float casted to integer)
@ -528,6 +527,20 @@ namespace nd4j {
DECLARE_CONFIGURABLE_OP(polygamma, 2, 1, false, 0, 0); DECLARE_CONFIGURABLE_OP(polygamma, 2, 1, false, 0, 0);
#endif #endif
/**
* This op calculates digamma function psi(x) = derivative of log(Gamma(x))
*
* Input arrays:
* 0: x - abscissa points where to evaluate the digamma function, type float
*
* Output array:
* 0: values of digamma function at corresponding x, type float
*
*/
#if NOT_EXCLUDED(OP_digamma)
DECLARE_CONFIGURABLE_OP(digamma, 1, 1, false, 0, 0);
#endif
/** /**
* This operation takes shape as first argument, and returns new NDArray filled with specific scalar value. * This operation takes shape as first argument, and returns new NDArray filled with specific scalar value.
* Input arrays: * Input arrays:
@ -575,44 +588,47 @@ namespace nd4j {
* This operation adjusts image hue by delta * This operation adjusts image hue by delta
* Input arrays: * Input arrays:
* 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels. * 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels.
* 1 - optional argument, input scalar-array containing delta
* *
* T arguments: * T arguments:
* 0 - delta value * 0 - optional argument, delta value
* *
* Int arguments: * Int arguments:
* 0 - optional argument, corresponds to dimension with 3 channels * 0 - optional argument, corresponds to dimension with 3 channels
*/ */
#if NOT_EXCLUDED(OP_adjust_hue) #if NOT_EXCLUDED(OP_adjust_hue)
DECLARE_CONFIGURABLE_OP(adjust_hue, 1, 1, true, 1, -2); DECLARE_CONFIGURABLE_OP(adjust_hue, 1, 1, true, 0, 0);
#endif #endif
/** /**
* This operation adjusts image saturation by delta * This operation adjusts image saturation by delta
* Input arrays: * Input arrays:
* 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels. * 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels.
* 1 - optional argument, input scalar-array containing saturation factor
* *
* T arguments: * T arguments:
* 0 - saturation factor * 0 - optional argument, saturation factor
* *
* Int arguments: * Int arguments:
* 0 - optional argument, corresponds to dimension with 3 channels * 0 - optional argument, corresponds to dimension with 3 channels
*/ */
#if NOT_EXCLUDED(OP_adjust_saturation) #if NOT_EXCLUDED(OP_adjust_saturation)
DECLARE_CONFIGURABLE_OP(adjust_saturation, 1, 1, true, 1, -2); DECLARE_CONFIGURABLE_OP(adjust_saturation, 1, 1, true, 0, 0);
#endif #endif
/** /**
* This operation adjusts image contrast by given factor ( z = (x - mean) * factor + mean ) * This operation adjusts image contrast by given factor ( z = (x - mean) * factor + mean )
* Input arrays: * Input arrays:
* 0 - input array with rank >= 3, must have last one dimension equal 3, that is dimension containing channels. * 0 - input array with rank >= 3, must have last one dimension equal 3, that is dimension containing channels.
* 1 - optional argument, input scalar-array containing saturation contrast factor
* *
* T arguments: * T arguments:
* 0 - contrast factor * 0 - optional argument, contrast factor
* *
*/ */
#if NOT_EXCLUDED(OP_adjust_contrast) #if NOT_EXCLUDED(OP_adjust_contrast)
DECLARE_CONFIGURABLE_OP(adjust_contrast, 1, 1, true, -2, 0); DECLARE_CONFIGURABLE_OP(adjust_contrast, 1, 1, true, 0, 0);
DECLARE_CONFIGURABLE_OP(adjust_contrast_v2, 1, 1, true, -2, 0); DECLARE_CONFIGURABLE_OP(adjust_contrast_v2, 1, 1, true, 0, 0);
#endif #endif

View File

@ -107,12 +107,12 @@ static T betaIncCore(T a, T b, T x) {
return x; return x;
const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b); const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b);
const T front = math::nd4j_exp<T,T>(math::nd4j_log<T, T>(x) * a + math::nd4j_log<T, T>(1 - x) * b - gammaPart) / a; const T front = math::nd4j_exp<T,T>(math::nd4j_log<T, T>(x) * a + math::nd4j_log<T, T>(1.f - x) * b - gammaPart);
if (x <= (a + static_cast<T>(1)) / (a + b + static_cast<T>(2))) if (x <= (a + static_cast<T>(1)) / (a + b + static_cast<T>(2)))
return front * continuedFraction(a, b, x); return front * continuedFraction(a, b, x) / a;
else // symmetry relation else // symmetry relation
return static_cast<T>(1) - front * continuedFraction(b, a, static_cast<T>(1) - x); return static_cast<T>(1) - front * continuedFraction(b, a, static_cast<T>(1) - x) / b;
} }

View File

@ -0,0 +1,53 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include<ops/declarable/helpers/gammaMathFunc.h>
#include <execution/Threads.h>
namespace nd4j {
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////////
// calculate digamma function for array elements
template <typename T>
static void diGamma_(const NDArray& x, NDArray& z) {
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment)
z.p(i, diGammaScalar<T>(x.e<T>(i)));
};
samediff::Threads::parallel_for(func, 0, x.lengthOf());
}
void diGamma(nd4j::LaunchContext* context, const NDArray& x, NDArray& z) {
BUILD_SINGLE_SELECTOR(x.dataType(), diGamma_, (x, z), FLOAT_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void diGamma_, (const NDArray& x, NDArray& z), FLOAT_TYPES);
}
}
}

View File

@ -18,7 +18,7 @@
// Created by Yurii Shyrma on 12.12.2017 // Created by Yurii Shyrma on 12.12.2017
// //
#include<ops/declarable/helpers/polyGamma.h> #include<ops/declarable/helpers/gammaMathFunc.h>
#include<ops/declarable/helpers/zeta.h> #include<ops/declarable/helpers/zeta.h>
#include <NDArrayFactory.h> #include <NDArrayFactory.h>
#include <execution/Threads.h> #include <execution/Threads.h>
@ -57,8 +57,6 @@ static FORCEINLINE T polyGammaScalar(nd4j::LaunchContext * context, const int n,
// if (x <= (T)0.) // if (x <= (T)0.)
// throw("polyGamma function: x must be > 0 !"); // throw("polyGamma function: x must be > 0 !");
// TODO case for n = 0 (digamma)
int sign = (n + 1) % 2 ? -1 : 1; int sign = (n + 1) % 2 ? -1 : 1;
// T factorial = (T)std::tgamma(n + 1); // T factorial = (T)std::tgamma(n + 1);
@ -71,17 +69,18 @@ static FORCEINLINE T polyGammaScalar(nd4j::LaunchContext * context, const int n,
template <typename T> template <typename T>
static void polyGamma_(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output) { static void polyGamma_(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output) {
NDArray& result = output;
int xLen = x.lengthOf();
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i += increment) for (auto i = start; i < stop; i += increment) {
result.p(i, polyGammaScalar<T>(context, n.e<int>(i), x.e<T>(i))); const T order = n.e<T>(i);
if(order != static_cast<int>(order)) // if order has fractional part then do not perform calculations and return NAN
output.p(i, std::numeric_limits<T>::quiet_NaN());
else if (order == 0) // polygamma function of zero order is digamma function
output.p(i, diGammaScalar<T>(x.e<T>(i)));
else
output.p(i, polyGammaScalar<T>(context, order, x.e<T>(i)));
}
}; };
samediff::Threads::parallel_for(func, 0, x.lengthOf()); samediff::Threads::parallel_for(func, 0, x.lengthOf());
// return result;
} }
void polyGamma(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output) { void polyGamma(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output) {

View File

@ -89,20 +89,6 @@ __device__ T continuedFractionCuda(const T a, const T b, const T x) {
return 1.f / 0.f; // no convergence, more iterations is required return 1.f / 0.f; // no convergence, more iterations is required
} }
///////////////////////////////////////////////////////////////////
// evaluates incomplete beta function for positive a and b, and x between 0 and 1.
template <typename T>
__device__ T betaIncCoreCuda(T a, T b, T x) {
const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b);
const T front = math::nd4j_exp<T,T>(math::nd4j_log<T, T>(x) * a + math::nd4j_log<T, T>(1 - x) * b - gammaPart) / a;
if (x <= (a + static_cast<T>(1)) / (a + b + static_cast<T>(2)))
return front * continuedFractionCuda(a, b, x);
else // symmetry relation
return static_cast<T>(1) - front * continuedFractionCuda(b, a, static_cast<T>(1) - x);
}
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
template<typename T> template<typename T>
__global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo, __global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo,
@ -115,13 +101,22 @@ __global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo,
const Nd4jLong j = blockIdx.x; // one block per each element const Nd4jLong j = blockIdx.x; // one block per each element
Nd4jLong len = shape::length(xShapeInfo);
const T a = *(reinterpret_cast<const T*>(va) + shape::getIndexOffset(j, aShapeInfo));
const T b = *(reinterpret_cast<const T*>(vb) + shape::getIndexOffset(j, bShapeInfo));
const T x = *(reinterpret_cast<const T*>(vx) + shape::getIndexOffset(j, xShapeInfo));
T& z = *(reinterpret_cast<T*>(vz) + shape::getIndexOffset(j, zShapeInfo)); T& z = *(reinterpret_cast<T*>(vz) + shape::getIndexOffset(j, zShapeInfo));
__shared__ T a, b, x;
__shared__ bool symmCond;
if (threadIdx.x == 0) {
a = *(reinterpret_cast<const T*>(va) + shape::getIndexOffset(j, aShapeInfo));
b = *(reinterpret_cast<const T*>(vb) + shape::getIndexOffset(j, bShapeInfo));
x = *(reinterpret_cast<const T*>(vx) + shape::getIndexOffset(j, xShapeInfo));
symmCond = x <= (a + static_cast<T>(1)) / (a + b + static_cast<T>(2));
}
__syncthreads();
// t^{n-1} * (1 - t)^{n-1} is symmetric function with respect to x = 0.5 // t^{n-1} * (1 - t)^{n-1} is symmetric function with respect to x = 0.5
if(a == b && x == static_cast<T>(0.5)) { if(a == b && x == static_cast<T>(0.5)) {
z = static_cast<T>(0.5); z = static_cast<T>(0.5);
@ -135,17 +130,31 @@ __global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo,
if(threadIdx.x % 2 == 0) { /***** even part *****/ if(threadIdx.x % 2 == 0) { /***** even part *****/
const int m = threadIdx.x + 1; const int m = threadIdx.x + 1;
if(symmCond)
sharedMem[threadIdx.x] = m * (b - m) * x / ((a + 2 * m - static_cast<T>(1)) * (a + 2 * m)); sharedMem[threadIdx.x] = m * (b - m) * x / ((a + 2 * m - static_cast<T>(1)) * (a + 2 * m));
else
sharedMem[threadIdx.x] = m * (a - m) * (1.f-x) / ((b + 2 * m - static_cast<T>(1)) * (b + 2 * m));
} }
else { /***** odd part *****/ else { /***** odd part *****/
const int m = threadIdx.x; const int m = threadIdx.x;
if(symmCond)
sharedMem[threadIdx.x] = -(a + m) * (a + b + m) * x / ((a + 2 * m + static_cast<T>(1)) * (a + 2 * m)); sharedMem[threadIdx.x] = -(a + m) * (a + b + m) * x / ((a + 2 * m + static_cast<T>(1)) * (a + 2 * m));
else
sharedMem[threadIdx.x] = -(b + m) * (a + b + m) * (1.f-x) / ((b + 2 * m + static_cast<T>(1)) * (b + 2 * m));
} }
__syncthreads(); __syncthreads();
if(threadIdx.x == 0) if(threadIdx.x == 0) {
z = betaIncCoreCuda(a, b, x);
const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b);
const T front = math::nd4j_exp<T,T>(math::nd4j_log<T, T>(x) * a + math::nd4j_log<T, T>(1.f - x) * b - gammaPart);
if (symmCond)
z = front * continuedFractionCuda(a, b, x) / a;
else // symmetry relation
z = static_cast<T>(1) - front * continuedFractionCuda(b, a, static_cast<T>(1) - x) / b;
}
} }
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////

View File

@ -0,0 +1,78 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include<ops/declarable/helpers/gammaMathFunc.h>
#include <NDArrayFactory.h>
namespace nd4j {
namespace ops {
namespace helpers {
///////////////////////////////////////////////////////////////////
template<typename T>
__global__ static void diGammaCuda(const void *vx, const Nd4jLong *xShapeInfo,
void *vz, const Nd4jLong *zShapeInfo) {
const auto x = reinterpret_cast<const T*>(vx);
auto z = reinterpret_cast<T*>(vz);
__shared__ Nd4jLong len;
__shared__ bool sameOffset;
if (threadIdx.x == 0) {
len = shape::length(xShapeInfo);
sameOffset = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
}
__syncthreads();
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < len; i += gridDim.x * blockDim.x) {
const auto xOffset = shape::getIndexOffset(i, xShapeInfo);
const auto zOffset = sameOffset ? xOffset : shape::getIndexOffset(i, zShapeInfo);
z[zOffset] = diGammaScalar<T>(x[xOffset]);
}
}
///////////////////////////////////////////////////////////////////
template<typename T>
static void diGammaCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo) {
diGammaCuda<T><<<blocksPerGrid, threadsPerBlock, 1024, *stream>>>(vx, xShapeInfo, vz, zShapeInfo);
}
///////////////////////////////////////////////////////////////////
void diGamma(nd4j::LaunchContext* context, const NDArray& x, NDArray& z) {
int threadsPerBlock = MAX_NUM_THREADS / 2;
int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
NDArray::prepareSpecialUse({&z}, {&x});
BUILD_SINGLE_SELECTOR(x.dataType(), diGammaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), z.getSpecialBuffer(), z.getSpecialShapeInfo()), FLOAT_TYPES);
NDArray::registerSpecialUse({&z}, {&x});
}
BUILD_SINGLE_TEMPLATE(template void diGammaCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo), FLOAT_TYPES);
}
}
}

View File

@ -18,7 +18,7 @@
// @author Yurii Shyrma (iuriish@yahoo.com), created on 26.04.2019 // @author Yurii Shyrma (iuriish@yahoo.com), created on 26.04.2019
// //
#include<ops/declarable/helpers/polyGamma.h> #include<ops/declarable/helpers/gammaMathFunc.h>
#include<ops/declarable/helpers/zeta.h> #include<ops/declarable/helpers/zeta.h>
#include <NDArrayFactory.h> #include <NDArrayFactory.h>
@ -37,9 +37,13 @@ __global__ static void polyGammaCuda(const void *vn, const Nd4jLong *nShapeInfo,
auto z = reinterpret_cast<T*>(vz); auto z = reinterpret_cast<T*>(vz);
__shared__ Nd4jLong len; __shared__ Nd4jLong len;
__shared__ bool sameOffsetNX, sameOffsetNZ;
if (threadIdx.x == 0) if (threadIdx.x == 0) {
len = shape::length(nShapeInfo); len = shape::length(nShapeInfo);
sameOffsetNX = shape::haveSameShapeAndStrides(xShapeInfo, nShapeInfo);
sameOffsetNZ = shape::haveSameShapeAndStrides(zShapeInfo, nShapeInfo);
}
__syncthreads(); __syncthreads();
const auto tid = blockIdx.x * blockDim.x + threadIdx.x; const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
@ -48,19 +52,26 @@ __global__ static void polyGammaCuda(const void *vn, const Nd4jLong *nShapeInfo,
for (int i = tid; i < len; i += totalThreads) { for (int i = tid; i < len; i += totalThreads) {
const auto nOffset = shape::getIndexOffset(i, nShapeInfo); const auto nOffset = shape::getIndexOffset(i, nShapeInfo);
const auto xOffset = shape::getIndexOffset(i, xShapeInfo); const auto xOffset = sameOffsetNX ? nOffset : shape::getIndexOffset(i, xShapeInfo);
const auto zOffset = shape::getIndexOffset(i, zShapeInfo); const auto zOffset = sameOffsetNZ ? nOffset : shape::getIndexOffset(i, zShapeInfo);
const T nVal = n[nOffset]; const T order = n[nOffset];
int sign = (static_cast<int>(nVal) + 1) % 2 ? -1 : 1; int sign = (static_cast<int>(order) + 1) % 2 ? -1 : 1;
if(order != static_cast<int>(order)) {
z[zOffset] = DataTypeUtils::nanOrZero<T>();
}
else if(order == 0) {
z[zOffset] = diGammaScalar<T>(x[xOffset]);
}
else {
T factorial = 1; T factorial = 1;
if(nVal != 0 && nVal != 1) for(int i = 2; i <= order; ++i)
for(int i = 2; i <= nVal; ++i)
factorial *= i; factorial *= i;
z[zOffset] = sign * factorial * zetaScalar<T>(nVal + 1, x[xOffset]); z[zOffset] = sign * factorial * zetaScalar<T>(order + 1, x[xOffset]);
}
} }
} }
@ -76,7 +87,7 @@ void polyGamma(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x
NDArray::prepareSpecialUse({&z}, {&n, &x}); NDArray::prepareSpecialUse({&z}, {&n, &x});
int threadsPerBlock = MAX_NUM_THREADS; int threadsPerBlock = MAX_NUM_THREADS / 2;
int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock; int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
BUILD_SINGLE_SELECTOR(n.dataType(), polyGammaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), n.getSpecialBuffer(), n.getSpecialShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), z.getSpecialBuffer(), z.getSpecialShapeInfo()), FLOAT_TYPES); BUILD_SINGLE_SELECTOR(n.dataType(), polyGammaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), n.getSpecialBuffer(), n.getSpecialShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), z.getSpecialBuffer(), z.getSpecialShapeInfo()), FLOAT_TYPES);

View File

@ -0,0 +1,100 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#ifndef LIBND4J_GAMMAMATHFUNC_H
#define LIBND4J_GAMMAMATHFUNC_H
#include <ops/declarable/helpers/helpers.h>
#include "NDArray.h"
namespace nd4j {
namespace ops {
namespace helpers {
// calculate the digamma function for each element for array
void diGamma(nd4j::LaunchContext* context, const NDArray& x, NDArray& z);
// calculate the polygamma function
void polyGamma(nd4j::LaunchContext* context, const NDArray& n, const NDArray& x, NDArray& z);
// calculate the digamma function for one element
// implementation is based on serial representation written in terms of the Hurwitz zeta function as polygamma = (-1)^{n+1} * n! * zeta(n+1, x)
template <typename T>
_CUDA_HD T diGammaScalar(T x) {
const int xInt = static_cast<int>(x);
// negative and zero
if(x <= 0) {
if(x == xInt) // integer
return DataTypeUtils::infOrMax<T>();
else
return diGammaScalar<T>(1 - x) - M_PI / nd4j::math::nd4j_tan<T,T>(M_PI * x); // use reflection formula psi(1-x) = psi(x) + pi*cot(pi*x)
}
// positive integer
if(x == xInt && xInt <= 20) { // psi(n) = -Euler_Mascheroni_const + sum_from_k=1_to_n-1( 1/k ), for n = 1,2,3,...inf, we use this formula only for n <= 20 to avoid time consuming sum calculation for bigger n
T result = -0.577215664901532;
for (uint i = 1; i <= xInt - 1; ++i) {
result += static_cast<T>(1) / i;
}
return result;
}
// positive half-integer
if(x - xInt == 0.5 && xInt <= 20) { // psi(n+0.5) = -Euler_Mascheroni_const - 2*ln(2) + sum_from_k=1_to_n( 2/(2*k-1) ) , for n = 1,2,3,...inf, we use this formula only for n <= 20 to avoid time consuming sum calculation for bigger n
T result = -0.577215664901532 - 2 * nd4j::math::nd4j_log<T,T>(2);
for (uint i = 1; i <= xInt; ++i) {
result += static_cast<T>(2) / (2*i - 1);
}
return result;
}
// positive, smaller then 5; we should use number > 5 in order to have satisfactory accuracy in asymptotic expansion
if(x < 5)
return diGammaScalar<T>(1 + x) - static_cast<T>(1) / x; // recurrence formula psi(x) = psi(x+1) - 1/x.
// *** other positive **** //
// truncated expansion formula (from wiki)
// psi(x) = log(x) - 1/(2*x) - 1/(12*x^2) + 1/(120*x^4) - 1/(252*x^6) + 1/(240*x^8) - 5/(660*x^10) + 691/(32760*x^12) - 1/(12*x^14) + ...
if(x >= (sizeof(T) > 4 ? 1.e16 : 1.e8)) // if x is too big take into account only log(x)
return nd4j::math::nd4j_log<T,T>(x);
// coefficients used in truncated asymptotic expansion formula
const T coeffs[7] = {-(T)1/12, (T)1/120, -(T)1/252, (T)1/240, -(T)5/660, (T)691/32760, -(T)1/12};
// const T coeffs[7] = {-0.0833333333333333, 0.00833333333333333, -0.00396825396825397, 0.00416666666666667, -0.00757575757575758, 0.0210927960927961, -0.0833333333333333};
const T x2Inv = static_cast<T>(1) / (x * x);
T result = 0;
for (int i = 6; i >= 0; --i)
result = (result + coeffs[i]) * x2Inv;
return result + nd4j::math::nd4j_log<T,T>(x) - static_cast<T>(0.5) / x;
}
}
}
}
#endif //LIBND4J_GAMMAMATHFUNC_H

View File

@ -36,7 +36,7 @@ namespace platforms {
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, const NDArray *weights, static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, const NDArray *weights,
const NDArray *bias, NDArray *output, const int kH, const int kW, const int sH, const NDArray *bias, NDArray *output, const int kH, const int kW, const int sH,
const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode, const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode,
const int isNCHW) { const int isNCHW) {
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
@ -44,8 +44,7 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW, ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW,
indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
if(isSameMode) // SAME ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
dnnl_memory_desc_t empty; dnnl_memory_desc_t empty;
dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md( dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(
@ -53,7 +52,7 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con
dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md( dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
empty); empty);
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW, mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW,
bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr, bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr,
bias, output, bias, output,
&conv_src_md, nullptr, &conv_weights_md, nullptr, &conv_src_md, nullptr, &conv_weights_md, nullptr,
@ -115,13 +114,11 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(conv2d) { PLATFORM_IMPL(conv2d) {
auto input = INPUT_VARIABLE( auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto output = OUTPUT_VARIABLE( auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
int sH = INT_ARG(2); // strides height int sH = INT_ARG(2); // strides height
int sW = INT_ARG(3); // strides width int sW = INT_ARG(3); // strides width
@ -129,13 +126,13 @@ PLATFORM_IMPL(conv2d) {
int pW = INT_ARG(5); // paddings width int pW = INT_ARG(5); // paddings width
int dH = INT_ARG(6); // dilations height int dH = INT_ARG(6); // dilations height
int dW = INT_ARG(7); // dilations width int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width
conv2d_mkldnn(block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW); conv2d_mkldnn(block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
return Status::OK(); return Status::OK();
} }
@ -155,18 +152,13 @@ PLATFORM_CHECK(conv2d) {
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
PLATFORM_IMPL(conv2d_bp) { PLATFORM_IMPL(conv2d_bp) {
auto input = INPUT_VARIABLE( auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
auto weights = INPUT_VARIABLE(
1); // [kH, kW, iC, oC] always
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE( auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto gradI = OUTPUT_VARIABLE( auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC] always
auto gradW = OUTPUT_VARIABLE(
1); // [kH, kW, iC, oC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
int kH = INT_ARG(0); // filter(kernel) height int kH = INT_ARG(0); // filter(kernel) height
@ -177,7 +169,7 @@ PLATFORM_IMPL(conv2d_bp) {
int pW = INT_ARG(5); // paddings width int pW = INT_ARG(5); // paddings width
int dH = INT_ARG(6); // dilations height int dH = INT_ARG(6); // dilations height
int dW = INT_ARG(7); // dilations width int dW = INT_ARG(7); // dilations width
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
REQUIRE_TRUE(input->rankOf() == 4, 0, REQUIRE_TRUE(input->rankOf() == 4, 0,
@ -195,8 +187,7 @@ PLATFORM_IMPL(conv2d_bp) {
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC, ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
indIiH, indWiC, indWoC, indWkH, indOoH); indIiH, indWiC, indWoC, indWkH, indOoH);
if (isSameMode) // SAME ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
dnnl_memory_desc_t empty; dnnl_memory_desc_t empty;
dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty), dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
@ -204,7 +195,7 @@ PLATFORM_IMPL(conv2d_bp) {
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty), dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty); user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation; dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW, mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW,
bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW, bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW,
gradB, gradO, gradB, gradO,
&conv_src_md, &conv_diff_src_md, &conv_weights_md, &conv_src_md, &conv_diff_src_md, &conv_weights_md,
@ -342,18 +333,13 @@ PLATFORM_CHECK(conv2d_bp) {
if (::optimalLevel() < 2) if (::optimalLevel() < 2)
return false; return false;
auto input = INPUT_VARIABLE( auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW) auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
auto weights = INPUT_VARIABLE(
1); // [kH, kW, iC, oC] always
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC] auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE( auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
auto gradI = OUTPUT_VARIABLE( auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC] always
auto gradW = OUTPUT_VARIABLE(
1); // [kH, kW, iC, oC] always
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC] auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]

View File

@ -1568,11 +1568,39 @@ namespace simdOps {
return opOutput + old; return opOutput + old;
} }
op_def static Z update(X old, X opOutput, X *extraParams) { op_def static Z update(X old, X opOutput, X *extraParams) {
return opOutput + old; return opOutput + old;
} }
op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) {
return reduction;
}
};
template <typename X, typename Z>
class IsNegative {
public:
no_op_exec_special_bool
no_op_exec_special_bool_cuda
no_op_exec_special_accumulation
no_op_exec_special_accumulation_cuda
op_def static Z op(X d1, X *params) {
return d1 < (X)0.f;
}
op_def static X startingValue(const X *input) {
return static_cast<X>(0);
}
op_def static Z merge(X old, X opOutput, X *extraParams) {
return opOutput + old;
}
op_def static Z update(X old, X opOutput, X *extraParams) {
return opOutput + old;
}
op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) { op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) {
return reduction; return reduction;

View File

@ -1008,6 +1008,38 @@ TEST_F(ConvolutionTests1, conv1d_causal_6) {
delete results; delete results;
} }
//////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, conv1d_causal_7) {
int bS=2, iW=8, iC=3,oC=4, kW=2, sW=1, pW=0, dW=1;
int oW = (iW-1)/sW + 1;
int paddingMode = 2; // CAUSAL
int dataFormat = 1; // 1-NHWC, 0-NCHW
NDArray input('c', {bS, iW, iC}, nd4j::DataType::FLOAT32);
NDArray weights('c', {kW, iC, oC}, nd4j::DataType::FLOAT32);
NDArray expOutput('c', {bS, oW, oC}, {11.000000, 11.600000, 12.200000, 12.800000, 30.099998, 32.200001, 34.299999, 36.400002, 49.899998, 53.800003, 57.699997,
61.599998, 69.699997, 75.400002, 81.099998, 86.800003, 89.500000, 97.000000, 104.500000, 112.000000, 109.300003, 118.600006, 127.899994, 137.199997, 129.100006,
140.199997, 151.300003, 162.399994, 148.899994, 161.800003, 174.699997, 187.600006, 133.399994, 141.200012, 149.000000, 156.800003, 188.500000, 205.000000,
221.500000, 238.000000, 208.299988, 226.600006, 244.899994, 263.200012, 228.100006, 248.200012, 268.299988, 288.399994, 247.899994, 269.799988, 291.700012,
313.600006, 267.700012, 291.399994, 315.100006, 338.799988, 287.500000, 313.000000, 338.500000, 364.000000, 307.299988, 334.600006, 361.899994, 389.200012}, nd4j::DataType::FLOAT32);
input.linspace(1., 1.);
weights.linspace(0.1, 0.1);
nd4j::ops::conv1d op;
auto results = op.execute({&input, &weights}, {}, {kW, sW, pW, dW, paddingMode, dataFormat});
auto output = results->at(0);
ASSERT_EQ(Status::OK(), results->status());
ASSERT_TRUE(expOutput.isSameShape(output));
ASSERT_TRUE(expOutput.equalsTo(output));
delete results;
}
////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////
TEST_F(ConvolutionTests1, conv1d_causal_bp_1) { TEST_F(ConvolutionTests1, conv1d_causal_bp_1) {

View File

@ -428,10 +428,11 @@ TEST_F(DeclarableOpsTests13, CellContains_test_1) {
TEST_F(DeclarableOpsTests13, adjustHue_1) { TEST_F(DeclarableOpsTests13, adjustHue_1) {
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32); NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32);
NDArray factor = NDArrayFactory::create<float>(0.5);
NDArray exp ('c', {2,2,3}, {100,0,44, 208,5,220, 177,230,97, 2,255,244}, nd4j::DataType::FLOAT32); NDArray exp ('c', {2,2,3}, {100,0,44, 208,5,220, 177,230,97, 2,255,244}, nd4j::DataType::FLOAT32);
nd4j::ops::adjust_hue op; nd4j::ops::adjust_hue op;
auto results = op.execute({&input}, {0.5}, {2}); auto results = op.execute({&input, &factor}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());
@ -525,10 +526,11 @@ TEST_F(DeclarableOpsTests13, adjustHue_5) {
TEST_F(DeclarableOpsTests13, adjustSaturation_1) { TEST_F(DeclarableOpsTests13, adjustSaturation_1) {
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32); NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32);
NDArray factor = NDArrayFactory::create<float>(0.5);
NDArray exp ('c', {2,2,3}, {50,100,78, 118.5,220,112.5, 190,163.5,230, 255,128.5,134}, nd4j::DataType::FLOAT32); NDArray exp ('c', {2,2,3}, {50,100,78, 118.5,220,112.5, 190,163.5,230, 255,128.5,134}, nd4j::DataType::FLOAT32);
nd4j::ops::adjust_saturation op; nd4j::ops::adjust_saturation op;
auto results = op.execute({&input}, {0.5}, {2}); auto results = op.execute({&input, &factor}, {}, {2});
ASSERT_EQ(ND4J_STATUS_OK, results->status()); ASSERT_EQ(ND4J_STATUS_OK, results->status());

View File

@ -159,18 +159,19 @@ TEST_F(DeclarableOpsTests15, Test_standarize_bp_1) {
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_1) { TEST_F(DeclarableOpsTests15, Test_AdjustContrast_1) {
auto x = NDArrayFactory::create<double>('c', {4,4,3}); auto x = NDArrayFactory::create<double>('c', {4,4,3});
auto e = NDArrayFactory::create<double>('c', {4,4,3}, { NDArray factor = NDArrayFactory::create<double>(2.);
-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5, auto e = NDArrayFactory::create<double>('c', {4,4,3}, {-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5,
2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5, 2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5,
26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5, 26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5,
50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5 50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5});
});
x.linspace(1.); x.linspace(1.);
nd4j::ops::adjust_contrast op; nd4j::ops::adjust_contrast op;
auto result = op.execute({&x}, {2.}, {}, {}); auto result = op.execute({&x, &factor}, {}, {}, {});
ASSERT_EQ(Status::OK(), result->status()); ASSERT_EQ(Status::OK(), result->status());
auto out = result->at(0); auto out = result->at(0);
// out->printIndexedBuffer("Adjusted Constrast");
ASSERT_TRUE(e.equalsTo(out)); ASSERT_TRUE(e.equalsTo(out));
delete result; delete result;
} }

View File

@ -1774,6 +1774,28 @@ TEST_F(DeclarableOpsTests3, betainc_test10) {
delete results; delete results;
} }
///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests3, betainc_test11) {
NDArray a('c', {4}, {0.7788f, 0.8012f, 0.7244f, 0.2309f}, nd4j::DataType::FLOAT32);
NDArray b('c', {4}, {0.7717f, 0.9281f, 0.9846f, 0.4838f}, nd4j::DataType::FLOAT32);
NDArray x('c', {4}, {0.9441f, 0.5957f, 0.8669f, 0.3502f}, nd4j::DataType::FLOAT32);
NDArray expected('c', {4}, {0.912156, 0.634443, 0.898314, 0.624544}, nd4j::DataType::FLOAT32);
nd4j::ops::betainc op;
auto results = op.execute({&a, &b, &x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto *output = results->at(0);
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests3, zeta_test1) { TEST_F(DeclarableOpsTests3, zeta_test1) {
@ -2092,8 +2114,26 @@ TEST_F(DeclarableOpsTests3, polygamma_test3) {
x.linspace(10.); x.linspace(10.);
auto expected= NDArrayFactory::create<double>('c', {3,3}, {1.05166336e-01,-9.04983497e-03, 1.31009323e-03,-2.44459433e-04, 5.31593880e-05,-1.28049888e-05, 3.31755364e-06,-9.07408791e-07, 2.58758130e-07}); auto expected= NDArrayFactory::create<double>('c', {3,3}, {1.05166336e-01,-9.04983497e-03, 1.31009323e-03,-2.44459433e-04, 5.31593880e-05,-1.28049888e-05, 3.31755364e-06,-9.07408791e-07, 2.58758130e-07});
nd4j::ops::polygamma op;
auto results = op.execute({&n, &x}, {}, {});
//ASSERT_FALSE(true); ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto output = results->at(0);
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
TEST_F(DeclarableOpsTests3, polygamma_test4) {
NDArray n('c', {3,4}, {/*0.7788*/0, 0,1,2,3,4,5,6,7,8,9,10}, nd4j::DataType::DOUBLE);
NDArray x('c', {3,4}, {0.7717,0.9281,0.9846,0.4838,0.6433,0.6041,0.6501,0.7612,0.7605,0.3948,0.9493,0.8600}, nd4j::DataType::DOUBLE);
NDArray expected('c', {3,4}, {/*std::numeric_limits<double>::quiet_NaN()*/-1.031918, -7.021327e-01, 1.682743e+00, -1.851378e+01,3.604167e+01, -3.008293e+02,
1.596005e+03, -4.876665e+03,4.510025e+04, -1.730340e+08, 6.110257e+05, -1.907087e+07}, nd4j::DataType::DOUBLE);
nd4j::ops::polygamma op; nd4j::ops::polygamma op;
auto results = op.execute({&n, &x}, {}, {}); auto results = op.execute({&n, &x}, {}, {});
@ -2108,6 +2148,26 @@ TEST_F(DeclarableOpsTests3, polygamma_test3) {
delete results; delete results;
} }
TEST_F(DeclarableOpsTests3, digamma_1) {
NDArray x('c', {18}, {-25, -24.99999, -21.5, -21.2, -5.5, -4.1, -2.1, -0.5, -0.3, 0., 0.2, 1, 1.5, 2.2, 5.2, 19., 21, 22.2}, nd4j::DataType::DOUBLE);
NDArray expected('c', {18}, {std::numeric_limits<double>::infinity(), -99996.761229, 3.091129, 7.401432, 1.792911,11.196838,10.630354, 0.03649, 2.11331,
std::numeric_limits<double>::infinity(),-5.28904,-0.577216, 0.03649, 0.544293, 1.549434,2.917892, 3.020524, 3.077401}, nd4j::DataType::DOUBLE);
nd4j::ops::digamma op;
auto results = op.execute({&x}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, results->status());
auto output = results->at(0);
ASSERT_TRUE(expected.isSameShape(output));
ASSERT_TRUE(expected.equalsTo(output));
delete results;
}
/////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests3, svd_test1) { TEST_F(DeclarableOpsTests3, svd_test1) {