diff --git a/libnd4j/include/ops/declarable/generic/parity_ops/lgamma.cpp b/libnd4j/include/ops/declarable/generic/parity_ops/lgamma.cpp new file mode 100644 index 000000000..615190c2f --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/parity_ops/lgamma.cpp @@ -0,0 +1,50 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author George A. Shulinok +// + +#include +#if NOT_EXCLUDED(OP_lgamma) + +#include +#include + +namespace nd4j { +namespace ops { + +OP_IMPL(lgamma, 1, 1, true) { + + auto x = INPUT_VARIABLE(0); + auto z = OUTPUT_VARIABLE(0); + + helpers::lgamma(block.launchContext(), *x, *z); + + return Status::OK(); +} + +DECLARE_TYPES(lgamma) { + getOpDescriptor() + ->setAllowedInputTypes({ALL_FLOATS}) // as TF says + ->setSameMode(true); +} + +} +} + +#endif \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/headers/parity_ops.h b/libnd4j/include/ops/declarable/headers/parity_ops.h index cbaae52f7..c0bf2ab60 100644 --- a/libnd4j/include/ops/declarable/headers/parity_ops.h +++ b/libnd4j/include/ops/declarable/headers/parity_ops.h @@ -527,6 +527,20 @@ namespace nd4j { DECLARE_CONFIGURABLE_OP(polygamma, 2, 1, false, 0, 0); #endif + /** + * This op calculates lgamma function lgamma(x) = log(Gamma(x)) + * + * Input arrays: + * 0: x - input matrix + * + * Output array: + * 0: log of Gamma(x) + * + */ + #if NOT_EXCLUDED(OP_lgamma) + DECLARE_OP(lgamma, 1, 1, true); + #endif + /** * This op calculates digamma function psi(x) = derivative of log(Gamma(x)) * diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lgamma.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lgamma.cpp new file mode 100644 index 000000000..2978a9d45 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/lgamma.cpp @@ -0,0 +1,53 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author George A. Shulinok +// + +#include +#include + +namespace nd4j { +namespace ops { +namespace helpers { + +////////////////////////////////////////////////////////////////////////// +// calculate digamma function for array elements +template +static void lgamma_(NDArray& x, NDArray& z) { + + auto lgammaProc = LAMBDA_T(x_) { + return T(DataTypeUtils::fromT() == DataType::DOUBLE?::lgamma(x_): ::lgammaf(x_)); //math::nd4j_log(math::nd4j_gamma(x)); + }; + + x.applyLambda(lgammaProc, z); +} + +void lgamma(nd4j::LaunchContext* context, NDArray& x, NDArray& z) { + + BUILD_SINGLE_SELECTOR(x.dataType(), lgamma_, (x, z), FLOAT_TYPES); +} + +BUILD_SINGLE_TEMPLATE(template void lgamma_, (NDArray& x, NDArray& z), FLOAT_TYPES); + + + +} +} +} + diff --git a/libnd4j/include/ops/declarable/helpers/cuda/lgamma.cu b/libnd4j/include/ops/declarable/helpers/cuda/lgamma.cu new file mode 100644 index 000000000..9b749c6e2 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/lgamma.cu @@ -0,0 +1,54 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author George A. Shulinok +// + +#include +//#include +//#include + +namespace nd4j { +namespace ops { +namespace helpers { + +////////////////////////////////////////////////////////////////////////// +// calculate digamma function for array elements +template +static void lgamma_(NDArray& x, NDArray& z) { + //auto dtype = x.dataType(); + auto lgammaProc = LAMBDA_T(x_, dtype) { + return T(DataTypeUtils::fromT() == DataType::DOUBLE?::lgamma(x_): ::lgammaf(x_)); //math::nd4j_log(math::nd4j_gamma(x)); + }; + + x.applyLambda(lgammaProc, z); +} + +void lgamma(nd4j::LaunchContext* context, NDArray& x, NDArray& z) { + + BUILD_SINGLE_SELECTOR(x.dataType(), lgamma_, (x, z), FLOAT_TYPES); +} + +BUILD_SINGLE_TEMPLATE(template void lgamma_, (NDArray& x, NDArray& z), FLOAT_TYPES); + + + +} +} +} + diff --git a/libnd4j/include/ops/declarable/helpers/lgamma.h b/libnd4j/include/ops/declarable/helpers/lgamma.h new file mode 100644 index 000000000..48bcf1d73 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/lgamma.h @@ -0,0 +1,40 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author George A. Shulinok +// + +#ifndef __LIBND4J_L_GAMMA__H__ +#define __LIBND4J_L_GAMMA__H__ + +#include +#include "NDArray.h" + +namespace nd4j { +namespace ops { +namespace helpers { + + // calculate the digamma function for each element for array + void lgamma(nd4j::LaunchContext* context, NDArray& x, NDArray& z); + +} +} +} + + +#endif //__LIBND4J_L_GAMMA__H__ diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp index f9382f6c7..30c645785 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests10.cpp @@ -579,6 +579,29 @@ TEST_F(DeclarableOpsTests10, IGamma_Test2) { delete result; } +////////////////////////////////////////////////////////////////////////////// +TEST_F(DeclarableOpsTests10, LGamma_Test1) { + + auto x = NDArrayFactory::create('c', {3, 3}, {0.1, 0.5, 0.7, 1.5, 1.7, 2.0, 2.5, 2.7, 3.}); + + auto exp = NDArrayFactory::create('c', {3,3}, { + 2.2527127 , 0.5723649 , 0.26086727, + -0.12078223, -0.09580769, 0., + 0.28468287, 0.4348206 , 0.6931472 + }); + + nd4j::ops::lgamma op; + auto result = op.execute({&x}, {}, {}, {}); + ASSERT_EQ(ND4J_STATUS_OK, result->status()); + auto z = result->at(0); +// z->printBuffer("OUtput"); +// exp.printBuffer("EXpect"); + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + + delete result; +} + ////////////////////////////////////////////////////////////////////////////// TEST_F(DeclarableOpsTests10, range_test10) { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 1693b80b5..e85c472c8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -613,6 +613,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.custom.BetaInc.class, org.nd4j.linalg.api.ops.custom.MatrixBandPart.class, org.nd4j.linalg.api.ops.custom.Polygamma.class, + org.nd4j.linalg.api.ops.custom.Lgamma.class, org.nd4j.linalg.api.ops.custom.RandomCrop.class, org.nd4j.linalg.api.ops.custom.Roll.class, org.nd4j.linalg.api.ops.custom.ToggleBits.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Lgamma.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Lgamma.java new file mode 100644 index 000000000..3df488120 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/custom/Lgamma.java @@ -0,0 +1,64 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.custom; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class Lgamma extends DynamicCustomOp { + + public Lgamma(@NonNull INDArray x) { + addInputArgument(x); + } + + public Lgamma(@NonNull INDArray x, INDArray output) { + this(x); + if (output != null) { + addOutputArgument(output); + } + } + + public Lgamma(@NonNull SameDiff sameDiff, @NonNull SDVariable x) { + super("", sameDiff, new SDVariable[]{x}); + } + + @Override + public String opName() { + return "lgamma"; + } + + @Override + public String tensorflowName() { + return "Lgamma"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + int n = args().length; + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); + } +} diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java index 0ae56350d..15b9454da 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/custom/CustomOpsTests.java @@ -1207,6 +1207,18 @@ public class CustomOpsTests extends BaseNd4jTest { assertEquals(expected, output); } + @Test + public void testLgamma() { + INDArray x = Nd4j.createFromArray(new double[]{0.1, 0.5, 0.7, 1.5, 1.7, 2.0, 2.5, 2.7, 3.}).reshape(3,3); + INDArray expected = Nd4j.createFromArray(new double[]{ + 2.2527127 , 0.5723649 , 0.26086727, + -0.12078223, -0.09580769, 0., + 0.28468287, 0.4348206 , 0.6931472 + }).reshape(3,3); + INDArray[] ret = Nd4j.exec(new Lgamma(x)); + assertEquals(expected, ret[0]); + } + @Test public void testRandomCrop() { INDArray x = Nd4j.createFromArray(new double[]{1.8, 2.5, 4., 9., 2.1, 2.4, 3., 9.,2.1, 2.1, 0.7, 0.1,3., 4.2, 2.2, 1. }).reshape(2,2,4);