Oleh multinomial (#163)
* libnd4j: Multinomial op #8570 first raw step of multinomial random data generator implementation Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j: Multinomial op #8570 next step of multinomial random categories generator implementation on both cpu and cuda, need corrections and code clean up before review and testing * libnd4j: Multinomial op #8570 code clean up and fixed issues data selecting, moved from coords to tads * libnd4j: Multinomial op #8570 fixed cuda build add reference for math materials that was used for implementation * libnd4j: Multinomial op #8570 fixed several bugs, added several tests and improved cuda version. current implementation works, need testing of reproduction with the same seed * libnd4j: Multinomial op #8570 fixes and optimization after discussion in both cuda and cpu * libnd4j: Multinomial op #8570 add corrections after review, removed tads, replace 2D parallel loop by 3D Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j: Multinomial op fixed declaration and add tests need discussion * libnd4j: Multinomial op fix in test * libnd4j: Multinomial op corrected behavior to get reproducible results, fixed issue in uniform value getting, tests added, need cuda review and cuda testing Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j: Multinomial op fixed indexing on uniform calculation Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j: Multinomial op some corrections in max min declaration Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j: Multinomial op fixed index calculation, added rewind, corrected input declaration, added stats tests, both cuda and cpu. cuda need testing * libnd4j: Multinomial op fixed bugs on cuda nad cpu. need review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j: Multinomial op corrected tests to handle different orders Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j: Multinomial op some improvements after code review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j: Multinomial op more corrections after review Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j: Multinomial op fixed seed usage, update tests, fixed cuda based on comments, fixed bug of rewind, removed one behavior, minor corrections. Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j: Multinomial op minor corrections Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j: Multinomial op rise the bound of fluctuation for random cases Signed-off-by: Oleg <oleg.semeniv@gmail.com> * libnd4j: Multinomial op modified operation inputs and update implementation and tests on both cpu and cuda * libnd4j: Multinomial op corrected data types according ops.proto Co-authored-by: raver119 <raver119@gmail.com>master
parent
bb86bbc255
commit
2404be5fe0
|
@ -0,0 +1,114 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author Oleh Semeniv (oleg.semeniv@gmail.com)
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#if NOT_EXCLUDED(OP_random_multinomial)
|
||||||
|
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <helpers/RandomLauncher.h>
|
||||||
|
#include <ops/declarable/helpers/random.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
///////////////////////
|
||||||
|
/**
|
||||||
|
* multinomial (categorical) random generator
|
||||||
|
* takes 2D ndarray with logits with shape [batch_size (N), num_classes (K)]
|
||||||
|
* and array with one scalar value of samples number, number of independent samples to draw for each experiment 1,N.
|
||||||
|
* represents the unnormalized log-probabilities for all classes.
|
||||||
|
* Int arguments: 0 - optional argument, corresponds to dimension with batch_size
|
||||||
|
* Int arguments: 1 - optional argument, integer type to use for the output. Default int64.
|
||||||
|
*/
|
||||||
|
// used https://en.wikipedia.org/wiki/Categorical_distribution
|
||||||
|
// methods: gumbel trick + softmax + argmax
|
||||||
|
CUSTOM_OP_IMPL(random_multinomial, 2, 1, false, 0, 0) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
auto inputSamples = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
|
||||||
|
REQUIRE_TRUE(!input->isEmpty(), 0, "RANDOM_MULTINOMIAL OP: Have to be provided at least one logits. ");
|
||||||
|
|
||||||
|
REQUIRE_TRUE(inputSamples->lengthOf() == 1, 0, "RANDOM_MULTINOMIAL OP: Have to be specified at least one sample,"
|
||||||
|
" but got no argumets instead.");
|
||||||
|
|
||||||
|
Nd4jLong numOfSamples = static_cast<Nd4jLong>(inputSamples->e<int>(0));
|
||||||
|
// do nothing if number of samples = 0
|
||||||
|
if (0 == numOfSamples)
|
||||||
|
return Status::OK();
|
||||||
|
|
||||||
|
REQUIRE_TRUE(numOfSamples > 0, 0, "RANDOM_MULTINOMIAL OP: Number of samples should be greater then 0, got %i. ", numOfSamples);
|
||||||
|
|
||||||
|
const int rank = input->rankOf();
|
||||||
|
REQUIRE_TRUE(rank == 2, 0, "RANDOM_MULTINOMIAL OP: Logits should be a matrix with rank = 2, but got instead rank = %i.", rank);
|
||||||
|
|
||||||
|
const int argSize = block.getIArguments()->size();
|
||||||
|
const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
|
||||||
|
|
||||||
|
auto dimA = (0 == dimC) ? 1 : 0;
|
||||||
|
if (1 == input->sizeAt(dimA)) {
|
||||||
|
*output = 0;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto rng = block.randomGenerator();
|
||||||
|
helpers::fillRandomMultiNomial(block.launchContext(), rng, *input, *output, numOfSamples, dimC);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(random_multinomial) {
|
||||||
|
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto inputSamples = INPUT_VARIABLE(1);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(inputSamples->lengthOf() == 1, 0, "RANDOM_MULTINOMIAL OP: Have to be specified at least one sample,"
|
||||||
|
" but got no argumets instead.");
|
||||||
|
|
||||||
|
Nd4jLong numOfSamples = static_cast<Nd4jLong>(inputSamples->e<int>(0));
|
||||||
|
|
||||||
|
REQUIRE_TRUE(numOfSamples > 0, 0, "RANDOM_MULTINOMIAL OP: Number of samples should be greater then 0, got %i. ", numOfSamples);
|
||||||
|
|
||||||
|
const int rank = input->rankOf();
|
||||||
|
REQUIRE_TRUE(rank == 2, 0, "RANDOM_MULTINOMIAL OP: Logits should be a matrix with rank = 2, but got instead rank = %i.", rank);
|
||||||
|
|
||||||
|
const int argSize = block.getIArguments()->size();
|
||||||
|
const int dimC = argSize > 0 ? (INT_ARG(0) >= 0 ? INT_ARG(0) : INT_ARG(0) + rank) : rank - 1;
|
||||||
|
|
||||||
|
auto nShape = input->getShapeAsVector();
|
||||||
|
auto dimA = (0 == dimC) ? 1 : 0;
|
||||||
|
nShape[dimA] = numOfSamples;
|
||||||
|
|
||||||
|
DataType nType = (argSize > 1) ? ( INT_ARG(1) >= 0 ? static_cast<DataType>(INT_ARG(1)) : nd4j::DataType::INT64) : nd4j::DataType::INT64;
|
||||||
|
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(nType, input->ordering(), nShape));
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(random_multinomial) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes(0, { ALL_FLOATS, ALL_INTS })
|
||||||
|
->setAllowedInputTypes(1, { nd4j::DataType::INT32 })
|
||||||
|
->setAllowedOutputTypes(0, { ALL_INDICES });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -49,6 +49,22 @@ namespace nd4j {
|
||||||
#if NOT_EXCLUDED(OP_randomuniform)
|
#if NOT_EXCLUDED(OP_randomuniform)
|
||||||
DECLARE_CUSTOM_OP(randomuniform, 1, 1, false, 0, 0);
|
DECLARE_CUSTOM_OP(randomuniform, 1, 1, false, 0, 0);
|
||||||
#endif
|
#endif
|
||||||
|
/*
|
||||||
|
* multinomial (categorical) random generator draws samples from a multinomial distribution
|
||||||
|
*
|
||||||
|
* Input array:
|
||||||
|
* 0 - 2D ndarray with unnormalized log-probabilities with shape [batch_size (N), num_classes (K)]
|
||||||
|
* 1 - array with one int value of samples number, number of independent samples to draw for each experiment 1,N.
|
||||||
|
* Int arguments:
|
||||||
|
* 0 - optional argument, corresponds to dimension with batch_size
|
||||||
|
* 1 - optional argument, integer type to use for the output. Default int64.
|
||||||
|
*
|
||||||
|
* Output array:
|
||||||
|
* 0 - 2D ndarray with the drawn samples of shape [batch_size, num_samples]
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_random_multinomial)
|
||||||
|
DECLARE_CUSTOM_OP(random_multinomial, 2, 1, false, 0, 0);
|
||||||
|
#endif
|
||||||
|
|
||||||
#if NOT_EXCLUDED(OP_random_normal)
|
#if NOT_EXCLUDED(OP_random_normal)
|
||||||
DECLARE_CUSTOM_OP(random_normal, 1, 1, true, 2, 0);
|
DECLARE_CUSTOM_OP(random_normal, 1, 1, true, 2, 0);
|
||||||
|
|
|
@ -24,6 +24,8 @@
|
||||||
//#include <graph/Context.h>
|
//#include <graph/Context.h>
|
||||||
#include <ShapeUtils.h>
|
#include <ShapeUtils.h>
|
||||||
#include <helpers/RandomLauncher.h>
|
#include <helpers/RandomLauncher.h>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -150,6 +152,61 @@ namespace helpers {
|
||||||
void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
|
void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
|
||||||
BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, (context, rng, min, max, output), NUMERIC_TYPES);
|
BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, (context, rng, min, max, output), NUMERIC_TYPES);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// used https://en.wikipedia.org/wiki/Categorical_distribution
|
||||||
|
// methods: gumbel trick + softmax + argmax
|
||||||
|
template <typename Tx, typename Tz>
|
||||||
|
void fillRandomMultiNomial_(LaunchContext* context, graph::RandomGenerator& rng, NDArray& input, NDArray& output, const Nd4jLong numOfSamples, const int dimC) {
|
||||||
|
|
||||||
|
const Tx* x = input.bufferAsT<Tx>();
|
||||||
|
Tz* z = output.bufferAsT<Tz>();
|
||||||
|
|
||||||
|
Tx minVal = DataTypeUtils::min<Tx>();
|
||||||
|
Tx maxVal = 1.0;
|
||||||
|
|
||||||
|
auto dimA = (0 == dimC) ? 1 : 0;
|
||||||
|
const Nd4jLong batchValue = output.sizeAt(dimC);
|
||||||
|
const Nd4jLong numOfClassX = input.sizeAt(dimA);
|
||||||
|
|
||||||
|
const Nd4jLong zDimAstride = output.stridesOf()[dimA];
|
||||||
|
const Nd4jLong xDimAstride = input.stridesOf()[dimA];
|
||||||
|
const Nd4jLong zDimCstride = output.stridesOf()[dimC];
|
||||||
|
const Nd4jLong xDimCstride = input.stridesOf()[dimC];
|
||||||
|
|
||||||
|
auto func = PRAGMA_THREADS_FOR_2D{
|
||||||
|
for (auto nBatchIndex = start_x; nBatchIndex < stop_x; nBatchIndex += inc_x) {
|
||||||
|
for (auto nSampleIndexInBatch = start_y; nSampleIndexInBatch < stop_y; nSampleIndexInBatch += inc_y) {
|
||||||
|
|
||||||
|
const Tx* xTad = x + (nBatchIndex * xDimCstride);
|
||||||
|
Tz* zTad = z + (nBatchIndex * zDimCstride);
|
||||||
|
Tz& arg = zTad[nSampleIndexInBatch * zDimAstride];
|
||||||
|
Tx Max = -minVal;
|
||||||
|
|
||||||
|
auto nSamplesPerBatch = nBatchIndex * numOfClassX * numOfSamples;
|
||||||
|
auto nClassesPerSample = nSampleIndexInBatch * numOfClassX;
|
||||||
|
for (auto nClass = 0; nClass < numOfClassX; nClass += 1) {
|
||||||
|
auto nIndex = nSamplesPerBatch + nClassesPerSample + nClass;
|
||||||
|
auto unifornLog = nd4j::math::nd4j_log<Tx, Tx>(-nd4j::math::nd4j_log<Tx, Tx>(rng.relativeT<Tx>(nIndex, minVal, maxVal)));
|
||||||
|
Tx tValue = (xTad[nClass * xDimAstride] - unifornLog);
|
||||||
|
if (tValue > Max) {
|
||||||
|
Max = tValue;
|
||||||
|
arg = nClass;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
samediff::Threads::parallel_for(func, 0, batchValue, 1, 0, numOfSamples, 1);
|
||||||
|
rng.rewindH(output.lengthOf()*numOfClassX);
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
void fillRandomMultiNomial(LaunchContext* context, graph::RandomGenerator& rng, NDArray& input, NDArray& output, const Nd4jLong numOfSamples, const int dimC) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), fillRandomMultiNomial_, (context, rng, input, output, numOfSamples, dimC), FLOAT_TYPES, INDEXING_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -27,6 +27,8 @@
|
||||||
#include <ShapeUtils.h>
|
#include <ShapeUtils.h>
|
||||||
#include <NDArrayFactory.h>
|
#include <NDArrayFactory.h>
|
||||||
#include <cuda_exception.h>
|
#include <cuda_exception.h>
|
||||||
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
#include <PointersManager.h>
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
@ -248,6 +250,116 @@ namespace helpers {
|
||||||
BUILD_SINGLE_TEMPLATE(template void fillRandomUniform_, (LaunchContext* context,
|
BUILD_SINGLE_TEMPLATE(template void fillRandomUniform_, (LaunchContext* context,
|
||||||
graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output), NUMERIC_TYPES);
|
graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output), NUMERIC_TYPES);
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
// used https://en.wikipedia.org/wiki/Categorical_distribution
|
||||||
|
// methods: gumbel trick + softmax + argmax
|
||||||
|
template<typename X, typename Z>
|
||||||
|
__global__ static void fillMultiNomialCuda_(graph::RandomGenerator* devRng, const void* vx, const Nd4jLong* xShapeInfo,
|
||||||
|
void* vz, const Nd4jLong* zShapeInfo, const Nd4jLong batchValue,
|
||||||
|
const Nd4jLong numOfSamples, const Nd4jLong numOfClassX,
|
||||||
|
const Nd4jLong dimA, const X minVal, const X maxVal) {
|
||||||
|
|
||||||
|
|
||||||
|
const X* x = reinterpret_cast<const X*>(vx);
|
||||||
|
Z* z = reinterpret_cast<Z*>(vz);
|
||||||
|
|
||||||
|
__shared__ Nd4jLong xDimAstride, zDimAstride, xDimCstride, zDimCstride, dimC;
|
||||||
|
|
||||||
|
if (0 == threadIdx.x) {
|
||||||
|
dimC = (0 == dimA) ? 1 : 0;
|
||||||
|
zDimAstride = shape::stride(zShapeInfo)[dimA];
|
||||||
|
xDimAstride = shape::stride(xShapeInfo)[dimA];
|
||||||
|
zDimCstride = shape::stride(zShapeInfo)[dimC];
|
||||||
|
xDimCstride = shape::stride(xShapeInfo)[dimC];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
for (Nd4jLong index = tid; index < batchValue*numOfSamples; index += gridDim.x * blockDim.x) {
|
||||||
|
|
||||||
|
Nd4jLong nBatchIndex = index / numOfSamples;
|
||||||
|
Nd4jLong nSampleIndexInBatch = index - (nBatchIndex * numOfSamples);
|
||||||
|
|
||||||
|
const X* xTad = x + (nBatchIndex * xDimCstride);
|
||||||
|
Z* zTad = z + (nBatchIndex * zDimCstride);
|
||||||
|
Z& arg = zTad[nSampleIndexInBatch * zDimAstride];
|
||||||
|
|
||||||
|
X Max = -minVal;
|
||||||
|
Nd4jLong nSamplesPerBatch = nBatchIndex * numOfClassX * numOfSamples;
|
||||||
|
Nd4jLong nClassPerSamples = nSampleIndexInBatch * numOfClassX;
|
||||||
|
|
||||||
|
for (Nd4jLong nClass = 0; nClass < numOfClassX; nClass++) {
|
||||||
|
Nd4jLong nIndex = nSamplesPerBatch + nClassPerSamples + nClass;
|
||||||
|
X tValue = (xTad[nClass * xDimAstride] - nd4j::math::nd4j_log<X, X>(-nd4j::math::nd4j_log<X, X>(devRng->relativeT<X>(nIndex, minVal, maxVal))));
|
||||||
|
if (tValue > Max) {
|
||||||
|
Max = tValue;
|
||||||
|
arg = nClass;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template<typename X, typename Z>
|
||||||
|
__host__ static void fillMultiNomialCudaLauncher(
|
||||||
|
const int blocksPerGrid, const int threadsPerBlock, const cudaStream_t* stream,
|
||||||
|
graph::RandomGenerator* devRng, const void* vx, const Nd4jLong* xShapeInfo,
|
||||||
|
void* vz, const Nd4jLong* zShapeInfo,
|
||||||
|
const Nd4jLong batchValue, const Nd4jLong numOfSamples,
|
||||||
|
const Nd4jLong numOfClassX, const Nd4jLong dimA){
|
||||||
|
|
||||||
|
const X minVal = DataTypeUtils::min<X>();
|
||||||
|
const X maxVal = 1.0;
|
||||||
|
|
||||||
|
fillMultiNomialCuda_<X, Z> <<< blocksPerGrid, threadsPerBlock, 256, * stream >>> (
|
||||||
|
devRng, vx, xShapeInfo, vz, zShapeInfo, batchValue,
|
||||||
|
numOfSamples, numOfClassX, dimA, minVal, maxVal);
|
||||||
|
}
|
||||||
|
|
||||||
|
///////////////////////////////////////////////////////////////////
|
||||||
|
void fillRandomMultiNomial(LaunchContext* context, graph::RandomGenerator& rng, NDArray& input, NDArray& output, const Nd4jLong numOfSamples, const int dimC) {
|
||||||
|
|
||||||
|
Nd4jLong dimA = (0 == dimC) ? 1 : 0;
|
||||||
|
|
||||||
|
const Nd4jLong batchValue = output.sizeAt(dimC);
|
||||||
|
const Nd4jLong numOfClassX = input.sizeAt(dimA);
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
|
const int blocksPerGrid = (batchValue * numOfSamples + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
|
||||||
|
PointersManager manager(context, "fillMultinomial");
|
||||||
|
graph::RandomGenerator *devRng;
|
||||||
|
|
||||||
|
auto err = cudaMalloc(&devRng, sizeof(graph::RandomGenerator));
|
||||||
|
if (err != 0) {
|
||||||
|
cuda_exception::build("fillRandomMultiNomial: Cannot allocate device memory for random generator due error", err);
|
||||||
|
}
|
||||||
|
err = cudaStreamSynchronize(*context->getCudaStream());
|
||||||
|
if (err != 0) {
|
||||||
|
cuda_exception::build("fillRandomMultiNomial: Cannot synchronize stream for random generator due error", err);
|
||||||
|
}
|
||||||
|
err = cudaMemcpyAsync(devRng, &rng, sizeof(graph::RandomGenerator), cudaMemcpyHostToDevice, *context->getCudaStream());
|
||||||
|
if (err != 0) {
|
||||||
|
cuda_exception::build("fillRandomMultiNomial: Cannot copy random generator to device", err);
|
||||||
|
}
|
||||||
|
|
||||||
|
NDArray::prepareSpecialUse({ &output }, { &input });
|
||||||
|
BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), fillMultiNomialCudaLauncher,
|
||||||
|
(blocksPerGrid, threadsPerBlock, context->getCudaStream(), devRng, input.getSpecialBuffer(),
|
||||||
|
input.getSpecialShapeInfo(), output.specialBuffer(),
|
||||||
|
output.specialShapeInfo(), batchValue, numOfSamples,
|
||||||
|
numOfClassX, dimA), FLOAT_TYPES, INDEXING_TYPES);
|
||||||
|
NDArray::registerSpecialUse({ &output }, { &input });
|
||||||
|
manager.synchronize();
|
||||||
|
|
||||||
|
err = cudaFree(devRng);
|
||||||
|
if (err != 0) {
|
||||||
|
cuda_exception::build("fillRandomMultiNomial: Cannot deallocate device memory for random generator", err);
|
||||||
|
}
|
||||||
|
rng.rewindH(output.lengthOf() * numOfClassX);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -34,6 +34,7 @@ namespace helpers {
|
||||||
void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output);
|
void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output);
|
||||||
void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output);
|
void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output);
|
||||||
void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output);
|
void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output);
|
||||||
|
void fillRandomMultiNomial(LaunchContext* context, graph::RandomGenerator& rng, NDArray& input, NDArray& output, const Nd4jLong numOfSamples, const int dimC);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1003,3 +1003,204 @@ TEST_F(RNGTests, test_uniform_119) {
|
||||||
auto status = op.execute({&x}, {&z}, {1.0, 2.0}, {}, {});
|
auto status = op.execute({&x}, {&z}, {1.0, 2.0}, {}, {});
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(RNGTests, test_multinomial_1) {
|
||||||
|
|
||||||
|
NDArray probs('f', { 3, 3 }, { 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3, 0.3 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expected('f', { 3, 3 }, { 0, 1, 2, 2, 0, 0, 1, 2, 1 }, nd4j::DataType::INT64);
|
||||||
|
NDArray output('f', { 3, 3 }, nd4j::DataType::INT64);
|
||||||
|
NDArray samples('f', { 1 }, { 3 }, nd4j::DataType::INT32);
|
||||||
|
|
||||||
|
nd4j::ops::random_multinomial op;
|
||||||
|
RandomGenerator rng(1234, 1234);
|
||||||
|
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64}, {}, false) );
|
||||||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
NDArray probsZ('c', { 1, 3 }, { 0.3, 0.3, 0.3 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expectedZ('c', { 3, 3 }, { 0, 0, 0, 0, 0, 0, 0, 0, 0 }, nd4j::DataType::INT64);
|
||||||
|
|
||||||
|
auto result = op.execute({ &probsZ, &samples }, { }, { 1, INT64 });
|
||||||
|
auto outputZ = result->at(0);
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
ASSERT_TRUE(expectedZ.isSameShape(outputZ));
|
||||||
|
ASSERT_TRUE(expectedZ.equalsTo(outputZ));
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(RNGTests, test_multinomial_2) {
|
||||||
|
|
||||||
|
NDArray samples('c', { 1 }, { 20 }, nd4j::DataType::INT32);
|
||||||
|
NDArray probs('c', { 3, 5 }, { 0.2, 0.3, 0.5, 0.3, 0.5, 0.2, 0.5, 0.2, 0.3, 0.35, 0.25, 0.3, 0.25, 0.25, 0.5 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expected('c', { 3, 20 }, { 0, 2, 0, 2, 0, 4, 2, 0, 1, 2, 0, 2, 3, 0, 0, 2, 4, 4, 1, 0, 2, 3, 2, 3, 0, 1, 3, 1, 1, 1, 2, 4, 3, 3, 1, 4, 4, 2, 0, 0, 3, 3, 3, 0, 0, 2, 2, 3, 3, 0, 0, 2, 3, 4, 2, 2, 3, 2, 1, 2 }, nd4j::DataType::INT64);
|
||||||
|
NDArray output('c', { 3, 20 }, nd4j::DataType::INT64);
|
||||||
|
|
||||||
|
nd4j::ops::random_multinomial op;
|
||||||
|
RandomGenerator rng(1234, 1234);
|
||||||
|
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, false));
|
||||||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
|
||||||
|
NDArray probs2('c', { 5, 3 }, { 0.2, 0.3, 0.5, 0.3, 0.5, 0.2, 0.5, 0.2, 0.3, 0.35, 0.25, 0.3, 0.25, 0.25, 0.5 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expected2('c', { 20, 3 }, { 0, 2, 3, 2, 3, 3, 0, 2, 3, 2, 3, 0, 0, 0, 0, 4, 1, 2, 2, 3, 2, 3, 1, 3, 1, 1, 3, 2, 1, 0, 0, 2, 0, 2, 4, 2, 3, 3, 3, 0, 3, 4, 0, 1, 2, 2, 0, 2, 4, 4, 0, 4, 2, 2, 1, 0, 1, 0, 0, 2 }, nd4j::DataType::INT64);
|
||||||
|
NDArray output2('c', { 20, 3 }, nd4j::DataType::INT64);
|
||||||
|
|
||||||
|
rng.setStates(1234, 1234);
|
||||||
|
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs2, &samples }, { &output2 }, {}, { 1, INT64 }, {}, false));
|
||||||
|
ASSERT_TRUE(expected2.isSameShape(output2));
|
||||||
|
ASSERT_TRUE(expected2.equalsTo(output2));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(RNGTests, test_multinomial_3) {
|
||||||
|
|
||||||
|
NDArray probs('c', { 4, 3 }, { 0.3, 0.3, 0.4, 0.3, 0.4, 0.3, 0.3, 0.3, 0.4, 0.4, 0.3, 0.3 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expected('c', { 4, 5 }, nd4j::DataType::INT64);
|
||||||
|
NDArray output('c', { 4, 5 }, nd4j::DataType::INT64);
|
||||||
|
NDArray samples('c', { 1 }, { 5 }, nd4j::DataType::INT32);
|
||||||
|
RandomGenerator rng(1234, 1234);
|
||||||
|
|
||||||
|
nd4j::ops::random_multinomial op;
|
||||||
|
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 0, INT64 }, {}, false));
|
||||||
|
|
||||||
|
rng.setStates(1234, 1234);
|
||||||
|
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, false));
|
||||||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(RNGTests, test_multinomial_4) {
|
||||||
|
|
||||||
|
NDArray probs('c', { 3, 4 }, { 0.3, 0.3, 0.4, 0.3, 0.4, 0.3, 0.3, 0.3, 0.4, 0.4, 0.3, 0.3 }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray expected('c', { 5, 4 }, nd4j::DataType::INT64);
|
||||||
|
NDArray output('c', { 5, 4 }, nd4j::DataType::INT64);
|
||||||
|
NDArray samples('c', { 1 }, { 5 }, nd4j::DataType::INT32);
|
||||||
|
|
||||||
|
RandomGenerator rng(1234, 1234);
|
||||||
|
nd4j::ops::random_multinomial op;
|
||||||
|
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &expected }, {}, { 1, INT64 }, {}, false));
|
||||||
|
|
||||||
|
rng.setStates(1234, 1234);
|
||||||
|
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1, INT64 }, {}, false));
|
||||||
|
ASSERT_TRUE(expected.isSameShape(output));
|
||||||
|
ASSERT_TRUE(expected.equalsTo(output));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(RNGTests, test_multinomial_5) {
|
||||||
|
// multinomial as binomial if 2 classes used
|
||||||
|
int batchValue = 1;
|
||||||
|
int ClassValue = 2;
|
||||||
|
int Samples = 1000000;
|
||||||
|
|
||||||
|
NDArray samples('c', { 1 }, { 1.*Samples }, nd4j::DataType::INT32);
|
||||||
|
|
||||||
|
NDArray probs('c', { ClassValue, batchValue }, { 1.0, 1.0 }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
nd4j::ops::random_multinomial op;
|
||||||
|
|
||||||
|
NDArray output('c', { Samples, batchValue }, nd4j::DataType::INT64);
|
||||||
|
RandomGenerator rng(1234, 1234);
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 1 }, {}, false));
|
||||||
|
|
||||||
|
auto deviation = output.varianceNumber(variance::SummaryStatsStandardDeviation, false);
|
||||||
|
auto mean = output.meanNumber();
|
||||||
|
// printf("Var: %f Mean: %f \n", deviation.e<double>(0), mean.e<double>(0));
|
||||||
|
// theoretical values for binomial
|
||||||
|
ASSERT_NEAR(0.5, deviation.e<double>(0), 3e-3);
|
||||||
|
ASSERT_NEAR(0.5, mean.e<double>(0), 3e-3);
|
||||||
|
|
||||||
|
for (int i = 0; i < output.lengthOf(); i++) {
|
||||||
|
auto value = output.e<Nd4jLong>(i);
|
||||||
|
ASSERT_TRUE(value >= 0 && value < ClassValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto resultR = op.execute({ &probs, &samples }, { }, { 1 });
|
||||||
|
auto outputR = resultR->at(0);
|
||||||
|
ASSERT_EQ(Status::OK(), resultR->status());
|
||||||
|
|
||||||
|
deviation = outputR->varianceNumber(variance::SummaryStatsStandardDeviation, false);
|
||||||
|
mean = outputR->meanNumber();
|
||||||
|
// printf("Random seed - Var: %f Mean: %f \n", deviation.e<double>(0), mean.e<double>(0));
|
||||||
|
ASSERT_NEAR(0.5, deviation.e<double>(0), 35e-3);
|
||||||
|
ASSERT_NEAR(0.5, mean.e<double>(0), 35e-3);
|
||||||
|
|
||||||
|
for (int i = 0; i < outputR->lengthOf(); i++) {
|
||||||
|
auto value = outputR->e<Nd4jLong>(i);
|
||||||
|
ASSERT_TRUE(value >= 0 && value < ClassValue);
|
||||||
|
}
|
||||||
|
|
||||||
|
delete resultR;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(RNGTests, test_multinomial_6) {
|
||||||
|
|
||||||
|
int batchValue = 1;
|
||||||
|
int ClassValue = 5;
|
||||||
|
int Samples = 1000000;
|
||||||
|
|
||||||
|
NDArray samples('c', { 1 }, { 1. * Samples }, nd4j::DataType::INT32);
|
||||||
|
|
||||||
|
nd4j::ops::random_multinomial op;
|
||||||
|
NDArray probExpect('c', { ClassValue }, { 0.058, 0.096, 0.1576, 0.2598, 0.4287 }, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
|
// without seed
|
||||||
|
NDArray probsR('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, nd4j::DataType::FLOAT32);
|
||||||
|
|
||||||
|
auto resultR = op.execute({ &probsR, &samples }, { }, { 0 });
|
||||||
|
auto outputR = resultR->at(0);
|
||||||
|
ASSERT_EQ(Status::OK(), resultR->status());
|
||||||
|
|
||||||
|
NDArray countsR('c', { ClassValue }, { 0, 0, 0, 0, 0 }, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
|
for (int i = 0; i < outputR->lengthOf(); i++) {
|
||||||
|
auto value = outputR->e<Nd4jLong>(i);
|
||||||
|
ASSERT_TRUE(value >= 0 && value < ClassValue);
|
||||||
|
double* z = countsR.bufferAsT<double>();
|
||||||
|
z[value] += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < countsR.lengthOf(); i++) {
|
||||||
|
auto c = countsR.e<double>(i);
|
||||||
|
auto p = probExpect.e<double>(i);
|
||||||
|
// printf("Get freq : %f Expect freq: %f \n", c / Samples, p);
|
||||||
|
ASSERT_NEAR((c / Samples), p, 35e-3);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto deviation = outputR->varianceNumber(variance::SummaryStatsStandardDeviation, false);
|
||||||
|
auto mean = outputR->meanNumber();
|
||||||
|
// printf("Var: %f Mean: %f \n", deviation.e<double>(0), mean.e<double>(0));
|
||||||
|
ASSERT_NEAR(1.2175, deviation.e<double>(0), 35e-3);
|
||||||
|
ASSERT_NEAR(2.906, mean.e<double>(0), 35e-3);
|
||||||
|
|
||||||
|
delete resultR;
|
||||||
|
|
||||||
|
RandomGenerator rng(1234, 1234);
|
||||||
|
NDArray probs('c', { batchValue, ClassValue }, { 1., 1.5, 2., 2.5, 3. }, nd4j::DataType::FLOAT32);
|
||||||
|
NDArray output('c', { batchValue, Samples }, nd4j::DataType::INT64);
|
||||||
|
|
||||||
|
ASSERT_EQ(Status::OK(), op.execute(rng, { &probs, &samples }, { &output }, {}, { 0, INT64 }, {}, false));
|
||||||
|
|
||||||
|
NDArray counts('c', { ClassValue }, { 0, 0, 0, 0, 0 }, nd4j::DataType::DOUBLE);
|
||||||
|
|
||||||
|
for (int i = 0; i < output.lengthOf(); i++) {
|
||||||
|
auto value = output.e<Nd4jLong>(i);
|
||||||
|
ASSERT_TRUE(value >= 0 && value < ClassValue);
|
||||||
|
double* z = counts.bufferAsT<double>();
|
||||||
|
z[value] += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < counts.lengthOf(); i++) {
|
||||||
|
auto c = counts.e<double>(i);
|
||||||
|
auto p = probExpect.e<double>(i);
|
||||||
|
// printf("Get freq : %f Expect freq: %f \n", c / Samples, p);
|
||||||
|
ASSERT_NEAR((c / Samples), p, 3e-3);
|
||||||
|
}
|
||||||
|
|
||||||
|
deviation = output.varianceNumber(variance::SummaryStatsStandardDeviation, false);
|
||||||
|
mean = output.meanNumber();
|
||||||
|
// printf("Var: %f Mean: %f \n", deviation.e<double>(0), mean.e<double>(0));
|
||||||
|
ASSERT_NEAR(1.2175, deviation.e<double>(0), 3e-3);
|
||||||
|
ASSERT_NEAR(2.906, mean.e<double>(0), 3e-3);
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue