Shugeo lgamma (#170)
* lgamma op. Initial version. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored lgamma op and test. Signed-off-by: shugeo <sgazeos@gmail.com> * Lgamma wrapper * Added TF mapping Co-authored-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>master
parent
cae5ef4180
commit
6943a5f57a
|
@ -0,0 +1,50 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author George A. Shulinok <sgazeos@gmail.com>
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_lgamma)
|
||||
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <ops/declarable/helpers/lgamma.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
|
||||
OP_IMPL(lgamma, 1, 1, true) {
|
||||
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto z = OUTPUT_VARIABLE(0);
|
||||
|
||||
helpers::lgamma(block.launchContext(), *x, *z);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_TYPES(lgamma) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes({ALL_FLOATS}) // as TF says
|
||||
->setSameMode(true);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -527,6 +527,20 @@ namespace nd4j {
|
|||
DECLARE_CONFIGURABLE_OP(polygamma, 2, 1, false, 0, 0);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* This op calculates lgamma function lgamma(x) = log(Gamma(x))
|
||||
*
|
||||
* Input arrays:
|
||||
* 0: x - input matrix
|
||||
*
|
||||
* Output array:
|
||||
* 0: log of Gamma(x)
|
||||
*
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_lgamma)
|
||||
DECLARE_OP(lgamma, 1, 1, true);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* This op calculates digamma function psi(x) = derivative of log(Gamma(x))
|
||||
*
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author George A. Shulinok <sgazeos@gmail.com>
|
||||
//
|
||||
|
||||
#include<ops/declarable/helpers/lgamma.h>
|
||||
#include <execution/Threads.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// calculate digamma function for array elements
|
||||
template <typename T>
|
||||
static void lgamma_(NDArray& x, NDArray& z) {
|
||||
|
||||
auto lgammaProc = LAMBDA_T(x_) {
|
||||
return T(DataTypeUtils::fromT<T>() == DataType::DOUBLE?::lgamma(x_): ::lgammaf(x_)); //math::nd4j_log<T,T>(math::nd4j_gamma<T,T>(x));
|
||||
};
|
||||
|
||||
x.applyLambda<T>(lgammaProc, z);
|
||||
}
|
||||
|
||||
void lgamma(nd4j::LaunchContext* context, NDArray& x, NDArray& z) {
|
||||
|
||||
BUILD_SINGLE_SELECTOR(x.dataType(), lgamma_, (x, z), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void lgamma_, (NDArray& x, NDArray& z), FLOAT_TYPES);
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2019 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author George A. Shulinok <sgazeos@gmail.com>
|
||||
//
|
||||
|
||||
#include<ops/declarable/helpers/lgamma.h>
|
||||
//#include <execution/Threads.h>
|
||||
//#include <helper_math.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// calculate digamma function for array elements
|
||||
template <typename T>
|
||||
static void lgamma_(NDArray& x, NDArray& z) {
|
||||
//auto dtype = x.dataType();
|
||||
auto lgammaProc = LAMBDA_T(x_, dtype) {
|
||||
return T(DataTypeUtils::fromT<T>() == DataType::DOUBLE?::lgamma(x_): ::lgammaf(x_)); //math::nd4j_log<T,T>(math::nd4j_gamma<T,T>(x));
|
||||
};
|
||||
|
||||
x.applyLambda(lgammaProc, z);
|
||||
}
|
||||
|
||||
void lgamma(nd4j::LaunchContext* context, NDArray& x, NDArray& z) {
|
||||
|
||||
BUILD_SINGLE_SELECTOR(x.dataType(), lgamma_, (x, z), FLOAT_TYPES);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void lgamma_, (NDArray& x, NDArray& z), FLOAT_TYPES);
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,40 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author George A. Shulinok <sgazeos@gmail.com>
|
||||
//
|
||||
|
||||
#ifndef __LIBND4J_L_GAMMA__H__
|
||||
#define __LIBND4J_L_GAMMA__H__
|
||||
|
||||
#include <ops/declarable/helpers/helpers.h>
|
||||
#include "NDArray.h"
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
// calculate the digamma function for each element for array
|
||||
void lgamma(nd4j::LaunchContext* context, NDArray& x, NDArray& z);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#endif //__LIBND4J_L_GAMMA__H__
|
|
@ -579,6 +579,29 @@ TEST_F(DeclarableOpsTests10, IGamma_Test2) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, LGamma_Test1) {
|
||||
|
||||
auto x = NDArrayFactory::create<double>('c', {3, 3}, {0.1, 0.5, 0.7, 1.5, 1.7, 2.0, 2.5, 2.7, 3.});
|
||||
|
||||
auto exp = NDArrayFactory::create<double>('c', {3,3}, {
|
||||
2.2527127 , 0.5723649 , 0.26086727,
|
||||
-0.12078223, -0.09580769, 0.,
|
||||
0.28468287, 0.4348206 , 0.6931472
|
||||
});
|
||||
|
||||
nd4j::ops::lgamma op;
|
||||
auto result = op.execute({&x}, {}, {}, {});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result->status());
|
||||
auto z = result->at(0);
|
||||
// z->printBuffer("OUtput");
|
||||
// exp.printBuffer("EXpect");
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
|
||||
delete result;
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, range_test10) {
|
||||
|
||||
|
|
|
@ -613,6 +613,7 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.custom.BetaInc.class,
|
||||
org.nd4j.linalg.api.ops.custom.MatrixBandPart.class,
|
||||
org.nd4j.linalg.api.ops.custom.Polygamma.class,
|
||||
org.nd4j.linalg.api.ops.custom.Lgamma.class,
|
||||
org.nd4j.linalg.api.ops.custom.RandomCrop.class,
|
||||
org.nd4j.linalg.api.ops.custom.Roll.class,
|
||||
org.nd4j.linalg.api.ops.custom.ToggleBits.class,
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
/* ******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
package org.nd4j.linalg.api.ops.custom;
|
||||
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.autodiff.samediff.SDVariable;
|
||||
import org.nd4j.autodiff.samediff.SameDiff;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
@NoArgsConstructor
|
||||
public class Lgamma extends DynamicCustomOp {
|
||||
|
||||
public Lgamma(@NonNull INDArray x) {
|
||||
addInputArgument(x);
|
||||
}
|
||||
|
||||
public Lgamma(@NonNull INDArray x, INDArray output) {
|
||||
this(x);
|
||||
if (output != null) {
|
||||
addOutputArgument(output);
|
||||
}
|
||||
}
|
||||
|
||||
public Lgamma(@NonNull SameDiff sameDiff, @NonNull SDVariable x) {
|
||||
super("", sameDiff, new SDVariable[]{x});
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "lgamma";
|
||||
}
|
||||
|
||||
@Override
|
||||
public String tensorflowName() {
|
||||
return "Lgamma";
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||
int n = args().length;
|
||||
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
|
||||
return Collections.singletonList(inputDataTypes.get(0));
|
||||
}
|
||||
}
|
|
@ -1207,6 +1207,18 @@ public class CustomOpsTests extends BaseNd4jTest {
|
|||
assertEquals(expected, output);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testLgamma() {
|
||||
INDArray x = Nd4j.createFromArray(new double[]{0.1, 0.5, 0.7, 1.5, 1.7, 2.0, 2.5, 2.7, 3.}).reshape(3,3);
|
||||
INDArray expected = Nd4j.createFromArray(new double[]{
|
||||
2.2527127 , 0.5723649 , 0.26086727,
|
||||
-0.12078223, -0.09580769, 0.,
|
||||
0.28468287, 0.4348206 , 0.6931472
|
||||
}).reshape(3,3);
|
||||
INDArray[] ret = Nd4j.exec(new Lgamma(x));
|
||||
assertEquals(expected, ret[0]);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRandomCrop() {
|
||||
INDArray x = Nd4j.createFromArray(new double[]{1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }).reshape(2,2,4);
|
||||
|
|
Loading…
Reference in New Issue