Shyrma adjust (#98)
* - add possibility of passing scalar-array as input parameter for scale factor in adjust hue/contrast/saturation ops - correct typo in function which calculates regularized incomplete beta integral Signed-off-by: Yurii <iuriish@yahoo.com> * - fix bug in betainc cuda kernel Signed-off-by: Yurii <iuriish@yahoo.com> * - start working on implementation of digamma function Signed-off-by: Yurii <iuriish@yahoo.com> * - further work on digamma function (cpu) Signed-off-by: Yurii <iuriish@yahoo.com> * - testing and fixing bugs in digamma op Signed-off-by: Yurii <iuriish@yahoo.com> * - make correction n cuda kernel for polyGamma Signed-off-by: Yurii <iuriish@yahoo.com> * - remove unnecessary stuff from betaInc cuda kernel Signed-off-by: Yurii <iuriish@yahoo.com> * - resolve conflicts in DeclarableOpsTests3.cpp after master branch has been merged Signed-off-by: Yurii <iuriish@yahoo.com> * - restore id number of Not opertion in legacy_ops.h Signed-off-by: Yurii <iuriish@yahoo.com> * - correct padding calculation in mkl dnn conv1d causal Signed-off-by: Yurii <iuriish@yahoo.com> * restore empty check in adjust_contrast_v2 Signed-off-by: raver119 <raver119@gmail.com>master
parent
1e9ff114aa
commit
1f5e15b541
|
@ -112,7 +112,8 @@
|
||||||
(4, IsInfOrNan), \
|
(4, IsInfOrNan), \
|
||||||
(5, MatchConditionBool), \
|
(5, MatchConditionBool), \
|
||||||
(6, IsPositive) , \
|
(6, IsPositive) , \
|
||||||
(7, Not)
|
(7, Not), \
|
||||||
|
(8, IsNegative)
|
||||||
|
|
||||||
|
|
||||||
#define TRANSFORM_STRICT_OPS \
|
#define TRANSFORM_STRICT_OPS \
|
||||||
|
@ -279,7 +280,8 @@
|
||||||
(3, IsInfOrNan), \
|
(3, IsInfOrNan), \
|
||||||
(4, IsNan), \
|
(4, IsNan), \
|
||||||
(5, IsInf), \
|
(5, IsInf), \
|
||||||
(6, IsPositive)
|
(6, IsPositive), \
|
||||||
|
(7, IsNegative)
|
||||||
|
|
||||||
#define REDUCE_SAME_OPS \
|
#define REDUCE_SAME_OPS \
|
||||||
(0, Sum), \
|
(0, Sum), \
|
||||||
|
|
|
@ -27,7 +27,8 @@
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, -2, 0) {
|
////////////////////////////////////////////////////////////////////
|
||||||
|
CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, 0, 0) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
@ -37,23 +38,31 @@ CONFIGURABLE_OP_IMPL(adjust_contrast, 1, 1, true, -2, 0) {
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
||||||
REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST: Scale factor required");
|
REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST: Scale factor required");
|
||||||
|
|
||||||
const double factor = block.width() > 1 ? INPUT_VARIABLE(1)->e<double>(0) : T_ARG(0);
|
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST: op expects rank of input array to be >= 3, but got %i instead", input->rankOf());
|
REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST: op expects rank of input array to be >= 3, but got %i instead", input->rankOf());
|
||||||
REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1));
|
REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1));
|
||||||
// compute mean before
|
|
||||||
|
NDArray* factor = nullptr;
|
||||||
|
|
||||||
|
if(block.width() > 1)
|
||||||
|
factor = INPUT_VARIABLE(1);
|
||||||
|
else {
|
||||||
|
factor = new NDArray(output->dataType(), block.launchContext());
|
||||||
|
factor->p(0, T_ARG(0));
|
||||||
|
}
|
||||||
|
|
||||||
// fill up axes vector first
|
// fill up axes vector first
|
||||||
std::vector<int> axes(input->rankOf() - 1);
|
std::vector<int> axes(input->rankOf() - 1);
|
||||||
for (auto i = 0; i < axes.size(); ++i)
|
for (auto i = 0; i < axes.size(); ++i)
|
||||||
axes[i] = i;
|
axes[i] = i;
|
||||||
|
|
||||||
// mean as reduction for last dimension set
|
// mean as reduction for last dimension set
|
||||||
auto mean = input->reduceAlongDims(reduce::Mean, axes);
|
auto mean = input->reduceAlongDims(reduce::Mean, axes);
|
||||||
|
|
||||||
NDArray factorT(output->dataType(), block.launchContext()); // = NDArrayFactory::create(factor, block.launchContext());
|
|
||||||
factorT.p(0, factor);
|
|
||||||
// this is contrast calculation
|
// this is contrast calculation
|
||||||
output->assign((*input - mean) * factorT + mean);
|
output->assign((*input - mean) * (*factor) + mean);
|
||||||
|
|
||||||
|
if(block.width() == 1)
|
||||||
|
delete factor;
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -64,22 +73,28 @@ DECLARE_TYPES(adjust_contrast) {
|
||||||
->setSameMode(true);
|
->setSameMode(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////
|
||||||
CONFIGURABLE_OP_IMPL(adjust_contrast_v2, 1, 1, true, -2, 0) {
|
CONFIGURABLE_OP_IMPL(adjust_contrast_v2, 1, 1, true, 0, 0) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST_V2: Scale factor required");
|
|
||||||
|
|
||||||
const double factor = block.width() > 1 ? INPUT_VARIABLE(1)->e<double>(0) : T_ARG(0);
|
|
||||||
|
|
||||||
// just skip op if input is empty
|
// just skip op if input is empty
|
||||||
if (input->isEmpty())
|
if (input->isEmpty())
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST_V2: op expects rank of input array to be >= 3, but got %i instead", input->rankOf());
|
REQUIRE_TRUE(input->rankOf() > 2, 0, "ADJUST_CONTRAST_V2: op expects rank of input array to be >= 3, but got %i instead", input->rankOf());
|
||||||
REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST_V2: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1));
|
REQUIRE_TRUE(input->sizeAt(-1) == 3, 0, "ADJUST_CONTRAST_V2: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(-1));
|
||||||
|
REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_CONTRAST_V2: Scale factor required");
|
||||||
|
|
||||||
|
NDArray* factor = nullptr;
|
||||||
|
|
||||||
|
if(block.width() > 1)
|
||||||
|
factor = INPUT_VARIABLE(1);
|
||||||
|
else {
|
||||||
|
factor = new NDArray(output->dataType(), block.launchContext());
|
||||||
|
factor->p(0, T_ARG(0));
|
||||||
|
}
|
||||||
|
|
||||||
// compute mean before
|
// compute mean before
|
||||||
std::vector<int> axes(input->rankOf() - 1);
|
std::vector<int> axes(input->rankOf() - 1);
|
||||||
|
@ -92,9 +107,12 @@ DECLARE_TYPES(adjust_contrast) {
|
||||||
// result as (x - mean) * factor + mean
|
// result as (x - mean) * factor + mean
|
||||||
auto temp = input->ulike();
|
auto temp = input->ulike();
|
||||||
input->applyTrueBroadcast(BroadcastOpsTuple::Subtract(), &mean, &temp);
|
input->applyTrueBroadcast(BroadcastOpsTuple::Subtract(), &mean, &temp);
|
||||||
temp.applyScalar(scalar::Multiply, factor);
|
temp.applyScalarArr(scalar::Multiply, factor);
|
||||||
temp.applyTrueBroadcast(BroadcastOpsTuple::Add(), &mean, output);
|
temp.applyTrueBroadcast(BroadcastOpsTuple::Add(), &mean, output);
|
||||||
|
|
||||||
|
if(block.width() == 1)
|
||||||
|
delete factor;
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -24,13 +24,12 @@
|
||||||
|
|
||||||
#include <ops/declarable/headers/parity_ops.h>
|
#include <ops/declarable/headers/parity_ops.h>
|
||||||
#include <ops/declarable/helpers/adjust_hue.h>
|
#include <ops/declarable/helpers/adjust_hue.h>
|
||||||
#include <NDArrayFactory.h>
|
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
|
|
||||||
CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, 1, -2) {
|
CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, 0, 0) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
@ -41,15 +40,26 @@ CONFIGURABLE_OP_IMPL(adjust_hue, 1, 1, true, 1, -2) {
|
||||||
|
|
||||||
const int rank = input->rankOf();
|
const int rank = input->rankOf();
|
||||||
const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
|
const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
|
||||||
const double delta = T_ARG(0);
|
|
||||||
|
|
||||||
|
REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_HUE: delta factor is required !");
|
||||||
REQUIRE_TRUE(rank >= 3, 0, "ADJUST_HUE: op expects rank of input array to be >= 3, but got %i instead", rank);
|
REQUIRE_TRUE(rank >= 3, 0, "ADJUST_HUE: op expects rank of input array to be >= 3, but got %i instead", rank);
|
||||||
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_HUE: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC));
|
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_HUE: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC));
|
||||||
REQUIRE_TRUE(-1. <= delta && delta <= 1., 0, "ADJUST_HUE: parameter delta must be within [-1, 1] interval, but got %f instead", delta);
|
|
||||||
|
|
||||||
NDArray deltaScalarArr = NDArrayFactory::create<double>(delta, block.launchContext());
|
NDArray* delta = nullptr;
|
||||||
|
|
||||||
helpers::adjustHue(block.launchContext(), input, &deltaScalarArr, output, dimC);
|
if(block.width() > 1)
|
||||||
|
delta = INPUT_VARIABLE(1);
|
||||||
|
else {
|
||||||
|
delta = new NDArray(output->dataType(), block.launchContext());
|
||||||
|
delta->p(0, T_ARG(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
REQUIRE_TRUE(-1. <= delta->e<double>(0) && delta->e<double>(0) <= 1., 0, "ADJUST_HUE: parameter delta must be within [-1, 1] interval, but got %f instead", delta);
|
||||||
|
|
||||||
|
helpers::adjustHue(block.launchContext(), input, delta, output, dimC);
|
||||||
|
|
||||||
|
if(block.width() == 1)
|
||||||
|
delete delta;
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -28,7 +28,7 @@
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, 1, -2) {
|
CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, 0, 0) {
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
@ -39,14 +39,24 @@ CONFIGURABLE_OP_IMPL(adjust_saturation, 1, 1, true, 1, -2) {
|
||||||
|
|
||||||
const int rank = input->rankOf();
|
const int rank = input->rankOf();
|
||||||
const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
|
const int dimC = block.getIArguments()->size() > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
|
||||||
const double factor = T_ARG(0);
|
|
||||||
|
|
||||||
REQUIRE_TRUE(rank >= 3, 0, "ADJUST_SATURATION: op expects rank of input array to be >= 3, but got %i instead", rank);
|
REQUIRE_TRUE(rank >= 3, 0, "ADJUST_SATURATION: op expects rank of input array to be >= 3, but got %i instead", rank);
|
||||||
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_SATURATION: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC));
|
REQUIRE_TRUE(input->sizeAt(dimC) == 3, 0, "ADJUST_SATURATION: operation expects image with 3 channels (R, G, B), but got %i instead", input->sizeAt(dimC));
|
||||||
|
REQUIRE_TRUE(block.numT() > 0 || block.width() > 1, 0, "ADJUST_SATURATION: scale factor is required !");
|
||||||
|
|
||||||
NDArray factorScalarArr = NDArrayFactory::create<double>(factor, block.launchContext());
|
NDArray* factor = nullptr;
|
||||||
|
|
||||||
helpers::adjustSaturation(block.launchContext(), input, &factorScalarArr, output, dimC);
|
if(block.width() > 1)
|
||||||
|
factor = INPUT_VARIABLE(1);
|
||||||
|
else {
|
||||||
|
factor = new NDArray(output->dataType(), block.launchContext());
|
||||||
|
factor->p(0, T_ARG(0));
|
||||||
|
}
|
||||||
|
|
||||||
|
helpers::adjustSaturation(block.launchContext(), input, factor, output, dimC);
|
||||||
|
|
||||||
|
if(block.width() == 1)
|
||||||
|
delete factor;
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -15,27 +16,35 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// Created by Yurii Shyrma on 13.12.2017.
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#ifndef LIBND4J_POLYGAMMA_H
|
#include <op_boilerplate.h>
|
||||||
#define LIBND4J_POLYGAMMA_H
|
#if NOT_EXCLUDED(OP_digamma)
|
||||||
|
|
||||||
#include <ops/declarable/helpers/helpers.h>
|
#include <ops/declarable/CustomOperations.h>
|
||||||
#include "NDArray.h"
|
#include <ops/declarable/helpers/gammaMathFunc.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
namespace helpers {
|
|
||||||
|
|
||||||
|
CONFIGURABLE_OP_IMPL(digamma, 1, 1, false, 0, 0) {
|
||||||
|
|
||||||
// calculate the polygamma function
|
auto x = INPUT_VARIABLE(0);
|
||||||
void polyGamma(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output);
|
auto z = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
helpers::diGamma(block.launchContext(), *x, *z);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(digamma) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes({ALL_FLOATS, ALL_INTS})
|
||||||
|
->setSameMode(true);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
#endif
|
||||||
#endif //LIBND4J_POLYGAMMA_H
|
|
|
@ -15,14 +15,14 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author Yurii Shyrma (iuriish@yahoo.com), created on 13.12.2017
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <op_boilerplate.h>
|
#include <op_boilerplate.h>
|
||||||
#if NOT_EXCLUDED(OP_polygamma)
|
#if NOT_EXCLUDED(OP_polygamma)
|
||||||
|
|
||||||
#include <ops/declarable/CustomOperations.h>
|
#include <ops/declarable/CustomOperations.h>
|
||||||
#include <ops/declarable/helpers/polyGamma.h>
|
#include <ops/declarable/helpers/gammaMathFunc.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -37,11 +37,11 @@ CONFIGURABLE_OP_IMPL(polygamma, 2, 1, false, 0, 0) {
|
||||||
|
|
||||||
Nd4jLong arrLen = n->lengthOf();
|
Nd4jLong arrLen = n->lengthOf();
|
||||||
// FIXME: this shit should be single op call, not a loop!
|
// FIXME: this shit should be single op call, not a loop!
|
||||||
auto nPositive = n->reduceNumber(nd4j::reduce::IsPositive, nullptr);
|
auto nNegative = n->reduceNumber(nd4j::reduce::IsNegative, nullptr);
|
||||||
auto xPositive = x->reduceNumber(nd4j::reduce::IsPositive, nullptr);
|
auto xPositive = x->reduceNumber(nd4j::reduce::IsPositive, nullptr);
|
||||||
bool nPositiveFlag = nPositive.e<bool>(0);
|
bool nPositiveFlag = !nNegative.e<bool>(0); // require all n >= 0
|
||||||
bool xPositiveFlag = xPositive.e<bool>(0);
|
bool xPositiveFlag = xPositive.e<bool>(0); // require all x > 0
|
||||||
REQUIRE_TRUE(nPositiveFlag, 0, "POLYGAMMA op: all elements of n array must be > 0 !");
|
REQUIRE_TRUE(nPositiveFlag, 0, "POLYGAMMA op: all elements of n array must be >= 0 !");
|
||||||
REQUIRE_TRUE(xPositiveFlag, 0, "POLYGAMMA op: all elements of x array must be > 0 !");
|
REQUIRE_TRUE(xPositiveFlag, 0, "POLYGAMMA op: all elements of x array must be > 0 !");
|
||||||
|
|
||||||
helpers::polyGamma(block.launchContext(), *n, *x, *output);
|
helpers::polyGamma(block.launchContext(), *n, *x, *output);
|
||||||
|
|
|
@ -513,7 +513,6 @@ namespace nd4j {
|
||||||
/**
|
/**
|
||||||
* This op calculates polygamma function psi^(n)(x). Implementation is based on serial representation written in
|
* This op calculates polygamma function psi^(n)(x). Implementation is based on serial representation written in
|
||||||
* terms of the Hurwitz zeta function: polygamma = (-1)^{n+1} * n! * zeta(n+1, x).
|
* terms of the Hurwitz zeta function: polygamma = (-1)^{n+1} * n! * zeta(n+1, x).
|
||||||
* Currently the case n = 0 is not supported.
|
|
||||||
*
|
*
|
||||||
* Input arrays:
|
* Input arrays:
|
||||||
* 0: n - define derivative order (n+1), type integer (however currently is implemented as float casted to integer)
|
* 0: n - define derivative order (n+1), type integer (however currently is implemented as float casted to integer)
|
||||||
|
@ -528,6 +527,20 @@ namespace nd4j {
|
||||||
DECLARE_CONFIGURABLE_OP(polygamma, 2, 1, false, 0, 0);
|
DECLARE_CONFIGURABLE_OP(polygamma, 2, 1, false, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This op calculates digamma function psi(x) = derivative of log(Gamma(x))
|
||||||
|
*
|
||||||
|
* Input arrays:
|
||||||
|
* 0: x - abscissa points where to evaluate the digamma function, type float
|
||||||
|
*
|
||||||
|
* Output array:
|
||||||
|
* 0: values of digamma function at corresponding x, type float
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_digamma)
|
||||||
|
DECLARE_CONFIGURABLE_OP(digamma, 1, 1, false, 0, 0);
|
||||||
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This operation takes shape as first argument, and returns new NDArray filled with specific scalar value.
|
* This operation takes shape as first argument, and returns new NDArray filled with specific scalar value.
|
||||||
* Input arrays:
|
* Input arrays:
|
||||||
|
@ -575,44 +588,47 @@ namespace nd4j {
|
||||||
* This operation adjusts image hue by delta
|
* This operation adjusts image hue by delta
|
||||||
* Input arrays:
|
* Input arrays:
|
||||||
* 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels.
|
* 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels.
|
||||||
|
* 1 - optional argument, input scalar-array containing delta
|
||||||
*
|
*
|
||||||
* T arguments:
|
* T arguments:
|
||||||
* 0 - delta value
|
* 0 - optional argument, delta value
|
||||||
*
|
*
|
||||||
* Int arguments:
|
* Int arguments:
|
||||||
* 0 - optional argument, corresponds to dimension with 3 channels
|
* 0 - optional argument, corresponds to dimension with 3 channels
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_adjust_hue)
|
#if NOT_EXCLUDED(OP_adjust_hue)
|
||||||
DECLARE_CONFIGURABLE_OP(adjust_hue, 1, 1, true, 1, -2);
|
DECLARE_CONFIGURABLE_OP(adjust_hue, 1, 1, true, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This operation adjusts image saturation by delta
|
* This operation adjusts image saturation by delta
|
||||||
* Input arrays:
|
* Input arrays:
|
||||||
* 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels.
|
* 0 - input array with rank >= 3, must have at least one dimension equal 3, that is dimension containing channels.
|
||||||
|
* 1 - optional argument, input scalar-array containing saturation factor
|
||||||
*
|
*
|
||||||
* T arguments:
|
* T arguments:
|
||||||
* 0 - saturation factor
|
* 0 - optional argument, saturation factor
|
||||||
*
|
*
|
||||||
* Int arguments:
|
* Int arguments:
|
||||||
* 0 - optional argument, corresponds to dimension with 3 channels
|
* 0 - optional argument, corresponds to dimension with 3 channels
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_adjust_saturation)
|
#if NOT_EXCLUDED(OP_adjust_saturation)
|
||||||
DECLARE_CONFIGURABLE_OP(adjust_saturation, 1, 1, true, 1, -2);
|
DECLARE_CONFIGURABLE_OP(adjust_saturation, 1, 1, true, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This operation adjusts image contrast by given factor ( z = (x - mean) * factor + mean )
|
* This operation adjusts image contrast by given factor ( z = (x - mean) * factor + mean )
|
||||||
* Input arrays:
|
* Input arrays:
|
||||||
* 0 - input array with rank >= 3, must have last one dimension equal 3, that is dimension containing channels.
|
* 0 - input array with rank >= 3, must have last one dimension equal 3, that is dimension containing channels.
|
||||||
|
* 1 - optional argument, input scalar-array containing saturation contrast factor
|
||||||
*
|
*
|
||||||
* T arguments:
|
* T arguments:
|
||||||
* 0 - contrast factor
|
* 0 - optional argument, contrast factor
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
#if NOT_EXCLUDED(OP_adjust_contrast)
|
#if NOT_EXCLUDED(OP_adjust_contrast)
|
||||||
DECLARE_CONFIGURABLE_OP(adjust_contrast, 1, 1, true, -2, 0);
|
DECLARE_CONFIGURABLE_OP(adjust_contrast, 1, 1, true, 0, 0);
|
||||||
DECLARE_CONFIGURABLE_OP(adjust_contrast_v2, 1, 1, true, -2, 0);
|
DECLARE_CONFIGURABLE_OP(adjust_contrast_v2, 1, 1, true, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -107,12 +107,12 @@ static T betaIncCore(T a, T b, T x) {
|
||||||
return x;
|
return x;
|
||||||
|
|
||||||
const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b);
|
const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b);
|
||||||
const T front = math::nd4j_exp<T,T>(math::nd4j_log<T, T>(x) * a + math::nd4j_log<T, T>(1 - x) * b - gammaPart) / a;
|
const T front = math::nd4j_exp<T,T>(math::nd4j_log<T, T>(x) * a + math::nd4j_log<T, T>(1.f - x) * b - gammaPart);
|
||||||
|
|
||||||
if (x <= (a + static_cast<T>(1)) / (a + b + static_cast<T>(2)))
|
if (x <= (a + static_cast<T>(1)) / (a + b + static_cast<T>(2)))
|
||||||
return front * continuedFraction(a, b, x);
|
return front * continuedFraction(a, b, x) / a;
|
||||||
else // symmetry relation
|
else // symmetry relation
|
||||||
return static_cast<T>(1) - front * continuedFraction(b, a, static_cast<T>(1) - x);
|
return static_cast<T>(1) - front * continuedFraction(b, a, static_cast<T>(1) - x) / b;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,53 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include<ops/declarable/helpers/gammaMathFunc.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
// calculate digamma function for array elements
|
||||||
|
template <typename T>
|
||||||
|
static void diGamma_(const NDArray& x, NDArray& z) {
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
|
for (auto i = start; i < stop; i += increment)
|
||||||
|
z.p(i, diGammaScalar<T>(x.e<T>(i)));
|
||||||
|
};
|
||||||
|
samediff::Threads::parallel_for(func, 0, x.lengthOf());
|
||||||
|
}
|
||||||
|
|
||||||
|
void diGamma(nd4j::LaunchContext* context, const NDArray& x, NDArray& z) {
|
||||||
|
|
||||||
|
BUILD_SINGLE_SELECTOR(x.dataType(), diGamma_, (x, z), FLOAT_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void diGamma_, (const NDArray& x, NDArray& z), FLOAT_TYPES);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
// Created by Yurii Shyrma on 12.12.2017
|
// Created by Yurii Shyrma on 12.12.2017
|
||||||
//
|
//
|
||||||
|
|
||||||
#include<ops/declarable/helpers/polyGamma.h>
|
#include<ops/declarable/helpers/gammaMathFunc.h>
|
||||||
#include<ops/declarable/helpers/zeta.h>
|
#include<ops/declarable/helpers/zeta.h>
|
||||||
#include <NDArrayFactory.h>
|
#include <NDArrayFactory.h>
|
||||||
#include <execution/Threads.h>
|
#include <execution/Threads.h>
|
||||||
|
@ -57,8 +57,6 @@ static FORCEINLINE T polyGammaScalar(nd4j::LaunchContext * context, const int n,
|
||||||
// if (x <= (T)0.)
|
// if (x <= (T)0.)
|
||||||
// throw("polyGamma function: x must be > 0 !");
|
// throw("polyGamma function: x must be > 0 !");
|
||||||
|
|
||||||
// TODO case for n = 0 (digamma)
|
|
||||||
|
|
||||||
int sign = (n + 1) % 2 ? -1 : 1;
|
int sign = (n + 1) % 2 ? -1 : 1;
|
||||||
// T factorial = (T)std::tgamma(n + 1);
|
// T factorial = (T)std::tgamma(n + 1);
|
||||||
|
|
||||||
|
@ -71,17 +69,18 @@ static FORCEINLINE T polyGammaScalar(nd4j::LaunchContext * context, const int n,
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void polyGamma_(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output) {
|
static void polyGamma_(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output) {
|
||||||
|
|
||||||
NDArray& result = output;
|
|
||||||
|
|
||||||
int xLen = x.lengthOf();
|
|
||||||
|
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
for (auto i = start; i < stop; i += increment)
|
for (auto i = start; i < stop; i += increment) {
|
||||||
result.p(i, polyGammaScalar<T>(context, n.e<int>(i), x.e<T>(i)));
|
const T order = n.e<T>(i);
|
||||||
|
if(order != static_cast<int>(order)) // if order has fractional part then do not perform calculations and return NAN
|
||||||
|
output.p(i, std::numeric_limits<T>::quiet_NaN());
|
||||||
|
else if (order == 0) // polygamma function of zero order is digamma function
|
||||||
|
output.p(i, diGammaScalar<T>(x.e<T>(i)));
|
||||||
|
else
|
||||||
|
output.p(i, polyGammaScalar<T>(context, order, x.e<T>(i)));
|
||||||
|
}
|
||||||
};
|
};
|
||||||
samediff::Threads::parallel_for(func, 0, x.lengthOf());
|
samediff::Threads::parallel_for(func, 0, x.lengthOf());
|
||||||
|
|
||||||
// return result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void polyGamma(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output) {
|
void polyGamma(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x, NDArray& output) {
|
||||||
|
|
|
@ -89,20 +89,6 @@ __device__ T continuedFractionCuda(const T a, const T b, const T x) {
|
||||||
return 1.f / 0.f; // no convergence, more iterations is required
|
return 1.f / 0.f; // no convergence, more iterations is required
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
|
||||||
// evaluates incomplete beta function for positive a and b, and x between 0 and 1.
|
|
||||||
template <typename T>
|
|
||||||
__device__ T betaIncCoreCuda(T a, T b, T x) {
|
|
||||||
|
|
||||||
const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b);
|
|
||||||
const T front = math::nd4j_exp<T,T>(math::nd4j_log<T, T>(x) * a + math::nd4j_log<T, T>(1 - x) * b - gammaPart) / a;
|
|
||||||
|
|
||||||
if (x <= (a + static_cast<T>(1)) / (a + b + static_cast<T>(2)))
|
|
||||||
return front * continuedFractionCuda(a, b, x);
|
|
||||||
else // symmetry relation
|
|
||||||
return static_cast<T>(1) - front * continuedFractionCuda(b, a, static_cast<T>(1) - x);
|
|
||||||
}
|
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
template<typename T>
|
template<typename T>
|
||||||
__global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo,
|
__global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo,
|
||||||
|
@ -115,13 +101,22 @@ __global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo,
|
||||||
|
|
||||||
const Nd4jLong j = blockIdx.x; // one block per each element
|
const Nd4jLong j = blockIdx.x; // one block per each element
|
||||||
|
|
||||||
Nd4jLong len = shape::length(xShapeInfo);
|
|
||||||
|
|
||||||
const T a = *(reinterpret_cast<const T*>(va) + shape::getIndexOffset(j, aShapeInfo));
|
|
||||||
const T b = *(reinterpret_cast<const T*>(vb) + shape::getIndexOffset(j, bShapeInfo));
|
|
||||||
const T x = *(reinterpret_cast<const T*>(vx) + shape::getIndexOffset(j, xShapeInfo));
|
|
||||||
T& z = *(reinterpret_cast<T*>(vz) + shape::getIndexOffset(j, zShapeInfo));
|
T& z = *(reinterpret_cast<T*>(vz) + shape::getIndexOffset(j, zShapeInfo));
|
||||||
|
|
||||||
|
__shared__ T a, b, x;
|
||||||
|
__shared__ bool symmCond;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
|
||||||
|
a = *(reinterpret_cast<const T*>(va) + shape::getIndexOffset(j, aShapeInfo));
|
||||||
|
b = *(reinterpret_cast<const T*>(vb) + shape::getIndexOffset(j, bShapeInfo));
|
||||||
|
x = *(reinterpret_cast<const T*>(vx) + shape::getIndexOffset(j, xShapeInfo));
|
||||||
|
|
||||||
|
symmCond = x <= (a + static_cast<T>(1)) / (a + b + static_cast<T>(2));
|
||||||
|
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
// t^{n-1} * (1 - t)^{n-1} is symmetric function with respect to x = 0.5
|
// t^{n-1} * (1 - t)^{n-1} is symmetric function with respect to x = 0.5
|
||||||
if(a == b && x == static_cast<T>(0.5)) {
|
if(a == b && x == static_cast<T>(0.5)) {
|
||||||
z = static_cast<T>(0.5);
|
z = static_cast<T>(0.5);
|
||||||
|
@ -135,17 +130,31 @@ __global__ void betaIncForArrayCuda(const void* va, const Nd4jLong* aShapeInfo,
|
||||||
|
|
||||||
if(threadIdx.x % 2 == 0) { /***** even part *****/
|
if(threadIdx.x % 2 == 0) { /***** even part *****/
|
||||||
const int m = threadIdx.x + 1;
|
const int m = threadIdx.x + 1;
|
||||||
|
if(symmCond)
|
||||||
sharedMem[threadIdx.x] = m * (b - m) * x / ((a + 2 * m - static_cast<T>(1)) * (a + 2 * m));
|
sharedMem[threadIdx.x] = m * (b - m) * x / ((a + 2 * m - static_cast<T>(1)) * (a + 2 * m));
|
||||||
|
else
|
||||||
|
sharedMem[threadIdx.x] = m * (a - m) * (1.f-x) / ((b + 2 * m - static_cast<T>(1)) * (b + 2 * m));
|
||||||
}
|
}
|
||||||
else { /***** odd part *****/
|
else { /***** odd part *****/
|
||||||
const int m = threadIdx.x;
|
const int m = threadIdx.x;
|
||||||
|
if(symmCond)
|
||||||
sharedMem[threadIdx.x] = -(a + m) * (a + b + m) * x / ((a + 2 * m + static_cast<T>(1)) * (a + 2 * m));
|
sharedMem[threadIdx.x] = -(a + m) * (a + b + m) * x / ((a + 2 * m + static_cast<T>(1)) * (a + 2 * m));
|
||||||
|
else
|
||||||
|
sharedMem[threadIdx.x] = -(b + m) * (a + b + m) * (1.f-x) / ((b + 2 * m + static_cast<T>(1)) * (b + 2 * m));
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
if(threadIdx.x == 0)
|
if(threadIdx.x == 0) {
|
||||||
z = betaIncCoreCuda(a, b, x);
|
|
||||||
|
const T gammaPart = lgamma(a) + lgamma(b) - lgamma(a + b);
|
||||||
|
const T front = math::nd4j_exp<T,T>(math::nd4j_log<T, T>(x) * a + math::nd4j_log<T, T>(1.f - x) * b - gammaPart);
|
||||||
|
|
||||||
|
if (symmCond)
|
||||||
|
z = front * continuedFractionCuda(a, b, x) / a;
|
||||||
|
else // symmetry relation
|
||||||
|
z = static_cast<T>(1) - front * continuedFractionCuda(b, a, static_cast<T>(1) - x) / b;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -0,0 +1,78 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include<ops/declarable/helpers/gammaMathFunc.h>
|
||||||
|
#include <NDArrayFactory.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
__global__ static void diGammaCuda(const void *vx, const Nd4jLong *xShapeInfo,
|
||||||
|
void *vz, const Nd4jLong *zShapeInfo) {
|
||||||
|
|
||||||
|
const auto x = reinterpret_cast<const T*>(vx);
|
||||||
|
auto z = reinterpret_cast<T*>(vz);
|
||||||
|
|
||||||
|
__shared__ Nd4jLong len;
|
||||||
|
__shared__ bool sameOffset;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
len = shape::length(xShapeInfo);
|
||||||
|
sameOffset = shape::haveSameShapeAndStrides(xShapeInfo, zShapeInfo);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < len; i += gridDim.x * blockDim.x) {
|
||||||
|
|
||||||
|
const auto xOffset = shape::getIndexOffset(i, xShapeInfo);
|
||||||
|
const auto zOffset = sameOffset ? xOffset : shape::getIndexOffset(i, zShapeInfo);
|
||||||
|
|
||||||
|
z[zOffset] = diGammaScalar<T>(x[xOffset]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
template<typename T>
|
||||||
|
static void diGammaCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo) {
|
||||||
|
|
||||||
|
diGammaCuda<T><<<blocksPerGrid, threadsPerBlock, 1024, *stream>>>(vx, xShapeInfo, vz, zShapeInfo);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
void diGamma(nd4j::LaunchContext* context, const NDArray& x, NDArray& z) {
|
||||||
|
|
||||||
|
int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
|
int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({&z}, {&x});
|
||||||
|
BUILD_SINGLE_SELECTOR(x.dataType(), diGammaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), z.getSpecialBuffer(), z.getSpecialShapeInfo()), FLOAT_TYPES);
|
||||||
|
NDArray::registerSpecialUse({&z}, {&x});
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void diGammaCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo), FLOAT_TYPES);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
// @author Yurii Shyrma (iuriish@yahoo.com), created on 26.04.2019
|
// @author Yurii Shyrma (iuriish@yahoo.com), created on 26.04.2019
|
||||||
//
|
//
|
||||||
|
|
||||||
#include<ops/declarable/helpers/polyGamma.h>
|
#include<ops/declarable/helpers/gammaMathFunc.h>
|
||||||
#include<ops/declarable/helpers/zeta.h>
|
#include<ops/declarable/helpers/zeta.h>
|
||||||
#include <NDArrayFactory.h>
|
#include <NDArrayFactory.h>
|
||||||
|
|
||||||
|
@ -37,9 +37,13 @@ __global__ static void polyGammaCuda(const void *vn, const Nd4jLong *nShapeInfo,
|
||||||
auto z = reinterpret_cast<T*>(vz);
|
auto z = reinterpret_cast<T*>(vz);
|
||||||
|
|
||||||
__shared__ Nd4jLong len;
|
__shared__ Nd4jLong len;
|
||||||
|
__shared__ bool sameOffsetNX, sameOffsetNZ;
|
||||||
|
|
||||||
if (threadIdx.x == 0)
|
if (threadIdx.x == 0) {
|
||||||
len = shape::length(nShapeInfo);
|
len = shape::length(nShapeInfo);
|
||||||
|
sameOffsetNX = shape::haveSameShapeAndStrides(xShapeInfo, nShapeInfo);
|
||||||
|
sameOffsetNZ = shape::haveSameShapeAndStrides(zShapeInfo, nShapeInfo);
|
||||||
|
}
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
@ -48,19 +52,26 @@ __global__ static void polyGammaCuda(const void *vn, const Nd4jLong *nShapeInfo,
|
||||||
for (int i = tid; i < len; i += totalThreads) {
|
for (int i = tid; i < len; i += totalThreads) {
|
||||||
|
|
||||||
const auto nOffset = shape::getIndexOffset(i, nShapeInfo);
|
const auto nOffset = shape::getIndexOffset(i, nShapeInfo);
|
||||||
const auto xOffset = shape::getIndexOffset(i, xShapeInfo);
|
const auto xOffset = sameOffsetNX ? nOffset : shape::getIndexOffset(i, xShapeInfo);
|
||||||
const auto zOffset = shape::getIndexOffset(i, zShapeInfo);
|
const auto zOffset = sameOffsetNZ ? nOffset : shape::getIndexOffset(i, zShapeInfo);
|
||||||
|
|
||||||
const T nVal = n[nOffset];
|
const T order = n[nOffset];
|
||||||
|
|
||||||
int sign = (static_cast<int>(nVal) + 1) % 2 ? -1 : 1;
|
int sign = (static_cast<int>(order) + 1) % 2 ? -1 : 1;
|
||||||
|
|
||||||
|
if(order != static_cast<int>(order)) {
|
||||||
|
z[zOffset] = DataTypeUtils::nanOrZero<T>();
|
||||||
|
}
|
||||||
|
else if(order == 0) {
|
||||||
|
z[zOffset] = diGammaScalar<T>(x[xOffset]);
|
||||||
|
}
|
||||||
|
else {
|
||||||
T factorial = 1;
|
T factorial = 1;
|
||||||
if(nVal != 0 && nVal != 1)
|
for(int i = 2; i <= order; ++i)
|
||||||
for(int i = 2; i <= nVal; ++i)
|
|
||||||
factorial *= i;
|
factorial *= i;
|
||||||
|
|
||||||
z[zOffset] = sign * factorial * zetaScalar<T>(nVal + 1, x[xOffset]);
|
z[zOffset] = sign * factorial * zetaScalar<T>(order + 1, x[xOffset]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -76,7 +87,7 @@ void polyGamma(nd4j::LaunchContext * context, const NDArray& n, const NDArray& x
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&z}, {&n, &x});
|
NDArray::prepareSpecialUse({&z}, {&n, &x});
|
||||||
|
|
||||||
int threadsPerBlock = MAX_NUM_THREADS;
|
int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
int blocksPerGrid = (z.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
|
||||||
BUILD_SINGLE_SELECTOR(n.dataType(), polyGammaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), n.getSpecialBuffer(), n.getSpecialShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), z.getSpecialBuffer(), z.getSpecialShapeInfo()), FLOAT_TYPES);
|
BUILD_SINGLE_SELECTOR(n.dataType(), polyGammaCudaLauncher, (blocksPerGrid, threadsPerBlock, context->getCudaStream(), n.getSpecialBuffer(), n.getSpecialShapeInfo(), x.getSpecialBuffer(), x.getSpecialShapeInfo(), z.getSpecialBuffer(), z.getSpecialShapeInfo()), FLOAT_TYPES);
|
||||||
|
|
|
@ -0,0 +1,100 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef LIBND4J_GAMMAMATHFUNC_H
|
||||||
|
#define LIBND4J_GAMMAMATHFUNC_H
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/helpers.h>
|
||||||
|
#include "NDArray.h"
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
// calculate the digamma function for each element for array
|
||||||
|
void diGamma(nd4j::LaunchContext* context, const NDArray& x, NDArray& z);
|
||||||
|
|
||||||
|
// calculate the polygamma function
|
||||||
|
void polyGamma(nd4j::LaunchContext* context, const NDArray& n, const NDArray& x, NDArray& z);
|
||||||
|
|
||||||
|
// calculate the digamma function for one element
|
||||||
|
// implementation is based on serial representation written in terms of the Hurwitz zeta function as polygamma = (-1)^{n+1} * n! * zeta(n+1, x)
|
||||||
|
template <typename T>
|
||||||
|
_CUDA_HD T diGammaScalar(T x) {
|
||||||
|
|
||||||
|
const int xInt = static_cast<int>(x);
|
||||||
|
|
||||||
|
// negative and zero
|
||||||
|
if(x <= 0) {
|
||||||
|
if(x == xInt) // integer
|
||||||
|
return DataTypeUtils::infOrMax<T>();
|
||||||
|
else
|
||||||
|
return diGammaScalar<T>(1 - x) - M_PI / nd4j::math::nd4j_tan<T,T>(M_PI * x); // use reflection formula psi(1-x) = psi(x) + pi*cot(pi*x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// positive integer
|
||||||
|
if(x == xInt && xInt <= 20) { // psi(n) = -Euler_Mascheroni_const + sum_from_k=1_to_n-1( 1/k ), for n = 1,2,3,...inf, we use this formula only for n <= 20 to avoid time consuming sum calculation for bigger n
|
||||||
|
T result = -0.577215664901532;
|
||||||
|
for (uint i = 1; i <= xInt - 1; ++i) {
|
||||||
|
result += static_cast<T>(1) / i;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// positive half-integer
|
||||||
|
if(x - xInt == 0.5 && xInt <= 20) { // psi(n+0.5) = -Euler_Mascheroni_const - 2*ln(2) + sum_from_k=1_to_n( 2/(2*k-1) ) , for n = 1,2,3,...inf, we use this formula only for n <= 20 to avoid time consuming sum calculation for bigger n
|
||||||
|
T result = -0.577215664901532 - 2 * nd4j::math::nd4j_log<T,T>(2);
|
||||||
|
for (uint i = 1; i <= xInt; ++i) {
|
||||||
|
result += static_cast<T>(2) / (2*i - 1);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
// positive, smaller then 5; we should use number > 5 in order to have satisfactory accuracy in asymptotic expansion
|
||||||
|
if(x < 5)
|
||||||
|
return diGammaScalar<T>(1 + x) - static_cast<T>(1) / x; // recurrence formula psi(x) = psi(x+1) - 1/x.
|
||||||
|
|
||||||
|
// *** other positive **** //
|
||||||
|
|
||||||
|
// truncated expansion formula (from wiki)
|
||||||
|
// psi(x) = log(x) - 1/(2*x) - 1/(12*x^2) + 1/(120*x^4) - 1/(252*x^6) + 1/(240*x^8) - 5/(660*x^10) + 691/(32760*x^12) - 1/(12*x^14) + ...
|
||||||
|
|
||||||
|
if(x >= (sizeof(T) > 4 ? 1.e16 : 1.e8)) // if x is too big take into account only log(x)
|
||||||
|
return nd4j::math::nd4j_log<T,T>(x);
|
||||||
|
|
||||||
|
// coefficients used in truncated asymptotic expansion formula
|
||||||
|
const T coeffs[7] = {-(T)1/12, (T)1/120, -(T)1/252, (T)1/240, -(T)5/660, (T)691/32760, -(T)1/12};
|
||||||
|
// const T coeffs[7] = {-0.0833333333333333, 0.00833333333333333, -0.00396825396825397, 0.00416666666666667, -0.00757575757575758, 0.0210927960927961, -0.0833333333333333};
|
||||||
|
|
||||||
|
const T x2Inv = static_cast<T>(1) / (x * x);
|
||||||
|
T result = 0;
|
||||||
|
|
||||||
|
for (int i = 6; i >= 0; --i)
|
||||||
|
result = (result + coeffs[i]) * x2Inv;
|
||||||
|
return result + nd4j::math::nd4j_log<T,T>(x) - static_cast<T>(0.5) / x;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#endif //LIBND4J_GAMMAMATHFUNC_H
|
|
@ -36,7 +36,7 @@ namespace platforms {
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, const NDArray *weights,
|
static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, const NDArray *weights,
|
||||||
const NDArray *bias, NDArray *output, const int kH, const int kW, const int sH,
|
const NDArray *bias, NDArray *output, const int kH, const int kW, const int sH,
|
||||||
const int sW, int pH, int pW, const int dH, const int dW, const int isSameMode,
|
const int sW, int pH, int pW, const int dH, const int dW, const int paddingMode,
|
||||||
const int isNCHW) {
|
const int isNCHW) {
|
||||||
|
|
||||||
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width;
|
||||||
|
@ -44,8 +44,7 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con
|
||||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW,
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *output, bS, iC, iH, iW, oC, oH, oW,
|
||||||
indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
|
|
||||||
if(isSameMode) // SAME
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
|
||||||
|
|
||||||
dnnl_memory_desc_t empty;
|
dnnl_memory_desc_t empty;
|
||||||
dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(
|
dnnl::memory::desc conv_src_md(empty), conv_weights_md(empty), conv_bias_md(empty), conv_dst_md(
|
||||||
|
@ -53,7 +52,7 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con
|
||||||
dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
|
dnnl::memory::desc user_src_md(empty), user_weights_md(empty), user_bias_md(empty), user_dst_md(
|
||||||
empty);
|
empty);
|
||||||
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
||||||
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW,
|
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW,
|
||||||
bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr,
|
bS, iC, iH, iW, oC, oH, oW, input, nullptr, weights, nullptr,
|
||||||
bias, output,
|
bias, output,
|
||||||
&conv_src_md, nullptr, &conv_weights_md, nullptr,
|
&conv_src_md, nullptr, &conv_weights_md, nullptr,
|
||||||
|
@ -115,13 +114,11 @@ static void conv2d_mkldnn(nd4j::graph::Context &block, const NDArray *input, con
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
PLATFORM_IMPL(conv2d) {
|
PLATFORM_IMPL(conv2d) {
|
||||||
auto input = INPUT_VARIABLE(
|
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
|
||||||
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||||
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
auto bias = block.width() > 2 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(
|
auto output = OUTPUT_VARIABLE(0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
|
||||||
0); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW)
|
|
||||||
|
|
||||||
int sH = INT_ARG(2); // strides height
|
int sH = INT_ARG(2); // strides height
|
||||||
int sW = INT_ARG(3); // strides width
|
int sW = INT_ARG(3); // strides width
|
||||||
|
@ -129,13 +126,13 @@ PLATFORM_IMPL(conv2d) {
|
||||||
int pW = INT_ARG(5); // paddings width
|
int pW = INT_ARG(5); // paddings width
|
||||||
int dH = INT_ARG(6); // dilations height
|
int dH = INT_ARG(6); // dilations height
|
||||||
int dW = INT_ARG(7); // dilations width
|
int dW = INT_ARG(7); // dilations width
|
||||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||||
bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
bool isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||||
|
|
||||||
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height
|
int kH = INT_ARG(0) > 0 ? INT_ARG(0) : static_cast<int>(weights->sizeAt(0)); // filter(kernel) height
|
||||||
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width
|
int kW = INT_ARG(1) > 0 ? INT_ARG(1) : static_cast<int>(weights->sizeAt(1)); // filter(kernel) width
|
||||||
|
|
||||||
conv2d_mkldnn(block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW);
|
conv2d_mkldnn(block, input, weights, bias, output, kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -155,18 +152,13 @@ PLATFORM_CHECK(conv2d) {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
PLATFORM_IMPL(conv2d_bp) {
|
PLATFORM_IMPL(conv2d_bp) {
|
||||||
auto input = INPUT_VARIABLE(
|
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||||
auto weights = INPUT_VARIABLE(
|
|
||||||
1); // [kH, kW, iC, oC] always
|
|
||||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
|
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||||
2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
|
||||||
|
|
||||||
auto gradI = OUTPUT_VARIABLE(
|
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||||
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||||
auto gradW = OUTPUT_VARIABLE(
|
|
||||||
1); // [kH, kW, iC, oC] always
|
|
||||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
|
||||||
int kH = INT_ARG(0); // filter(kernel) height
|
int kH = INT_ARG(0); // filter(kernel) height
|
||||||
|
@ -177,7 +169,7 @@ PLATFORM_IMPL(conv2d_bp) {
|
||||||
int pW = INT_ARG(5); // paddings width
|
int pW = INT_ARG(5); // paddings width
|
||||||
int dH = INT_ARG(6); // dilations height
|
int dH = INT_ARG(6); // dilations height
|
||||||
int dW = INT_ARG(7); // dilations width
|
int dW = INT_ARG(7); // dilations width
|
||||||
int isSameMode = INT_ARG(8); // 0-VALID, 1-SAME
|
int paddingMode = INT_ARG(8); // 0-VALID, 1-SAME
|
||||||
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
int isNCHW = block.getIArguments()->size() > 9 ? !INT_ARG(9) : 1; // INT_ARG(9): 0-NCHW, 1-NHWC
|
||||||
|
|
||||||
REQUIRE_TRUE(input->rankOf() == 4, 0,
|
REQUIRE_TRUE(input->rankOf() == 4, 0,
|
||||||
|
@ -195,8 +187,7 @@ PLATFORM_IMPL(conv2d_bp) {
|
||||||
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
|
ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, *input, *gradO, bS, iC, iH, iW, oC, oH, oW, indIOioC,
|
||||||
indIiH, indWiC, indWoC, indWkH, indOoH);
|
indIiH, indWiC, indWoC, indWkH, indOoH);
|
||||||
|
|
||||||
if (isSameMode) // SAME
|
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW, paddingMode);
|
||||||
ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW);
|
|
||||||
|
|
||||||
dnnl_memory_desc_t empty;
|
dnnl_memory_desc_t empty;
|
||||||
dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
|
dnnl::memory::desc conv_src_md(empty), conv_diff_src_md(empty), conv_weights_md(empty),
|
||||||
|
@ -204,7 +195,7 @@ PLATFORM_IMPL(conv2d_bp) {
|
||||||
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
|
dnnl::memory::desc user_src_md(empty), user_diff_src_md(empty), user_weights_md(empty),
|
||||||
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
|
user_diff_weights_md(empty), user_bias_md(empty), user_dst_md(empty);
|
||||||
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
dnnl::memory::dims conv_strides, conv_padding, conv_padding_r, conv_dilation;
|
||||||
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, isSameMode, isNCHW,
|
mkldnnUtils::getMKLDNNMemoryDescConv2d(kH, kW, sH, sW, pH, pW, dH, dW, paddingMode, isNCHW,
|
||||||
bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW,
|
bS, iC, iH, iW, oC, oH, oW, input, gradI, weights, gradW,
|
||||||
gradB, gradO,
|
gradB, gradO,
|
||||||
&conv_src_md, &conv_diff_src_md, &conv_weights_md,
|
&conv_src_md, &conv_diff_src_md, &conv_weights_md,
|
||||||
|
@ -342,18 +333,13 @@ PLATFORM_CHECK(conv2d_bp) {
|
||||||
if (::optimalLevel() < 2)
|
if (::optimalLevel() < 2)
|
||||||
return false;
|
return false;
|
||||||
|
|
||||||
auto input = INPUT_VARIABLE(
|
auto input = INPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
||||||
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW)
|
auto weights = INPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||||
auto weights = INPUT_VARIABLE(
|
|
||||||
1); // [kH, kW, iC, oC] always
|
|
||||||
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
auto bias = block.width() > 3 ? INPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(
|
auto gradO = block.width() > 3 ? INPUT_VARIABLE(3) : INPUT_VARIABLE(2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
||||||
2); // [bS, oH, oW, oC] (NHWC) or [bS, oC, oH, oW] (NCHW), epsilon_next
|
|
||||||
|
|
||||||
auto gradI = OUTPUT_VARIABLE(
|
auto gradI = OUTPUT_VARIABLE(0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
||||||
0); // [bS, iH, iW, iC] (NHWC) or [bS, iC, iH, iW] (NCHW), epsilon
|
auto gradW = OUTPUT_VARIABLE(1); // [kH, kW, iC, oC] always
|
||||||
auto gradW = OUTPUT_VARIABLE(
|
|
||||||
1); // [kH, kW, iC, oC] always
|
|
||||||
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
auto gradB = block.width() > 3 ? OUTPUT_VARIABLE(2) : nullptr; // [oC]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1568,11 +1568,39 @@ namespace simdOps {
|
||||||
return opOutput + old;
|
return opOutput + old;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
op_def static Z update(X old, X opOutput, X *extraParams) {
|
op_def static Z update(X old, X opOutput, X *extraParams) {
|
||||||
return opOutput + old;
|
return opOutput + old;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) {
|
||||||
|
return reduction;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename X, typename Z>
|
||||||
|
class IsNegative {
|
||||||
|
public:
|
||||||
|
no_op_exec_special_bool
|
||||||
|
no_op_exec_special_bool_cuda
|
||||||
|
|
||||||
|
no_op_exec_special_accumulation
|
||||||
|
no_op_exec_special_accumulation_cuda
|
||||||
|
|
||||||
|
op_def static Z op(X d1, X *params) {
|
||||||
|
return d1 < (X)0.f;
|
||||||
|
}
|
||||||
|
|
||||||
|
op_def static X startingValue(const X *input) {
|
||||||
|
return static_cast<X>(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
op_def static Z merge(X old, X opOutput, X *extraParams) {
|
||||||
|
return opOutput + old;
|
||||||
|
}
|
||||||
|
|
||||||
|
op_def static Z update(X old, X opOutput, X *extraParams) {
|
||||||
|
return opOutput + old;
|
||||||
|
}
|
||||||
|
|
||||||
op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) {
|
op_def static Z postProcess(X reduction, Nd4jLong n, X *extraParams) {
|
||||||
return reduction;
|
return reduction;
|
||||||
|
|
|
@ -1008,6 +1008,38 @@ TEST_F(ConvolutionTests1, conv1d_causal_6) {
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(ConvolutionTests1, conv1d_causal_7) {
|
||||||
|
|
||||||
|
int bS=2, iW=8, iC=3,oC=4, kW=2, sW=1, pW=0, dW=1;
|
||||||
|
int oW = (iW-1)/sW + 1;
|
||||||
|
int paddingMode = 2; // CAUSAL
|
||||||
|
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
||||||
|
|
||||||
|
NDArray input('c', {bS, iW, iC}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray weights('c', {kW, iC, oC}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray expOutput('c', {bS, oW, oC}, {11.000000, 11.600000, 12.200000, 12.800000, 30.099998, 32.200001, 34.299999, 36.400002, 49.899998, 53.800003, 57.699997,
|
||||||
|
61.599998, 69.699997, 75.400002, 81.099998, 86.800003, 89.500000, 97.000000, 104.500000, 112.000000, 109.300003, 118.600006, 127.899994, 137.199997, 129.100006,
|
||||||
|
140.199997, 151.300003, 162.399994, 148.899994, 161.800003, 174.699997, 187.600006, 133.399994, 141.200012, 149.000000, 156.800003, 188.500000, 205.000000,
|
||||||
|
221.500000, 238.000000, 208.299988, 226.600006, 244.899994, 263.200012, 228.100006, 248.200012, 268.299988, 288.399994, 247.899994, 269.799988, 291.700012,
|
||||||
|
313.600006, 267.700012, 291.399994, 315.100006, 338.799988, 287.500000, 313.000000, 338.500000, 364.000000, 307.299988, 334.600006, 361.899994, 389.200012}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
input.linspace(1., 1.);
|
||||||
|
weights.linspace(0.1, 0.1);
|
||||||
|
|
||||||
|
nd4j::ops::conv1d op;
|
||||||
|
auto results = op.execute({&input, &weights}, {}, {kW, sW, pW, dW, paddingMode, dataFormat});
|
||||||
|
auto output = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), results->status());
|
||||||
|
|
||||||
|
ASSERT_TRUE(expOutput.isSameShape(output));
|
||||||
|
ASSERT_TRUE(expOutput.equalsTo(output));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(ConvolutionTests1, conv1d_causal_bp_1) {
|
TEST_F(ConvolutionTests1, conv1d_causal_bp_1) {
|
||||||
|
|
||||||
|
|
|
@ -428,10 +428,11 @@ TEST_F(DeclarableOpsTests13, CellContains_test_1) {
|
||||||
TEST_F(DeclarableOpsTests13, adjustHue_1) {
|
TEST_F(DeclarableOpsTests13, adjustHue_1) {
|
||||||
|
|
||||||
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32);
|
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray factor = NDArrayFactory::create<float>(0.5);
|
||||||
NDArray exp ('c', {2,2,3}, {100,0,44, 208,5,220, 177,230,97, 2,255,244}, nd4j::DataType::FLOAT32);
|
NDArray exp ('c', {2,2,3}, {100,0,44, 208,5,220, 177,230,97, 2,255,244}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::adjust_hue op;
|
nd4j::ops::adjust_hue op;
|
||||||
auto results = op.execute({&input}, {0.5}, {2});
|
auto results = op.execute({&input, &factor}, {}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
@ -525,10 +526,11 @@ TEST_F(DeclarableOpsTests13, adjustHue_5) {
|
||||||
TEST_F(DeclarableOpsTests13, adjustSaturation_1) {
|
TEST_F(DeclarableOpsTests13, adjustSaturation_1) {
|
||||||
|
|
||||||
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32);
|
NDArray input('c', {2,2,3}, {0,100,56, 17,220,5, 150,97,230, 255,2,13}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray factor = NDArrayFactory::create<float>(0.5);
|
||||||
NDArray exp ('c', {2,2,3}, {50,100,78, 118.5,220,112.5, 190,163.5,230, 255,128.5,134}, nd4j::DataType::FLOAT32);
|
NDArray exp ('c', {2,2,3}, {50,100,78, 118.5,220,112.5, 190,163.5,230, 255,128.5,134}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
nd4j::ops::adjust_saturation op;
|
nd4j::ops::adjust_saturation op;
|
||||||
auto results = op.execute({&input}, {0.5}, {2});
|
auto results = op.execute({&input, &factor}, {}, {2});
|
||||||
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
|
|
@ -159,18 +159,19 @@ TEST_F(DeclarableOpsTests15, Test_standarize_bp_1) {
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_1) {
|
TEST_F(DeclarableOpsTests15, Test_AdjustContrast_1) {
|
||||||
auto x = NDArrayFactory::create<double>('c', {4,4,3});
|
auto x = NDArrayFactory::create<double>('c', {4,4,3});
|
||||||
auto e = NDArrayFactory::create<double>('c', {4,4,3}, {
|
NDArray factor = NDArrayFactory::create<double>(2.);
|
||||||
-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5,
|
auto e = NDArrayFactory::create<double>('c', {4,4,3}, {-21.5, -20.5, -19.5, -15.5, -14.5, -13.5, -9.5, -8.5, -7.5, -3.5, -2.5, -1.5,
|
||||||
2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5,
|
2.5, 3.5, 4.5, 8.5, 9.5, 10.5, 14.5, 15.5, 16.5, 20.5, 21.5, 22.5,
|
||||||
26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5,
|
26.5, 27.5, 28.5, 32.5, 33.5, 34.5, 38.5, 39.5, 40.5, 44.5, 45.5, 46.5,
|
||||||
50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5
|
50.5, 51.5, 52.5, 56.5, 57.5, 58.5, 62.5, 63.5, 64.5, 68.5, 69.5, 70.5});
|
||||||
});
|
|
||||||
|
|
||||||
x.linspace(1.);
|
x.linspace(1.);
|
||||||
nd4j::ops::adjust_contrast op;
|
nd4j::ops::adjust_contrast op;
|
||||||
auto result = op.execute({&x}, {2.}, {}, {});
|
auto result = op.execute({&x, &factor}, {}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
auto out = result->at(0);
|
auto out = result->at(0);
|
||||||
// out->printIndexedBuffer("Adjusted Constrast");
|
|
||||||
ASSERT_TRUE(e.equalsTo(out));
|
ASSERT_TRUE(e.equalsTo(out));
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1774,6 +1774,28 @@ TEST_F(DeclarableOpsTests3, betainc_test10) {
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests3, betainc_test11) {
|
||||||
|
|
||||||
|
NDArray a('c', {4}, {0.7788f, 0.8012f, 0.7244f, 0.2309f}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray b('c', {4}, {0.7717f, 0.9281f, 0.9846f, 0.4838f}, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray x('c', {4}, {0.9441f, 0.5957f, 0.8669f, 0.3502f}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
NDArray expected('c', {4}, {0.912156, 0.634443, 0.898314, 0.624544}, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
nd4j::ops::betainc op;
|
||||||
|
auto results = op.execute({&a, &b, &x}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto *output = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests3, zeta_test1) {
|
TEST_F(DeclarableOpsTests3, zeta_test1) {
|
||||||
|
|
||||||
|
@ -2092,8 +2114,26 @@ TEST_F(DeclarableOpsTests3, polygamma_test3) {
|
||||||
x.linspace(10.);
|
x.linspace(10.);
|
||||||
|
|
||||||
auto expected= NDArrayFactory::create<double>('c', {3,3}, {1.05166336e-01,-9.04983497e-03, 1.31009323e-03,-2.44459433e-04, 5.31593880e-05,-1.28049888e-05, 3.31755364e-06,-9.07408791e-07, 2.58758130e-07});
|
auto expected= NDArrayFactory::create<double>('c', {3,3}, {1.05166336e-01,-9.04983497e-03, 1.31009323e-03,-2.44459433e-04, 5.31593880e-05,-1.28049888e-05, 3.31755364e-06,-9.07408791e-07, 2.58758130e-07});
|
||||||
|
nd4j::ops::polygamma op;
|
||||||
|
auto results = op.execute({&n, &x}, {}, {});
|
||||||
|
|
||||||
//ASSERT_FALSE(true);
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto output = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests3, polygamma_test4) {
|
||||||
|
|
||||||
|
NDArray n('c', {3,4}, {/*0.7788*/0, 0,1,2,3,4,5,6,7,8,9,10}, nd4j::DataType::DOUBLE);
|
||||||
|
NDArray x('c', {3,4}, {0.7717,0.9281,0.9846,0.4838,0.6433,0.6041,0.6501,0.7612,0.7605,0.3948,0.9493,0.8600}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
|
NDArray expected('c', {3,4}, {/*std::numeric_limits<double>::quiet_NaN()*/-1.031918, -7.021327e-01, 1.682743e+00, -1.851378e+01,3.604167e+01, -3.008293e+02,
|
||||||
|
1.596005e+03, -4.876665e+03,4.510025e+04, -1.730340e+08, 6.110257e+05, -1.907087e+07}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
nd4j::ops::polygamma op;
|
nd4j::ops::polygamma op;
|
||||||
auto results = op.execute({&n, &x}, {}, {});
|
auto results = op.execute({&n, &x}, {}, {});
|
||||||
|
@ -2108,6 +2148,26 @@ TEST_F(DeclarableOpsTests3, polygamma_test3) {
|
||||||
delete results;
|
delete results;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests3, digamma_1) {
|
||||||
|
|
||||||
|
NDArray x('c', {18}, {-25, -24.99999, -21.5, -21.2, -5.5, -4.1, -2.1, -0.5, -0.3, 0., 0.2, 1, 1.5, 2.2, 5.2, 19., 21, 22.2}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
|
NDArray expected('c', {18}, {std::numeric_limits<double>::infinity(), -99996.761229, 3.091129, 7.401432, 1.792911,11.196838,10.630354, 0.03649, 2.11331,
|
||||||
|
std::numeric_limits<double>::infinity(),-5.28904,-0.577216, 0.03649, 0.544293, 1.549434,2.917892, 3.020524, 3.077401}, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
|
nd4j::ops::digamma op;
|
||||||
|
auto results = op.execute({&x}, {}, {});
|
||||||
|
|
||||||
|
ASSERT_EQ(ND4J_STATUS_OK, results->status());
|
||||||
|
|
||||||
|
auto output = results->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
delete results;
|
||||||
|
}
|
||||||
|
|
||||||
///////////////////////////////////////////////////////////////////
|
///////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests3, svd_test1) {
|
TEST_F(DeclarableOpsTests3, svd_test1) {
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue