Shugeo random uniform int (#30)
* Corrected randomuniform declaration. * Refactored uniform distribution for both cuda and cpu platforms. * Refactored uniform distribution and tests. * Fixed type usage with indices. * Refactored uniform distribution implementation and tests to full conform with TF implementation. * Refactored gamma function to use type util method. * Copyright changes and fixes with ConstantHelper. * Added error checking on allocate cuda device memory and operations.master
parent
df8b4e607a
commit
08853c7829
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -29,8 +30,8 @@
|
||||||
#include <dll.h>
|
#include <dll.h>
|
||||||
#include <Environment.h>
|
#include <Environment.h>
|
||||||
#include <ArrayOptions.h>
|
#include <ArrayOptions.h>
|
||||||
#include <templatemath.h>
|
//#include <templatemath.h>
|
||||||
#include <shape.h>
|
//#include <shape.h>
|
||||||
#include <helpers/logger.h>
|
#include <helpers/logger.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
|
@ -128,7 +129,9 @@ namespace nd4j {
|
||||||
// if both dtypes are the same - just return it
|
// if both dtypes are the same - just return it
|
||||||
if (typeX == typeY)
|
if (typeX == typeY)
|
||||||
return typeX;
|
return typeX;
|
||||||
|
auto nd4j_max = [](nd4j::DataType typeX, nd4j::DataType typeY) {
|
||||||
|
return typeX > typeY?typeX:typeY;
|
||||||
|
};
|
||||||
auto rX = isR(typeX);
|
auto rX = isR(typeX);
|
||||||
auto rY = isR(typeY);
|
auto rY = isR(typeY);
|
||||||
|
|
||||||
|
@ -144,7 +147,7 @@ namespace nd4j {
|
||||||
if (rX && rY) {
|
if (rX && rY) {
|
||||||
// if we allow precision boost, then we pick bigger data type
|
// if we allow precision boost, then we pick bigger data type
|
||||||
if (nd4j::Environment::getInstance()->precisionBoostAllowed()) {
|
if (nd4j::Environment::getInstance()->precisionBoostAllowed()) {
|
||||||
return nd4j::math::nd4j_max<nd4j::DataType>(typeX, typeY);
|
return nd4j_max(typeX, typeY);
|
||||||
} else {
|
} else {
|
||||||
// and we return first operand otherwise
|
// and we return first operand otherwise
|
||||||
return typeX;
|
return typeX;
|
||||||
|
@ -155,7 +158,7 @@ namespace nd4j {
|
||||||
// if that's not real type, we apply same rules
|
// if that's not real type, we apply same rules
|
||||||
if (!rX && !rY) {
|
if (!rX && !rY) {
|
||||||
if (nd4j::Environment::getInstance()->precisionBoostAllowed()) {
|
if (nd4j::Environment::getInstance()->precisionBoostAllowed()) {
|
||||||
return nd4j::math::nd4j_max<nd4j::DataType>(typeX, typeY);
|
return nd4j_max(typeX, typeY);
|
||||||
} else {
|
} else {
|
||||||
// and we return first operand otherwise
|
// and we return first operand otherwise
|
||||||
return typeX;
|
return typeX;
|
||||||
|
@ -367,8 +370,8 @@ FORCEINLINE std::string DataTypeUtils::asString(DataType dataType) {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
FORCEINLINE bool DataTypeUtils::castShapeInfo(const Nd4jLong *originalShapeInfo, T *newShapeInfo) {
|
FORCEINLINE bool DataTypeUtils::castShapeInfo(const Nd4jLong *originalShapeInfo, T *newShapeInfo) {
|
||||||
|
auto shapeInfoLength = *originalShapeInfo * 2 + 4;
|
||||||
for (int e = 0; e < shape::shapeInfoLength(originalShapeInfo); e++) {
|
for (auto e = 0; e < shapeInfoLength; e++) {
|
||||||
if (originalShapeInfo[e] < static_cast<Nd4jLong>(DataTypeUtils::max<T>())) {
|
if (originalShapeInfo[e] < static_cast<Nd4jLong>(DataTypeUtils::max<T>())) {
|
||||||
newShapeInfo[e] = static_cast<T>(originalShapeInfo[e]);
|
newShapeInfo[e] = static_cast<T>(originalShapeInfo[e]);
|
||||||
} else
|
} else
|
||||||
|
|
|
@ -1,10 +1,26 @@
|
||||||
|
/**
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// Created by raver on 5/17/2019.
|
// Created by raver on 5/17/2019.
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <array/ConstantHolder.h>
|
|
||||||
#include <DataTypeUtils.h>
|
#include <DataTypeUtils.h>
|
||||||
|
#include <array/ConstantHolder.h>
|
||||||
|
#include <shape.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
ConstantHolder::ConstantHolder(const ConstantHolder& other) {
|
ConstantHolder::ConstantHolder(const ConstantHolder& other) {
|
||||||
|
@ -24,7 +40,7 @@ namespace nd4j {
|
||||||
bool ConstantHolder::hasBuffer() {
|
bool ConstantHolder::hasBuffer() {
|
||||||
return hasBuffer(DataTypeUtils::fromT<T>());
|
return hasBuffer(DataTypeUtils::fromT<T>());
|
||||||
}
|
}
|
||||||
BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT bool ConstantHolder::hasBuffer, (), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT bool ConstantHolder::hasBuffer, (void), LIBND4J_TYPES);
|
||||||
|
|
||||||
void ConstantHolder::addBuffer(ConstantDataBuffer &pointer, nd4j::DataType dataType) {
|
void ConstantHolder::addBuffer(ConstantDataBuffer &pointer, nd4j::DataType dataType) {
|
||||||
_buffers[dataType] = pointer;
|
_buffers[dataType] = pointer;
|
||||||
|
@ -34,7 +50,7 @@ namespace nd4j {
|
||||||
void ConstantHolder::addBuffer(ConstantDataBuffer &pointer) {
|
void ConstantHolder::addBuffer(ConstantDataBuffer &pointer) {
|
||||||
addBuffer(pointer, DataTypeUtils::fromT<T>());
|
addBuffer(pointer, DataTypeUtils::fromT<T>());
|
||||||
}
|
}
|
||||||
BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void ConstantHolder::addBuffer, (ConstantDataBuffer&), LIBND4J_TYPES);
|
BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void ConstantHolder::addBuffer, (ConstantDataBuffer& cb), LIBND4J_TYPES);
|
||||||
|
|
||||||
ConstantDataBuffer* ConstantHolder::getConstantDataBuffer(nd4j::DataType dataType) {
|
ConstantDataBuffer* ConstantHolder::getConstantDataBuffer(nd4j::DataType dataType) {
|
||||||
if (!hasBuffer(dataType))
|
if (!hasBuffer(dataType))
|
||||||
|
|
|
@ -195,7 +195,21 @@ namespace nd4j {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
_CUDA_HD FORCEINLINE T RandomGenerator::relativeT(Nd4jLong index, T from, T to) {
|
_CUDA_HD FORCEINLINE T RandomGenerator::relativeT(Nd4jLong index, T from, T to) {
|
||||||
auto t = this->relativeT<T>(index);
|
auto t = this->relativeT<T>(index);
|
||||||
auto z = from + (t * (to - from));
|
auto z = from + T(t * (to - from));
|
||||||
|
return z;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
_CUDA_HD FORCEINLINE Nd4jLong RandomGenerator::relativeT(Nd4jLong index, Nd4jLong from, Nd4jLong to) {
|
||||||
|
auto t = this->relativeT<double>(index);
|
||||||
|
auto z = from + Nd4jLong(t * (to - from));
|
||||||
|
return z;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
_CUDA_HD FORCEINLINE int RandomGenerator::relativeT(Nd4jLong index, int from, int to) {
|
||||||
|
auto t = this->relativeT<float>(index);
|
||||||
|
auto z = from + float(t * (to - from));
|
||||||
return z;
|
return z;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -21,6 +22,7 @@
|
||||||
#include <exceptions/cuda_exception.h>
|
#include <exceptions/cuda_exception.h>
|
||||||
#include <ConstantHelper.h>
|
#include <ConstantHelper.h>
|
||||||
#include <DataTypeUtils.h>
|
#include <DataTypeUtils.h>
|
||||||
|
#include <shape.h>
|
||||||
#include <execution/LaunchContext.h>
|
#include <execution/LaunchContext.h>
|
||||||
#include <specials.h>
|
#include <specials.h>
|
||||||
#include <logger.h>
|
#include <logger.h>
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -23,6 +24,7 @@
|
||||||
|
|
||||||
#include <ops/declarable/CustomOperations.h>
|
#include <ops/declarable/CustomOperations.h>
|
||||||
#include <helpers/RandomLauncher.h>
|
#include <helpers/RandomLauncher.h>
|
||||||
|
#include <ops/declarable/helpers/random.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -35,41 +37,45 @@ namespace nd4j {
|
||||||
* TArgs[0] - min for rng
|
* TArgs[0] - min for rng
|
||||||
* TArgs[1] - max for rng
|
* TArgs[1] - max for rng
|
||||||
*/
|
*/
|
||||||
CUSTOM_OP_IMPL(randomuniform, 1, 1, true, 2, 0) {
|
CUSTOM_OP_IMPL(randomuniform, 1, 1, true, 0, 0) {
|
||||||
// uniform distribution
|
// uniform distribution
|
||||||
auto rng = block.randomGenerator();
|
auto rng = block.randomGenerator();
|
||||||
|
auto dtype = DataType::FLOAT32;
|
||||||
|
if (block.getIArguments()->size())
|
||||||
|
dtype = (DataType)INT_ARG(0);
|
||||||
|
|
||||||
// FIXME: to be implemented
|
auto min = block.width() > 1?INPUT_VARIABLE(1):(NDArray*)nullptr;
|
||||||
/*
|
auto max = block.width() > 2?INPUT_VARIABLE(2):(NDArray*)nullptr;
|
||||||
if (rng == nullptr)
|
|
||||||
return Status::THROW("RNG is null, aborting...");
|
|
||||||
|
|
||||||
auto x = INPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
auto z = OUTPUT_VARIABLE(0);
|
REQUIRE_TRUE(output->dataType() == dtype, 0, "RandomUniform: data type of output should be equals to given.");
|
||||||
|
|
||||||
functions::random::RandomFunction<T>::template execTransform<randomOps::UniformDistribution<T>>(block.getRNG(), z->getBuffer(), z->getShapeInfo(), block.getTArguments()->data());
|
helpers::fillRandomUniform(block.launchContext(), rng, min, max, output);
|
||||||
|
|
||||||
STORE_RESULT(*z);
|
|
||||||
*/
|
|
||||||
REQUIRE_TRUE(block.numT() > 1, 0, "RandomUniform: to/from must be set");
|
|
||||||
|
|
||||||
RandomLauncher::fillUniform(block.launchContext(), rng, OUTPUT_VARIABLE(0), T_ARG(0), T_ARG(1));
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(randomuniform) {
|
DECLARE_SHAPE_FN(randomuniform) {
|
||||||
auto in = INPUT_VARIABLE(0);
|
auto in = INPUT_VARIABLE(0);
|
||||||
|
//auto min = INPUT_VARIABLE(1);
|
||||||
auto shape = in->template asVectorT<Nd4jLong>();
|
auto shape = in->template asVectorT<Nd4jLong>();
|
||||||
|
auto dtype = DataType::FLOAT32; //ArrayOptions::dataType(inputShape->at(1)); // output type is by given min
|
||||||
|
|
||||||
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(block.dataType(), 'c', shape);
|
if (block.getIArguments()->size())
|
||||||
|
dtype = (DataType)INT_ARG(0);
|
||||||
|
if (block.width() > 1)
|
||||||
|
REQUIRE_TRUE(dtype == INPUT_VARIABLE(1)->dataType(), 0, "RandomUniform: data type of output and min/max args should be the same");
|
||||||
|
|
||||||
|
auto newShape = ConstantShapeHelper::getInstance()->createShapeInfo(dtype, 'c', shape);
|
||||||
return SHAPELIST(newShape);
|
return SHAPELIST(newShape);
|
||||||
}
|
}
|
||||||
|
|
||||||
DECLARE_TYPES(randomuniform) {
|
DECLARE_TYPES(randomuniform) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(nd4j::DataType::ANY)
|
->setAllowedInputTypes(0, {ALL_INTS})
|
||||||
->setAllowedOutputTypes({ALL_FLOATS});
|
->setAllowedInputTypes(1, {ALL_INTS, ALL_FLOATS})
|
||||||
|
->setAllowedInputTypes(2, {ALL_INTS, ALL_FLOATS})
|
||||||
|
->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -33,8 +34,20 @@ namespace nd4j {
|
||||||
DECLARE_CUSTOM_OP(get_seed, -2, 1, false, 0, 0);
|
DECLARE_CUSTOM_OP(get_seed, -2, 1, false, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/*
|
||||||
|
* random_uniform distribution for types int32,int64, float16, float and double
|
||||||
|
* by default dtype is float32
|
||||||
|
*
|
||||||
|
* input:
|
||||||
|
* 0 - shape of output (1D int tensor)
|
||||||
|
* 1 - min val (0D of output type) - optional (0 as default)
|
||||||
|
* 2 - max val (0D of output type) - optional (inf as default)
|
||||||
|
*
|
||||||
|
* output:
|
||||||
|
* 0 - uniformly distributed values of given type (between min and max)
|
||||||
|
*/
|
||||||
#if NOT_EXCLUDED(OP_randomuniform)
|
#if NOT_EXCLUDED(OP_randomuniform)
|
||||||
DECLARE_CUSTOM_OP(randomuniform, 1, 1, true, 2, 0);
|
DECLARE_CUSTOM_OP(randomuniform, 1, 1, false, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if NOT_EXCLUDED(OP_random_normal)
|
#if NOT_EXCLUDED(OP_random_normal)
|
||||||
|
@ -66,6 +79,7 @@ namespace nd4j {
|
||||||
#if NOT_EXCLUDED(OP_random_poisson)
|
#if NOT_EXCLUDED(OP_random_poisson)
|
||||||
DECLARE_CUSTOM_OP(random_poisson, 2, 1, false, 0, 0);
|
DECLARE_CUSTOM_OP(random_poisson, 2, 1, false, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -23,6 +23,7 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
//#include <graph/Context.h>
|
//#include <graph/Context.h>
|
||||||
#include <ShapeUtils.h>
|
#include <ShapeUtils.h>
|
||||||
|
#include <helpers/RandomLauncher.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -127,6 +128,32 @@ namespace helpers {
|
||||||
BUILD_SINGLE_TEMPLATE(template void fillRandomPoisson_, (LaunchContext* context,
|
BUILD_SINGLE_TEMPLATE(template void fillRandomPoisson_, (LaunchContext* context,
|
||||||
graph::RandomGenerator& rng, NDArray* lambda, NDArray* output), FLOAT_TYPES);
|
graph::RandomGenerator& rng, NDArray* lambda, NDArray* output), FLOAT_TYPES);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void fillRandomUniform_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
|
||||||
|
T minVal = T(0);
|
||||||
|
T maxVal = DataTypeUtils::infOrMax<T>();
|
||||||
|
if (min)
|
||||||
|
minVal = min->t<T>(0);
|
||||||
|
if (max)
|
||||||
|
maxVal = max->t<T>(0);
|
||||||
|
|
||||||
|
if (output->isR())
|
||||||
|
RandomLauncher::fillUniform(context, rng, output, minVal, maxVal);
|
||||||
|
else {
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR
|
||||||
|
for (auto i = 0; i < output->lengthOf(); i++) {
|
||||||
|
output->t<T>(i) = rng.relativeT<T>(i, minVal, maxVal);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
|
||||||
|
BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, (context, rng, min, max, output), NUMERIC_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void fillRandomUniform_, (LaunchContext* context,
|
||||||
|
graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output), NUMERIC_TYPES);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -91,7 +92,7 @@ namespace helpers {
|
||||||
val = nd4j::math::nd4j_min<T>(val, input->t<T>(e));
|
val = nd4j::math::nd4j_min<T>(val, input->t<T>(e));
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
idx = indices->e<int>(e);
|
idx = indices->e<Nd4jLong>(e);
|
||||||
val = input->t<T>(e);
|
val = input->t<T>(e);
|
||||||
}
|
}
|
||||||
output->t<T>(idx) = val;
|
output->t<T>(idx) = val;
|
||||||
|
@ -111,14 +112,14 @@ namespace helpers {
|
||||||
minT->assign(listOfTensors->at(0));
|
minT->assign(listOfTensors->at(0));
|
||||||
|
|
||||||
for (Nd4jLong i = 1; i < indices->lengthOf(); i++) {
|
for (Nd4jLong i = 1; i < indices->lengthOf(); i++) {
|
||||||
if (indices->e<T>(i) == idx) {
|
if (indices->e<Nd4jLong>(i) == idx) {
|
||||||
|
|
||||||
for (int e = 0; e < minT->lengthOf(); e++) {
|
for (int e = 0; e < minT->lengthOf(); e++) {
|
||||||
minT->p(e, nd4j::math::nd4j_min(minT->e<T>(e), listOfTensors->at(i)->e<T>(e)));
|
minT->p(e, nd4j::math::nd4j_min(minT->e<T>(e), listOfTensors->at(i)->e<T>(e)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
idx = indices->e<T>(i);
|
idx = indices->e<Nd4jLong>(i);
|
||||||
minT = listOfOutTensors->at(idx);
|
minT = listOfOutTensors->at(idx);
|
||||||
minT->assign(listOfTensors->at(i));
|
minT->assign(listOfTensors->at(i));
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -26,6 +26,7 @@
|
||||||
#include <helpers/RandomLauncher.h>
|
#include <helpers/RandomLauncher.h>
|
||||||
#include <ShapeUtils.h>
|
#include <ShapeUtils.h>
|
||||||
#include <NDArrayFactory.h>
|
#include <NDArrayFactory.h>
|
||||||
|
#include <cuda_exception.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -181,6 +182,72 @@ namespace helpers {
|
||||||
}
|
}
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void fillRandomPoisson_, (LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output), FLOAT_NATIVE);
|
BUILD_SINGLE_TEMPLATE(template void fillRandomPoisson_, (LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output), FLOAT_NATIVE);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static __global__ void fillUniformKernel(graph::RandomGenerator* devRng, T from, T to, T* output, Nd4jLong* outputShape) {
|
||||||
|
auto start = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
auto step = blockDim.x * gridDim.x;
|
||||||
|
|
||||||
|
__shared__ Nd4jLong outputLen;
|
||||||
|
|
||||||
|
if (0 == threadIdx.x) {
|
||||||
|
outputLen = shape::length(outputShape);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (auto i = start; i < outputLen; i += step) {
|
||||||
|
auto zIndex = shape::getIndexOffset(i, outputShape);
|
||||||
|
output[zIndex] = devRng->relativeT<T>(i, from, to);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static void fillRandomUniform_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
|
||||||
|
T minVal = T(0);
|
||||||
|
T maxVal = DataTypeUtils::infOrMax<T>();
|
||||||
|
if (min)
|
||||||
|
minVal = min->t<T>(0);
|
||||||
|
if (max)
|
||||||
|
maxVal = max->t<T>(0);
|
||||||
|
|
||||||
|
if (output->isR())
|
||||||
|
RandomLauncher::fillUniform(context, rng, output, minVal, maxVal);
|
||||||
|
else {
|
||||||
|
auto stream = context->getCudaStream();
|
||||||
|
graph::RandomGenerator *devRng;
|
||||||
|
auto err = cudaMalloc(&devRng, sizeof(graph::RandomGenerator));
|
||||||
|
if (err != 0) {
|
||||||
|
cuda_exception::build("fillRandomUniform_: Cannot allocate device memory for random generator due error", err);
|
||||||
|
}
|
||||||
|
|
||||||
|
err = cudaMemcpy(devRng, &rng, sizeof(graph::RandomGenerator), cudaMemcpyHostToDevice);
|
||||||
|
if (err != 0) {
|
||||||
|
cuda_exception::build("fillRandomUniform_: Cannot copy random generator to device", err);
|
||||||
|
}
|
||||||
|
auto outputBuf = output->dataBuffer()->specialAsT<T>();
|
||||||
|
auto outputShape = output->specialShapeInfo();
|
||||||
|
fillUniformKernel<T><<<128, 128, 128, *stream>>>(devRng, minVal, maxVal, outputBuf, outputShape);
|
||||||
|
|
||||||
|
err = cudaStreamSynchronize(*stream);
|
||||||
|
if (err != 0) {
|
||||||
|
cuda_exception::build("fillRandomUniform_: Cannot successfully finish kernel call", err);
|
||||||
|
}
|
||||||
|
|
||||||
|
err = cudaFree(devRng);
|
||||||
|
if (err != 0) {
|
||||||
|
cuda_exception::build("fillRandomUniform_: Cannot deallocate device memory for random generator", err);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
|
||||||
|
BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, (context, rng, min, max, output), NUMERIC_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
BUILD_SINGLE_TEMPLATE(template void fillRandomUniform_, (LaunchContext* context,
|
||||||
|
graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output), NUMERIC_TYPES);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -1,5 +1,5 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -33,7 +33,7 @@ namespace helpers {
|
||||||
|
|
||||||
void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output);
|
void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output);
|
||||||
void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output);
|
void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output);
|
||||||
|
void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -273,8 +274,8 @@ public:
|
||||||
BlockInformation(Nd4jLong length, int threshold) {
|
BlockInformation(Nd4jLong length, int threshold) {
|
||||||
|
|
||||||
threads = length / threshold;
|
threads = length / threshold;
|
||||||
threads = nd4j::math::nd4j_max<int>(1, threads);
|
threads = (1 < threads)?threads:1;//nd4j::math::nd4j_max<int>(1, threads);
|
||||||
threads = nd4j::math::nd4j_min<int>(threads, omp_get_max_threads());
|
threads = (threads < omp_get_max_threads())?threads:omp_get_max_threads();//nd4j::math::nd4j_min<int>(threads, omp_get_max_threads());
|
||||||
|
|
||||||
items = length / threads;
|
items = length / threads;
|
||||||
remainder = length % threads;
|
remainder = length % threads;
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -27,7 +28,7 @@
|
||||||
#include <dll.h>
|
#include <dll.h>
|
||||||
#include <pointercast.h>
|
#include <pointercast.h>
|
||||||
#include <platformmath.h>
|
#include <platformmath.h>
|
||||||
|
#include <DataTypeUtils.h>
|
||||||
|
|
||||||
#define BFLOAT16_MAX_VALUE 32737.
|
#define BFLOAT16_MAX_VALUE 32737.
|
||||||
#define HALF_MAX_VALUE 65504.
|
#define HALF_MAX_VALUE 65504.
|
||||||
|
@ -883,7 +884,7 @@ namespace nd4j {
|
||||||
if (a > 171.624) {
|
if (a > 171.624) {
|
||||||
// Correct answer too large to display. Force +infinity.
|
// Correct answer too large to display. Force +infinity.
|
||||||
return Z(DOUBLE_MAX_VALUE);
|
return Z(DOUBLE_MAX_VALUE);
|
||||||
//DataTypeUtils::infOrMax<Z>();
|
// return DataTypeUtils::infOrMax<Z>();
|
||||||
}
|
}
|
||||||
|
|
||||||
return nd4j::math::nd4j_exp<Z,Z>(nd4j::math::nd4j_lgamma<X,Z>(a));
|
return nd4j::math::nd4j_exp<Z,Z>(nd4j::math::nd4j_lgamma<X,Z>(a));
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
*
|
*
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
@ -855,18 +856,39 @@ TEST_F(RNGTests, Test_GammaDistribution_3) {
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(RNGTests, Test_UniformDistribution_04) {
|
||||||
|
auto x = NDArrayFactory::create<Nd4jLong>('c', {1}, {10});
|
||||||
|
auto al = NDArrayFactory::create<int>(1);
|
||||||
|
auto be = NDArrayFactory::create<int>(20);
|
||||||
|
auto exp0 = NDArrayFactory::create<float>('c', {10});
|
||||||
|
|
||||||
|
|
||||||
|
nd4j::ops::randomuniform op;
|
||||||
|
auto result = op.execute({&x, &al, &be}, {}, {DataType::INT32});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
z->printIndexedBuffer("Uniform int distribution");
|
||||||
|
ASSERT_TRUE(exp0.isSameShape(z));
|
||||||
|
ASSERT_FALSE(exp0.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace tests {
|
namespace tests {
|
||||||
static void fillList(Nd4jLong seed, int numberOfArrays, std::vector<Nd4jLong> &shape, std::vector<NDArray*> &list, nd4j::graph::RandomGenerator *rng) {
|
static void fillList(Nd4jLong seed, int numberOfArrays, std::vector<Nd4jLong> &shape, std::vector<NDArray*> &list, nd4j::graph::RandomGenerator *rng) {
|
||||||
rng->setSeed((int) seed);
|
rng->setSeed((int) seed);
|
||||||
|
|
||||||
for (int i = 0; i < numberOfArrays; i++) {
|
for (int i = 0; i < numberOfArrays; i++) {
|
||||||
auto array = NDArrayFactory::create_<double>('c', shape);
|
auto arrayI = NDArrayFactory::create<Nd4jLong>(shape);
|
||||||
|
auto arrayR = NDArrayFactory::create_<double>('c', shape);
|
||||||
|
auto min = NDArrayFactory::create(0.0);
|
||||||
|
auto max = NDArrayFactory::create(1.0);
|
||||||
nd4j::ops::randomuniform op;
|
nd4j::ops::randomuniform op;
|
||||||
op.execute(*rng, {array}, {array}, {0.0, 1.0}, {}, {}, true);
|
op.execute(*rng, {&arrayI, &min, &max}, {arrayR}, {}, {DataType::DOUBLE}, {}, false);
|
||||||
|
|
||||||
list.emplace_back(array);
|
list.emplace_back(arrayR);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue