Shugeo lgamma (#170)

* lgamma op. Initial version.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Refactored lgamma op and test.

Signed-off-by: shugeo <sgazeos@gmail.com>

* Lgamma wrapper

* Added TF mapping

Co-authored-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
master
shugeo 2020-01-20 11:29:36 +02:00 committed by raver119
parent cae5ef4180
commit 6943a5f57a
9 changed files with 311 additions and 0 deletions

View File

@ -0,0 +1,50 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author George A. Shulinok <sgazeos@gmail.com>
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_lgamma)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/lgamma.h>
namespace nd4j {
namespace ops {
OP_IMPL(lgamma, 1, 1, true) {
auto x = INPUT_VARIABLE(0);
auto z = OUTPUT_VARIABLE(0);
helpers::lgamma(block.launchContext(), *x, *z);
return Status::OK();
}
DECLARE_TYPES(lgamma) {
getOpDescriptor()
->setAllowedInputTypes({ALL_FLOATS}) // as TF says
->setSameMode(true);
}
}
}
#endif

View File

@ -527,6 +527,20 @@ namespace nd4j {
DECLARE_CONFIGURABLE_OP(polygamma, 2, 1, false, 0, 0);
#endif
/**
* This op calculates lgamma function lgamma(x) = log(Gamma(x))
*
* Input arrays:
* 0: x - input matrix
*
* Output array:
* 0: log of Gamma(x)
*
*/
#if NOT_EXCLUDED(OP_lgamma)
DECLARE_OP(lgamma, 1, 1, true);
#endif
/**
* This op calculates digamma function psi(x) = derivative of log(Gamma(x))
*

View File

@ -0,0 +1,53 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author George A. Shulinok <sgazeos@gmail.com>
//
#include<ops/declarable/helpers/lgamma.h>
#include <execution/Threads.h>
namespace nd4j {
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////////
// calculate digamma function for array elements
template <typename T>
static void lgamma_(NDArray& x, NDArray& z) {
auto lgammaProc = LAMBDA_T(x_) {
return T(DataTypeUtils::fromT<T>() == DataType::DOUBLE?::lgamma(x_): ::lgammaf(x_)); //math::nd4j_log<T,T>(math::nd4j_gamma<T,T>(x));
};
x.applyLambda<T>(lgammaProc, z);
}
void lgamma(nd4j::LaunchContext* context, NDArray& x, NDArray& z) {
BUILD_SINGLE_SELECTOR(x.dataType(), lgamma_, (x, z), FLOAT_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void lgamma_, (NDArray& x, NDArray& z), FLOAT_TYPES);
}
}
}

View File

@ -0,0 +1,54 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author George A. Shulinok <sgazeos@gmail.com>
//
#include<ops/declarable/helpers/lgamma.h>
//#include <execution/Threads.h>
//#include <helper_math.h>
namespace nd4j {
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////////
// calculate digamma function for array elements
template <typename T>
static void lgamma_(NDArray& x, NDArray& z) {
//auto dtype = x.dataType();
auto lgammaProc = LAMBDA_T(x_, dtype) {
return T(DataTypeUtils::fromT<T>() == DataType::DOUBLE?::lgamma(x_): ::lgammaf(x_)); //math::nd4j_log<T,T>(math::nd4j_gamma<T,T>(x));
};
x.applyLambda(lgammaProc, z);
}
void lgamma(nd4j::LaunchContext* context, NDArray& x, NDArray& z) {
BUILD_SINGLE_SELECTOR(x.dataType(), lgamma_, (x, z), FLOAT_TYPES);
}
BUILD_SINGLE_TEMPLATE(template void lgamma_, (NDArray& x, NDArray& z), FLOAT_TYPES);
}
}
}

View File

@ -0,0 +1,40 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author George A. Shulinok <sgazeos@gmail.com>
//
#ifndef __LIBND4J_L_GAMMA__H__
#define __LIBND4J_L_GAMMA__H__
#include <ops/declarable/helpers/helpers.h>
#include "NDArray.h"
namespace nd4j {
namespace ops {
namespace helpers {
// calculate the digamma function for each element for array
void lgamma(nd4j::LaunchContext* context, NDArray& x, NDArray& z);
}
}
}
#endif //__LIBND4J_L_GAMMA__H__

View File

@ -579,6 +579,29 @@ TEST_F(DeclarableOpsTests10, IGamma_Test2) {
delete result;
}
//////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, LGamma_Test1) {
auto x = NDArrayFactory::create<double>('c', {3, 3}, {0.1, 0.5, 0.7, 1.5, 1.7, 2.0, 2.5, 2.7, 3.});
auto exp = NDArrayFactory::create<double>('c', {3,3}, {
2.2527127 , 0.5723649 , 0.26086727,
-0.12078223, -0.09580769, 0.,
0.28468287, 0.4348206 , 0.6931472
});
nd4j::ops::lgamma op;
auto result = op.execute({&x}, {}, {}, {});
ASSERT_EQ(ND4J_STATUS_OK, result->status());
auto z = result->at(0);
// z->printBuffer("OUtput");
// exp.printBuffer("EXpect");
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
delete result;
}
//////////////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, range_test10) {

View File

@ -613,6 +613,7 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.custom.BetaInc.class,
org.nd4j.linalg.api.ops.custom.MatrixBandPart.class,
org.nd4j.linalg.api.ops.custom.Polygamma.class,
org.nd4j.linalg.api.ops.custom.Lgamma.class,
org.nd4j.linalg.api.ops.custom.RandomCrop.class,
org.nd4j.linalg.api.ops.custom.Roll.class,
org.nd4j.linalg.api.ops.custom.ToggleBits.class,

View File

@ -0,0 +1,64 @@
/* ******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.custom;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import java.util.Collections;
import java.util.List;
@NoArgsConstructor
public class Lgamma extends DynamicCustomOp {
public Lgamma(@NonNull INDArray x) {
addInputArgument(x);
}
public Lgamma(@NonNull INDArray x, INDArray output) {
this(x);
if (output != null) {
addOutputArgument(output);
}
}
public Lgamma(@NonNull SameDiff sameDiff, @NonNull SDVariable x) {
super("", sameDiff, new SDVariable[]{x});
}
@Override
public String opName() {
return "lgamma";
}
@Override
public String tensorflowName() {
return "Lgamma";
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}
}

View File

@ -1207,6 +1207,18 @@ public class CustomOpsTests extends BaseNd4jTest {
assertEquals(expected, output);
}
@Test
public void testLgamma() {
INDArray x = Nd4j.createFromArray(new double[]{0.1, 0.5, 0.7, 1.5, 1.7, 2.0, 2.5, 2.7, 3.}).reshape(3,3);
INDArray expected = Nd4j.createFromArray(new double[]{
2.2527127 , 0.5723649 , 0.26086727,
-0.12078223, -0.09580769, 0.,
0.28468287, 0.4348206 , 0.6931472
}).reshape(3,3);
INDArray[] ret = Nd4j.exec(new Lgamma(x));
assertEquals(expected, ret[0]);
}
@Test
public void testRandomCrop() {
INDArray x = Nd4j.createFromArray(new double[]{1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }).reshape(2,2,4);