* initial commit Signed-off-by: raver119@gmail.com <raver119@gmail.com> * another initial commit Signed-off-by: raver119@gmail.com <raver119@gmail.com> * another initial commit Signed-off-by: raver119@gmail.com <raver119@gmail.com> * one more initial commit Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next step Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next step Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next step Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next step Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Refactored buffer() and shapeInfo() methods usage with NDArray class. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt Graph class methods to use const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt choose op to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt where op shape method to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt lstsq op to use constant empty shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt matrix_diag_part op shape routine to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt determinant ops to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt mean_pairwssqerr_loss ops to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt ops shape methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt shape methods for loss ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt log_loss op shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt shape methods for ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt dilation2d ops shape methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted deconv2d ops shape methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted dynamicRNN op shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape methods for ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape methods for lstm layer ops. Signed-off-by: shugeo <sgazeos@gmail.com> * few updates Signed-off-by: raver119@gmail.com <raver119@gmail.com> * first cuda tweak Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Adopt constant shapes for sconv2d ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt constant shapes for gru ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt constant shapes with shape methods for segment ops and so on. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted constant shapes with unsorted_segment_* ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted constant shapes with gamma op shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape methods of reduce_stddev ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape methods for reduce_* ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt shape method for squeeze op. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt strided_slice shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored concat op shape method to adopt constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted shape method for mirror_pad op. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted split op shape method. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted tile ops shape methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Added const cast for mkldnn routines handles. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored logSoftMaxForVector_ routine to conform with proper data and shape pointer casts. Signed-off-by: shugeo <sgazeos@gmail.com> * Cosmetic changes to proper usage of constant pointers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored a couple shape comparators for strides and addBias helpers to proper use data pointers with inplace option. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored depthToSpace helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored histogram helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored im2col helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored gather and gatherND helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage on percentile helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed gather shape with helpers and range buffer usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with space to depth helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage and constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with LUP decomposition> Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored onehot_ helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored pad and prefix to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactoed softmax helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed space to batch helpers to use buffers properly. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed stack and split helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with sparse to dense helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with mindistance_ helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with tile helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed constant shape usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed constant shape usage with legacy pairwise bool ops. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored a couple of methods to adopt constant shape usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed broadcasting with constant shape." Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const usage with inplace reverse and constant shapes with legacy reduction. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored legacy ops with const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored sort to adopt constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected sort for constant shape usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed constant shape usage with special methods. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored Context to conform with constant shape usage. Signed-off-by: shugeo <sgazeos@gmail.com> * CUDA broadcasting headers Signed-off-by: raver119@gmail.com <raver119@gmail.com> * pairwise/indexreduce/random headers Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Refactored native ops to adopt constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * legacy reduce3/scalar headers Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Corrected pullRow signature and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected routines to proper use of constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored tests to use constant shapes properly. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored legacy ops tests to use constant shapes properly. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored buffer usage with NDArray tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed native ops tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed special concat routine. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with test. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed buffer usage with a test. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored TAD.h and tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored calcStrides* routines to use constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed miscelaneous errors with constant shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * NativeOps const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Corrected definitions for declared functions. Signed-off-by: shugeo <sgazeos@gmail.com> * NativeOps const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * few more const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Fixed const shapes with shape routines. Signed-off-by: shugeo <sgazeos@gmail.com> * few more const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Fixed shape method for broadcastable case. Signed-off-by: shugeo <sgazeos@gmail.com> * few more const changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * xw_plus_b BP shape fn restored Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Fixed signatures with broadcasting. Signed-off-by: shugeo <sgazeos@gmail.com> * Repaired backprops shape methods for a set of operations. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored broadcast bool for cuda. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored methods for 3 args with const qualifier. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed a couple of kernel signatures for broadcasting. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed kernels signatures for const buffers and shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored pairwise methods to persistent buffers and shapes usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt const to buffers and shapes with kernels. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopt const to buffers and shapes with scalar kernels. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored indexreduce kernels signatures to use const buffers and shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored pairwise kernels to adopt cons shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored pairwise bool kernels to adopt cons shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored random special ops to conform with const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored native ops to conform with const shapes and buffers under cuda platform. Signed-off-by: shugeo <sgazeos@gmail.com> * Cosmetical changes only. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shapes and buffers error. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected start pos routine. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored methods to conform with const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored helpers to use proper methods instead. Signed-off-by: shugeo <sgazeos@gmail.com> * bunch of changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next bunch of changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * next bunch of changes Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Fixed execScalar declaration. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed execScalar declaration. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected const shape cases with sort and so on. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shapes for sort. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored kernel declarations to adopt const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed kernels declarations to adopt const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected kernel declarations to adopt const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed kernels declarations to adopt const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed segment helpers kernels declarations and so on to adopt const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shape usage with segment and solve helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed kernel declaration with adjustWeight helper. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed cuda implementations for constant shape helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted const shape usage with kernels. Signed-off-by: shugeo <sgazeos@gmail.com> * Adopted top_k kernels to use const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Corrected kernels declarations to adopt const shapes with helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored NDArray definitions to adopt const shapes and buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shapes with image suppression helpers. Signed-off-by: shugeo <sgazeos@gmail.com> * Slight improvement with buffers. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored buffer usage. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored buffer usage with tests. Signed-off-by: shugeo <sgazeos@gmail.com> * Fixed const shape usage with definitions. Signed-off-by: shugeo <sgazeos@gmail.com> * minor updates on cpu side Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Refactored const shape usage with ConstantDescritor and native ops with cuda platform. Signed-off-by: shugeo <sgazeos@gmail.com> * Refactored tear and tile kernels to adopt with const shapes. Signed-off-by: shugeo <sgazeos@gmail.com> * softmax_loop fix Signed-off-by: raver119 <raver119@gmail.com> * update missing signature Signed-off-by: raver119@gmail.com <raver119@gmail.com> * softmax again Signed-off-by: raver119@gmail.com <raver119@gmail.com> * few more missing consts Signed-off-by: raver119 <raver119@gmail.com> * new methods updated Signed-off-by: raver119@gmail.com <raver119@gmail.com> Co-authored-by: shugeo <sgazeos@gmail.com>
214 lines
9.1 KiB
C++
214 lines
9.1 KiB
C++
/*******************************************************************************
|
||
* Copyright (c) 2019 Konduit K.K.
|
||
*
|
||
* This program and the accompanying materials are made available under the
|
||
* terms of the Apache License, Version 2.0 which is available at
|
||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||
*
|
||
* Unless required by applicable law or agreed to in writing, software
|
||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||
* License for the specific language governing permissions and limitations
|
||
* under the License.
|
||
*
|
||
* SPDX-License-Identifier: Apache-2.0
|
||
******************************************************************************/
|
||
|
||
//
|
||
// @author sgazeos@gmail.com
|
||
//
|
||
|
||
#include <ops/declarable/helpers/random.h>
|
||
//#include <vector>
|
||
#include <memory>
|
||
//#include <graph/Context.h>
|
||
#include <helpers/ShapeUtils.h>
|
||
#include <helpers/RandomLauncher.h>
|
||
#include <execution/Threads.h>
|
||
#include <helpers/ConstantTadHelper.h>
|
||
|
||
namespace sd {
|
||
namespace ops {
|
||
namespace helpers {
|
||
|
||
template <typename T>
|
||
void fillRandomGamma_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output) {
|
||
|
||
auto broadcasted = alpha->shapeInfo();
|
||
if (beta != nullptr) {
|
||
const Nd4jLong* broadcastedShape = nullptr;
|
||
ShapeUtils::evalBroadcastShapeInfo(*alpha, *beta, true, broadcastedShape, context->getWorkspace());
|
||
broadcasted = broadcastedShape;
|
||
}
|
||
|
||
auto step = shape::length(broadcasted);
|
||
auto shift = output->lengthOf() / step;
|
||
|
||
auto copyAlpha = alpha;
|
||
auto copyBeta = beta;
|
||
if (beta != nullptr) {
|
||
NDArray alphaBroadcasted(broadcasted, alpha->dataType(), false, context);
|
||
NDArray betaBroadcasted(broadcasted, beta->dataType(), false, context);
|
||
|
||
copyAlpha = new NDArray(alphaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *alpha));
|
||
copyBeta = new NDArray(betaBroadcasted.applyTrueBroadcast(BroadcastOpsTuple::Assign(), *beta));
|
||
|
||
}
|
||
// bool directAlpha = alpha->ews() == 1 && alpha->ordering() == 'c';
|
||
bool directOutput = output->ews() == 1 && output->ordering() == 'c';
|
||
T* outputBuf = output->dataBuffer()->primaryAsT<T>();
|
||
|
||
PRAGMA_OMP_PARALLEL_FOR
|
||
for (Nd4jLong k = 0; k < shift; k++) {
|
||
auto pos = k * step;
|
||
auto u = rng.relativeT<T>(k, 0., 1.);
|
||
for (Nd4jLong e = 0; e < step; e++)
|
||
if (directOutput) {
|
||
outputBuf[pos + e] = math::nd4j_igamma<T, T, T>(copyAlpha->t<T>(e),
|
||
beta != nullptr ? copyBeta->t<T>(e) * u : u);
|
||
}
|
||
else {
|
||
output->t<T>(pos + e) = math::nd4j_igamma<T, T, T>(copyAlpha->t<T>(e),
|
||
beta != nullptr ? copyBeta->t<T>(e) * u : u);
|
||
}
|
||
}
|
||
|
||
if (beta != nullptr) {
|
||
delete copyAlpha;
|
||
delete copyBeta;
|
||
//delete broadcasted;
|
||
}
|
||
}
|
||
|
||
void fillRandomGamma(LaunchContext* context, graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output) {
|
||
BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomGamma_, (context, rng, alpha, beta, output), FLOAT_NATIVE);
|
||
}
|
||
BUILD_SINGLE_TEMPLATE(template void fillRandomGamma_, (LaunchContext* context,
|
||
graph::RandomGenerator& rng, NDArray* alpha, NDArray* beta, NDArray* output), FLOAT_NATIVE);
|
||
|
||
/*
|
||
* algorithm Poisson generator based upon the inversion by sequential search:[48]:505
|
||
init:
|
||
Let x ← 0, p ← e−λ, s ← p.
|
||
Generate uniform random number u in [0,1].
|
||
while u > s do:
|
||
x ← x + 1.
|
||
p ← p * λ / x.
|
||
s ← s + p.
|
||
return x.
|
||
* */
|
||
template <typename T>
|
||
void fillRandomPoisson_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output) {
|
||
auto shift = output->lengthOf() / lambda->lengthOf();
|
||
auto step = lambda->lengthOf();
|
||
T* lambdaBuf = lambda->dataBuffer()->primaryAsT<T>();
|
||
T* outputBuf = output->dataBuffer()->primaryAsT<T>();
|
||
bool directLa = lambda->ews() == 1 && lambda->ordering() == 'c';
|
||
bool directOut = output->ews() == 1 && output->ordering() == 'c';
|
||
PRAGMA_OMP_PARALLEL_FOR
|
||
for (Nd4jLong k = 0; k < shift; k++) {
|
||
auto pos = k * step;
|
||
auto u = rng.relativeT<T>(k, 0., 1.);
|
||
for (Nd4jLong e = 0; e < step; e++) {
|
||
auto p = math::nd4j_exp<T, T>(-lambda->t<T>(e));
|
||
auto s = p;
|
||
auto x = T(0.f);
|
||
while (u > s) {
|
||
x += 1.f;
|
||
p *= directLa?lambdaBuf[e]/x:lambda->t<T>(e) / x;
|
||
s += p;
|
||
}
|
||
if (directOut)
|
||
outputBuf[pos + e] = x;
|
||
else
|
||
output->t<T>(pos + e) = x;
|
||
}
|
||
}
|
||
}
|
||
|
||
void fillRandomPoisson(LaunchContext* context, graph::RandomGenerator& rng, NDArray* lambda, NDArray* output) {
|
||
BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomPoisson_, (context, rng, lambda, output), FLOAT_NATIVE);
|
||
}
|
||
BUILD_SINGLE_TEMPLATE(template void fillRandomPoisson_, (LaunchContext* context,
|
||
graph::RandomGenerator& rng, NDArray* lambda, NDArray* output), FLOAT_TYPES);
|
||
|
||
template <typename T>
|
||
void fillRandomUniform_(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
|
||
T minVal = T(0);
|
||
T maxVal = DataTypeUtils::max<T>();
|
||
if (min)
|
||
minVal = min->t<T>(0);
|
||
if (max)
|
||
maxVal = max->t<T>(0);
|
||
|
||
if (output->isR())
|
||
RandomLauncher::fillUniform(context, rng, output, minVal, maxVal);
|
||
else {
|
||
PRAGMA_OMP_PARALLEL_FOR
|
||
for (Nd4jLong i = 0; i < output->lengthOf(); i++) {
|
||
output->t<T>(i) = rng.relativeT<T>(i, minVal, maxVal);
|
||
}
|
||
}
|
||
}
|
||
|
||
void fillRandomUniform(LaunchContext* context, graph::RandomGenerator& rng, NDArray* min, NDArray* max, NDArray* output) {
|
||
BUILD_SINGLE_SELECTOR(output->dataType(), fillRandomUniform_, (context, rng, min, max, output), NUMERIC_TYPES);
|
||
}
|
||
|
||
// used https://en.wikipedia.org/wiki/Categorical_distribution
|
||
// methods: gumbel trick + softmax + argmax
|
||
template <typename Tx, typename Tz>
|
||
void fillRandomMultiNomial_(LaunchContext* context, graph::RandomGenerator& rng, NDArray& input, NDArray& output, const Nd4jLong numOfSamples, const int dimC) {
|
||
|
||
const Tx* x = input.bufferAsT<Tx>();
|
||
Tz* z = output.bufferAsT<Tz>();
|
||
|
||
Tx minVal = DataTypeUtils::min<Tx>();
|
||
Tx maxVal = 1.0;
|
||
|
||
auto dimA = (0 == dimC) ? 1 : 0;
|
||
const Nd4jLong batchValue = output.sizeAt(dimC);
|
||
const Nd4jLong numOfClassX = input.sizeAt(dimA);
|
||
|
||
const Nd4jLong zDimAstride = output.stridesOf()[dimA];
|
||
const Nd4jLong xDimAstride = input.stridesOf()[dimA];
|
||
const Nd4jLong zDimCstride = output.stridesOf()[dimC];
|
||
const Nd4jLong xDimCstride = input.stridesOf()[dimC];
|
||
|
||
auto func = PRAGMA_THREADS_FOR_2D{
|
||
for (auto nBatchIndex = start_x; nBatchIndex < stop_x; nBatchIndex += inc_x) {
|
||
for (auto nSampleIndexInBatch = start_y; nSampleIndexInBatch < stop_y; nSampleIndexInBatch += inc_y) {
|
||
|
||
const Tx* xTad = x + (nBatchIndex * xDimCstride);
|
||
Tz* zTad = z + (nBatchIndex * zDimCstride);
|
||
Tz& arg = zTad[nSampleIndexInBatch * zDimAstride];
|
||
Tx Max = -minVal;
|
||
|
||
auto nSamplesPerBatch = nBatchIndex * numOfClassX * numOfSamples;
|
||
auto nClassesPerSample = nSampleIndexInBatch * numOfClassX;
|
||
for (Nd4jLong nClass = 0; nClass < numOfClassX; nClass += 1) {
|
||
auto nIndex = nSamplesPerBatch + nClassesPerSample + nClass;
|
||
auto unifornLog = sd::math::nd4j_log<Tx, Tx>(-sd::math::nd4j_log<Tx, Tx>(rng.relativeT<Tx>(nIndex, minVal, maxVal)));
|
||
Tx tValue = (xTad[nClass * xDimAstride] - unifornLog);
|
||
if (tValue > Max) {
|
||
Max = tValue;
|
||
arg = nClass;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
};
|
||
|
||
samediff::Threads::parallel_for(func, 0, batchValue, 1, 0, numOfSamples, 1);
|
||
rng.rewindH(output.lengthOf()*numOfClassX);
|
||
|
||
return;
|
||
}
|
||
|
||
void fillRandomMultiNomial(LaunchContext* context, graph::RandomGenerator& rng, NDArray& input, NDArray& output, const Nd4jLong numOfSamples, const int dimC) {
|
||
BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), fillRandomMultiNomial_, (context, rng, input, output, numOfSamples, dimC), FLOAT_TYPES, INDEXING_TYPES);
|
||
}
|
||
|
||
}
|
||
}
|
||
} |