R119 random shuffle (#488)
* random_shuffle test for Yurii Signed-off-by: raver119@gmail.com <raver119@gmail.com> * - implementation and testing random_shuffle for vector case (cpu) Signed-off-by: Yurii <iuriish@yahoo.com> * - fix bug in random shuffle for cpu Signed-off-by: Yurii <iuriish@yahoo.com> * - correct tests for random shuffle and improve alg when inPlace is false Signed-off-by: Yurii <iuriish@yahoo.com> * - implementation of random shuffle algorithm for cuda Signed-off-by: Yurii <iuriish@yahoo.com> * - split cuda random shuffle alg into separate launches of 2 kernels Signed-off-by: Yurii <iuriish@yahoo.com> * - minor corrections in cuda concat kernel Signed-off-by: Yurii <iuriish@yahoo.com> Co-authored-by: raver119@gmail.com <raver119@gmail.com>master
parent
8733c0c3ed
commit
bb0492f47d
|
@ -16,7 +16,8 @@
|
|||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com), created on 20.04.2018
|
||||
//
|
||||
// implementation is based on following article:
|
||||
// "MergeShuffle: A Very Fast, Parallel Random Permutation Algorithm", https://arxiv.org/abs/1508.03167
|
||||
|
||||
|
||||
|
||||
|
@ -31,96 +32,167 @@ namespace ops {
|
|||
namespace helpers {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// Fisher-Yates shuffle
|
||||
template <typename T>
|
||||
void randomShuffle_(NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
|
||||
static void fisherYates(sd::graph::RandomGenerator& rng, T* buff, const Nd4jLong& len, const Nd4jLong& ews, Nd4jLong ind) {
|
||||
|
||||
for(Nd4jLong i = len-1; i > 0; --i) {
|
||||
const Nd4jLong j = rng.relativeLong(ind++) % (i + 1);
|
||||
if(i != j)
|
||||
math::nd4j_swap<T>(buff[i*ews], buff[j*ews]);
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// mutual shuffle of two adjacent already shuffled ranges with length len1 and (totLen - len1) correspondingly
|
||||
template <typename T>
|
||||
static void mergeShuffle(sd::graph::RandomGenerator& rng, T* buff, const Nd4jLong& len1, const Nd4jLong& totLen, const Nd4jLong& ews, Nd4jLong ind) {
|
||||
|
||||
Nd4jLong beg = 0; // beginning
|
||||
Nd4jLong mid = len1; // middle
|
||||
|
||||
while (true) {
|
||||
if(rng.relativeLong(ind++) % 2) {
|
||||
if(mid == totLen)
|
||||
break;
|
||||
math::nd4j_swap<T>(buff[ews * beg], buff[ews * mid++]);
|
||||
} else {
|
||||
if(beg == mid)
|
||||
break;
|
||||
}
|
||||
++beg;
|
||||
}
|
||||
|
||||
// fisherYates
|
||||
while (beg < totLen) {
|
||||
const Nd4jLong j = rng.relativeLong(ind++) % (beg + 1);
|
||||
if(beg != j)
|
||||
math::nd4j_swap<T>(buff[ews * beg], buff[ews * j]);
|
||||
++beg;
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
static void randomShuffle_(NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
|
||||
|
||||
// check edge cases first
|
||||
int temp;
|
||||
const int firstDim = input.sizeAt(0);
|
||||
int temp;
|
||||
|
||||
if(input.lengthOf() == 1 || firstDim == 1) {
|
||||
|
||||
if(!isInplace)
|
||||
output.assign(input);
|
||||
}
|
||||
else if (input.isVector() || shape::isLikeVector(input.shapeInfo(), temp)) {
|
||||
else if (shape::isCommonVector(input.shapeInfo(), temp)) {
|
||||
|
||||
// apply Fisher-Yates shuffle
|
||||
if(isInplace) {
|
||||
//PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance().tadThreshold())
|
||||
for(int i = firstDim-1; i > 0; --i) {
|
||||
int r = rng.relativeInt(i) % i;
|
||||
if(i == r)
|
||||
continue;
|
||||
T t0 = input.t<T>(i);
|
||||
T t1 = input.t<T>(r);
|
||||
//math::nd4j_swap<T>(input(i), input(r));
|
||||
input.r<T>(i) = t1;
|
||||
input.r<T>(r) = t0;
|
||||
}
|
||||
NDArray* arr = &input;
|
||||
|
||||
if (!isInplace) {
|
||||
output.assign(input);
|
||||
arr = &output;
|
||||
}
|
||||
else {
|
||||
std::vector<int> indices(firstDim);
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
output.p<T>(Nd4jLong(0), input.e<T>(0));
|
||||
|
||||
// FIXME: parallelism!!
|
||||
for(int i = firstDim-1; i > 0; --i) {
|
||||
int r = rng.relativeInt(i) % i;
|
||||
output.r<T>(i) = input.t<T>(indices[r]);
|
||||
if(i == r)
|
||||
continue;
|
||||
const Nd4jLong ews = arr->ews();
|
||||
|
||||
output.r<T>(r) = input.t<T>(indices[i]);
|
||||
math::nd4j_swap<int>(indices[i], indices[r]);
|
||||
const Nd4jLong len = arr->lengthOf();
|
||||
const Nd4jLong threshold = 1<<22; // this number was deduced from diagram in article
|
||||
|
||||
int power = 0;
|
||||
while ((len >> power) > threshold)
|
||||
++power;
|
||||
|
||||
const Nd4jLong numChunks = 1 << power;
|
||||
|
||||
auto funcFisherYates = PRAGMA_THREADS_FOR {
|
||||
|
||||
for (auto i = start; i < stop; ++i) {
|
||||
|
||||
Nd4jLong offset = (len * i) >> power;
|
||||
Nd4jLong currLen = ((len * (i + 1)) >> power) - offset;
|
||||
fisherYates<T>(rng, arr->bufferAsT<T>() + offset*ews, currLen, ews, offset);
|
||||
}
|
||||
rng.rewindH(firstDim-1);
|
||||
}
|
||||
};
|
||||
|
||||
auto funcMerge = PRAGMA_THREADS_FOR {
|
||||
|
||||
for (int64_t i = start, k = 1; i < stop; i += increment, ++k) {
|
||||
Nd4jLong offset = len * i >> power;
|
||||
Nd4jLong len1 = (len * (i + increment/2) >> power) - offset;
|
||||
Nd4jLong totLen = (len * (i + increment) >> power) - offset;
|
||||
mergeShuffle<T>(rng, arr->bufferAsT<T>() + offset*ews, len1, totLen, ews, len * k + offset);
|
||||
}
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(funcFisherYates, 0, numChunks);
|
||||
|
||||
for (int j = 1; j < numChunks; j += j)
|
||||
samediff::Threads::parallel_for(funcMerge, 0, numChunks, 2*j);
|
||||
|
||||
// #pragma omp parallel for
|
||||
// for (uint i = 0; i < numChunks; ++i) {
|
||||
|
||||
// Nd4jLong offset = (len * i) >> power;
|
||||
// Nd4jLong currLen = ((len * (i + 1)) >> power) - offset;
|
||||
// fisherYates<T>(rng, arr->bufferAsT<T>() + offset*ews, currLen, ews, offset);
|
||||
// }
|
||||
|
||||
// for (uint j = 1; j < numChunks; j += j) {
|
||||
// #pragma omp parallel for
|
||||
// for (auto i = 0; i < numChunks; i += 2*j) {
|
||||
// Nd4jLong offset = len * i >> power;
|
||||
// Nd4jLong len1 = (len * (i + j) >> power) - offset;
|
||||
// Nd4jLong totLen = (len * (i + 2*j) >> power) - offset;
|
||||
// mergeShuffle(rng, arr->bufferAsT<T>() + offset*ews, len1, totLen, ews, len * j + offset);
|
||||
// }
|
||||
// }
|
||||
|
||||
rng.rewindH((len + 1) * power);
|
||||
}
|
||||
else {
|
||||
|
||||
// evaluate sub-arrays list of input array through all dimensions excluding first one
|
||||
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input.rankOf(), {0});
|
||||
auto subArrsListIn = input.allTensorsAlongDimension(dimensions);
|
||||
auto dimsToExclude = ShapeUtils::evalDimsToExclude(input.rankOf(), {0});
|
||||
|
||||
// apply Fisher-Yates shuffle
|
||||
if(isInplace) {
|
||||
//PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance().elementwiseThreshold())
|
||||
for(int i = firstDim - 1; i > 0; --i) {
|
||||
int r = rng.relativeInt(i) % i;
|
||||
|
||||
if(i == r)
|
||||
continue;
|
||||
subArrsListIn.at(i)->swapUnsafe(*subArrsListIn.at(r));
|
||||
auto subArrsList = input.allTensorsAlongDimension(dimsToExclude);
|
||||
|
||||
// Fisher-Yates shuffle
|
||||
for(int i = firstDim - 1; i > 0; --i) {
|
||||
const int j = rng.relativeInt(i) % (i + 1);
|
||||
if(i != j)
|
||||
subArrsList.at(i)->swapUnsafe(*subArrsList.at(j));
|
||||
}
|
||||
}
|
||||
else {
|
||||
// evaluate sub-arrays list of output array through all dimensions excluding first one
|
||||
auto subArrsListOut = output.allTensorsAlongDimension(dimensions);
|
||||
|
||||
auto subArrsListIn = input.allTensorsAlongDimension(dimsToExclude);
|
||||
auto subArrsListOut = output.allTensorsAlongDimension(dimsToExclude);
|
||||
|
||||
std::vector<int> indices(firstDim);
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
bool isZeroShuffled = false;
|
||||
//PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance().tadThreshold())
|
||||
for(int i = firstDim - 1; i > 0; --i) {
|
||||
int r = rng.relativeInt(i) % i;
|
||||
subArrsListOut.at(i)->assign(subArrsListIn.at(indices[r]));
|
||||
if(r == 0)
|
||||
isZeroShuffled = true;
|
||||
if(i == r)
|
||||
continue;
|
||||
subArrsListOut.at(r)->assign(subArrsListIn.at(indices[i]));
|
||||
math::nd4j_swap<int>(indices[i], indices[r]);
|
||||
}
|
||||
if(!isZeroShuffled)
|
||||
subArrsListOut.at(0)->assign(subArrsListIn.at(0));
|
||||
std::iota(indices.begin(), indices.end(), 0); // 0,1,2,3, ... firstDim-1
|
||||
|
||||
// shuffle indices
|
||||
fisherYates<int>(rng, indices.data(), firstDim, 1, 0);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
|
||||
for (auto i = start; i < stop; ++i)
|
||||
subArrsListOut.at(i)->assign(subArrsListIn.at(indices[i]));
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(func, 0, firstDim);
|
||||
}
|
||||
|
||||
rng.rewindH(firstDim-1);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (input, output, rng, isInplace), LIBND4J_TYPES);
|
||||
}
|
||||
void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (input, output, rng, isInplace), LIBND4J_TYPES);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ __global__ static void concatCuda(void* pVx, void* pxShapeInfo, void* vz, const
|
|||
|
||||
int coords[MAX_RANK];
|
||||
|
||||
for (uint64_t i = tid; i < zLen; i += totalThreads) {
|
||||
for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
|
||||
shape::index2coords(i, zShapeInfo, coords);
|
||||
|
||||
const auto zOffset = shape::getOffset(zShapeInfo, coords);
|
||||
|
@ -162,9 +162,9 @@ void concat(sd::LaunchContext * context, const std::vector<const NDArray*>& inAr
|
|||
// }
|
||||
// else { // general (slower) case
|
||||
|
||||
const int threadsPerBlock = 256;
|
||||
const int blocksPerGrid = 512;
|
||||
const int sharedMem = 512;
|
||||
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||
const int blocksPerGrid = (output.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||
const int sharedMem = 256;
|
||||
|
||||
// prepare arrays of pointers on buffers and shapes
|
||||
std::vector<const void*> hInBuffers(numOfInArrs);
|
||||
|
|
|
@ -0,0 +1,228 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2020 Konduit K.K.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||
// implemented algorithm is GPU adaptation of algorithm described in following article:
|
||||
// "MergeShuffle: A Very Fast, Parallel Random Permutation Algorithm", https://arxiv.org/abs/1508.03167
|
||||
//
|
||||
|
||||
#include<ops/declarable/helpers/transforms.h>
|
||||
#include <array/ResultSet.h>
|
||||
#include <numeric>
|
||||
#include <execution/Threads.h>
|
||||
#include <helpers/ShapeUtils.h>
|
||||
#include <helpers/PointersManager.h>
|
||||
|
||||
namespace sd {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
static __global__ void fisherYatesCuda(sd::graph::RandomGenerator* rng, void* vx, const Nd4jLong ews, const Nd4jLong len, const int power) {
|
||||
|
||||
T* x = reinterpret_cast<T*>(vx);
|
||||
|
||||
__shared__ T* shmem, temp;
|
||||
__shared__ Nd4jLong ind, blockOffset, lenPerBlock;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
extern __shared__ unsigned char sharedMemory[];
|
||||
shmem = reinterpret_cast<T*>(sharedMemory);
|
||||
|
||||
blockOffset = (len * blockIdx.x) >> power;
|
||||
lenPerBlock = ((len * (blockIdx.x + 1)) >> power) - blockOffset;
|
||||
ind = blockOffset;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// copy from global memory to shared memory
|
||||
if(threadIdx.x < lenPerBlock)
|
||||
shmem[threadIdx.x] = x[(blockOffset + threadIdx.x) * ews];
|
||||
__syncthreads();
|
||||
|
||||
// *** apply Fisher-Yates shuffle to lenPerBlock number of elements
|
||||
if (threadIdx.x == 0) {
|
||||
for(Nd4jLong i = lenPerBlock - 1; i > 0; --i) {
|
||||
const Nd4jLong j = rng->relativeLong(ind++) % (i + 1);
|
||||
if(i != j) {
|
||||
temp = shmem[i];
|
||||
shmem[i] = shmem[j];
|
||||
shmem[j] = temp;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// copy from shared memory to global memory
|
||||
if(threadIdx.x < lenPerBlock)
|
||||
x[(blockOffset + threadIdx.x) * ews] = shmem[threadIdx.x];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __global__ void mergeShuffleCuda(sd::graph::RandomGenerator* rng, void* vx, const Nd4jLong ews, const Nd4jLong len, const int power, const Nd4jLong iterNum) {
|
||||
|
||||
|
||||
T* x = reinterpret_cast<T*>(vx);
|
||||
|
||||
__shared__ Nd4jLong ind, blockOffset, factor, beg, mid, totLen, iterExp;
|
||||
|
||||
// *** apply mergeShuffle algorithm
|
||||
if(threadIdx.x == 0) {
|
||||
|
||||
factor = blockIdx.x << iterNum;
|
||||
iterExp = 1 << (iterNum - 1);
|
||||
blockOffset = (len * factor) >> power;
|
||||
mid = ((len * (factor + iterExp)) >> power) - blockOffset; // middle
|
||||
totLen = ((len * (factor + 2*iterExp)) >> power) - blockOffset;
|
||||
ind = iterNum * len + blockOffset;
|
||||
beg = 0; // beginning
|
||||
|
||||
// printf("m %lld, blockIdx.x %lld, factor %lld, blockOffset %lld, mid %lld, totLen %lld \n", m,k,factor,blockOffset,mid,totLen);
|
||||
|
||||
while (true) {
|
||||
if(rng->relativeLong(ind++) % 2) {
|
||||
if(mid == totLen)
|
||||
break;
|
||||
math::nd4j_swap<T>(x[(blockOffset + beg) * ews], x[(blockOffset + mid++) * ews]);
|
||||
} else {
|
||||
if(beg == mid)
|
||||
break;
|
||||
}
|
||||
++beg;
|
||||
}
|
||||
|
||||
// Fisher-Yates
|
||||
while (beg < totLen) {
|
||||
const Nd4jLong e = rng->relativeLong(ind++) % (beg + 1);
|
||||
if(beg != e)
|
||||
math::nd4j_swap<T>(x[(blockOffset + beg) * ews], x[(blockOffset + e) * ews]);
|
||||
++beg;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// Fisher-Yates shuffle
|
||||
template <typename T>
|
||||
static void fisherYates(sd::graph::RandomGenerator& rng, T* buff, const Nd4jLong& len, const Nd4jLong& ews, Nd4jLong ind) {
|
||||
|
||||
for(Nd4jLong i = len-1; i > 0; --i) {
|
||||
const Nd4jLong j = rng.relativeLong(ind++) % (i + 1);
|
||||
if(i != j)
|
||||
math::nd4j_swap<T>(buff[i*ews], buff[j*ews]);
|
||||
}
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
static void randomShuffle_(sd::LaunchContext* context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
|
||||
|
||||
const int firstDim = input.sizeAt(0);
|
||||
int temp;
|
||||
|
||||
if(input.lengthOf() == 1 || firstDim == 1) {
|
||||
|
||||
if(!isInplace)
|
||||
output.assign(input);
|
||||
}
|
||||
else if (shape::isCommonVector(input.shapeInfo(), temp)) {
|
||||
|
||||
NDArray* arr = &input;
|
||||
|
||||
if (!isInplace) {
|
||||
output.assign(input);
|
||||
arr = &output;
|
||||
}
|
||||
|
||||
const Nd4jLong len = arr->lengthOf();
|
||||
|
||||
const int threadsPerBlock = MAX_NUM_THREADS;
|
||||
|
||||
int power = 0;
|
||||
while ((len >> power) > threadsPerBlock)
|
||||
++power;
|
||||
|
||||
const int blocksPerGrid = 1 << power;
|
||||
const int sharedMem = threadsPerBlock * input.sizeOfT() + 256;
|
||||
|
||||
PointersManager manager(context, "NDArray::randomShuffle cuda");
|
||||
|
||||
sd::graph::RandomGenerator* pRng = reinterpret_cast<sd::graph::RandomGenerator*>(manager.replicatePointer(&rng, sizeof(sd::graph::RandomGenerator)));
|
||||
|
||||
NDArray::prepareSpecialUse({arr}, {arr});
|
||||
fisherYatesCuda<T><<<blocksPerGrid, threadsPerBlock, sharedMem, *context->getCudaStream()>>>(pRng, arr->specialBuffer(), arr->ews(), len, power);
|
||||
for (Nd4jLong j = 1, i = 1; j < blocksPerGrid; j += j, ++i)
|
||||
mergeShuffleCuda<T><<<blocksPerGrid/(2*j), threadsPerBlock, 256, *context->getCudaStream()>>>(pRng, arr->specialBuffer(), arr->ews(), len, power, i);
|
||||
NDArray::registerSpecialUse({arr}, {arr});
|
||||
|
||||
manager.synchronize();
|
||||
|
||||
rng.rewindH((len + 1) * power);
|
||||
}
|
||||
else {
|
||||
|
||||
auto dimsToExclude = ShapeUtils::evalDimsToExclude(input.rankOf(), {0});
|
||||
|
||||
if(isInplace) {
|
||||
|
||||
auto subArrsList = input.allTensorsAlongDimension(dimsToExclude);
|
||||
|
||||
// Fisher-Yates shuffle
|
||||
for(int i = firstDim - 1; i > 0; --i) {
|
||||
const int j = rng.relativeInt(i) % (i + 1);
|
||||
if(i != j)
|
||||
subArrsList.at(i)->swapUnsafe(*subArrsList.at(j));
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
||||
auto subArrsListIn = input.allTensorsAlongDimension(dimsToExclude);
|
||||
auto subArrsListOut = output.allTensorsAlongDimension(dimsToExclude);
|
||||
|
||||
std::vector<int> indices(firstDim);
|
||||
std::iota(indices.begin(), indices.end(), 0); // 0,1,2,3, ... firstDim-1
|
||||
|
||||
// shuffle indices
|
||||
fisherYates<int>(rng, indices.data(), firstDim, 1, 0);
|
||||
|
||||
auto func = PRAGMA_THREADS_FOR {
|
||||
|
||||
for (auto i = start; i < stop; ++i)
|
||||
subArrsListOut.at(i)->assign(subArrsListIn.at(indices[i]));
|
||||
};
|
||||
|
||||
samediff::Threads::parallel_for(func, 0, firstDim);
|
||||
}
|
||||
|
||||
rng.rewindH(firstDim-1);
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////
|
||||
void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (context, input, output, rng, isInplace), LIBND4J_TYPES);
|
||||
}
|
||||
|
||||
// BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (sd::LaunchContext* context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace), LIBND4J_TYPES);
|
||||
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
}
|
|
@ -300,129 +300,6 @@ void tileBP(sd::LaunchContext * context, const NDArray& gradO /*input*/, NDArray
|
|||
manager.synchronize();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __global__ void swapShuffleKernel(T* input, Nd4jLong const* shape, Nd4jLong firstDim, sd::graph::RandomGenerator* rng) {
|
||||
auto tid = blockIdx.x * blockDim.x;
|
||||
auto step = blockDim.x * gridDim.x;
|
||||
|
||||
for (int i = firstDim - 1 - tid - threadIdx.x; i > 0; i -= step) {
|
||||
int r = rng->relativeInt(i) % i;
|
||||
if (i != r) {
|
||||
const auto iOffset = shape::getIndexOffset(i, shape);
|
||||
const auto rOffset = shape::getIndexOffset(r, shape);
|
||||
T e0 = input[iOffset];
|
||||
T e1 = input[rOffset];
|
||||
//math::nd4j_swap<T>(input(i), input(r));
|
||||
input[iOffset] = e1;
|
||||
input[rOffset] = e0;
|
||||
}
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
static __global__ void fillShuffleKernel(T* input, Nd4jLong const* inputShape, T* output, Nd4jLong const* outputShape, Nd4jLong firstDim, int* indices, sd::graph::RandomGenerator* rng) {
|
||||
|
||||
// PRAGMA_OMP_PARALLEL_FOR_IF((firstDim-1) > Environment::getInstance().tadThreshold())
|
||||
auto tid = blockIdx.x * blockDim.x;
|
||||
auto step = blockDim.x * gridDim.x;
|
||||
|
||||
for(int i = firstDim - 1 - tid - threadIdx.x; i > 0; i -= step) {
|
||||
int r = rng->relativeInt(i) % i;
|
||||
output[shape::getIndexOffset(i, outputShape)] = input[shape::getIndexOffset(indices[r], inputShape)];
|
||||
if(i != r) {
|
||||
output[shape::getIndexOffset(r, outputShape)] = input[shape::getIndexOffset(indices[i], inputShape)];
|
||||
// output.p(r, input.e<T>(indices[i]));
|
||||
// math::nd4j_swap<int>(indices[i], indices[r]);
|
||||
atomicExch(&indices[i], indices[r]);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
template <typename T>
|
||||
void randomShuffle_(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
|
||||
|
||||
// check edge cases first
|
||||
int temp;
|
||||
const int firstDim = input.sizeAt(0);
|
||||
auto stream = context->getCudaStream();
|
||||
NDArray::prepareSpecialUse({&output}, {&input});
|
||||
if(input.lengthOf() == 1 || firstDim == 1) {
|
||||
if(!isInplace)
|
||||
output.assign(input);
|
||||
}
|
||||
else if (input.isVector() || shape::isLikeVector(input.shapeInfo(), temp)) {
|
||||
|
||||
// apply Fisher-Yates shuffle
|
||||
sd::graph::RandomGenerator* dRandom = nullptr;
|
||||
cudaMalloc(&dRandom, sizeof(sd::graph::RandomGenerator));
|
||||
cudaMemcpy(dRandom, &rng, sizeof(sd::graph::RandomGenerator), cudaMemcpyHostToDevice);
|
||||
T* inputBuf = reinterpret_cast<T*>(input.specialBuffer());
|
||||
if(isInplace) {
|
||||
swapShuffleKernel<T><<<128, 256, 1024, *stream>>>(inputBuf, input.specialShapeInfo(), firstDim, dRandom);
|
||||
}
|
||||
else {
|
||||
std::vector<int> indices(firstDim);
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
cudaMemcpy(output.specialBuffer(), input.specialBuffer(), sizeof(T), cudaMemcpyDeviceToDevice);
|
||||
//output.p<T>(Nd4jLong(0), input.e<T>(0));
|
||||
PointersManager pointersManager(context, "helper::randomShuffle_");
|
||||
int* indicesDev = reinterpret_cast<int*>(pointersManager.replicatePointer(indices.data(), indices.size() * sizeof(int)));
|
||||
T* outputBuf = reinterpret_cast<T*>(output.specialBuffer());
|
||||
fillShuffleKernel<T><<<128, 256, 1024, *stream>>>(inputBuf, input.specialShapeInfo(), outputBuf, output.specialShapeInfo(), firstDim, indicesDev, dRandom);
|
||||
pointersManager.synchronize();
|
||||
}
|
||||
// rng.rewindH(firstDim - 1);
|
||||
cudaFree(dRandom);
|
||||
}
|
||||
else {
|
||||
|
||||
// evaluate sub-arrays list of input array through all dimensions excluding first one
|
||||
std::vector<int> dimensions = ShapeUtils::evalDimsToExclude(input.rankOf(), {0});
|
||||
auto subArrsListIn = input.allTensorsAlongDimension(dimensions);
|
||||
|
||||
// apply Fisher-Yates shuffle
|
||||
if(isInplace) {
|
||||
for(int i = firstDim - 1; i > 0; --i) {
|
||||
int r = rng.relativeInt(i) % i;
|
||||
|
||||
if(i != r)
|
||||
subArrsListIn.at(i)->swapUnsafe(*subArrsListIn.at(r));
|
||||
}
|
||||
}
|
||||
else {
|
||||
// evaluate sub-arrays list of output array through all dimensions excluding first one
|
||||
auto subArrsListOut = output.allTensorsAlongDimension(dimensions);
|
||||
std::vector<int> indices(firstDim);
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
bool isZeroShuffled = false;
|
||||
|
||||
for(int i = firstDim - 1; i > 0; --i) {
|
||||
int r = rng.relativeInt(i) % i;
|
||||
subArrsListOut.at(i)->assign(subArrsListIn.at(indices[r]));
|
||||
if(r == 0)
|
||||
isZeroShuffled = true;
|
||||
|
||||
if(i != r) {
|
||||
subArrsListOut.at(r)->assign(subArrsListIn.at(indices[i]));
|
||||
math::nd4j_swap<int>(indices[i], indices[r]);
|
||||
}
|
||||
}
|
||||
if(!isZeroShuffled)
|
||||
subArrsListOut.at(0)->assign(subArrsListIn.at(0));
|
||||
}
|
||||
rng.rewindH(firstDim-1);
|
||||
}
|
||||
NDArray::registerSpecialUse({&output}, {&input});
|
||||
|
||||
}
|
||||
|
||||
void randomShuffle(sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace) {
|
||||
BUILD_SINGLE_SELECTOR(input.dataType(), randomShuffle_, (context, input, output, rng, isInplace), LIBND4J_TYPES);
|
||||
}
|
||||
|
||||
BUILD_SINGLE_TEMPLATE(template void randomShuffle_, (sd::LaunchContext * context, NDArray& input, NDArray& output, sd::graph::RandomGenerator& rng, const bool isInplace), LIBND4J_TYPES);
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
void eye(sd::LaunchContext * context, NDArray& output) {
|
||||
|
||||
|
|
|
@ -419,3 +419,4 @@ TEST_F(DeclarableOpsTests19, test_squeeze_1) {
|
|||
auto status = op.execute({&x}, {&e}, {axis});
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
}
|
||||
|
||||
|
|
|
@ -1557,8 +1557,6 @@ TEST_F(DeclarableOpsTests5, trace_test1) {
|
|||
// exp.printIndexedBuffer("EXP TRACE");
|
||||
// output->printIndexedBuffer("OUT TRACE");
|
||||
ASSERT_TRUE(exp.equalsTo(output));
|
||||
|
||||
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
@ -1575,8 +1573,6 @@ TEST_F(DeclarableOpsTests5, trace_test2) {
|
|||
ASSERT_EQ(Status::OK(), results.status());
|
||||
ASSERT_TRUE(exp.isSameShape(output));
|
||||
ASSERT_TRUE(exp.equalsTo(output));
|
||||
|
||||
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
@ -1593,8 +1589,6 @@ TEST_F(DeclarableOpsTests5, trace_test3) {
|
|||
ASSERT_EQ(Status::OK(), results.status());
|
||||
ASSERT_TRUE(exp.isSameShape(output));
|
||||
ASSERT_TRUE(exp.equalsTo(output));
|
||||
|
||||
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
@ -1611,8 +1605,6 @@ TEST_F(DeclarableOpsTests5, trace_test4) {
|
|||
ASSERT_EQ(Status::OK(), results.status());
|
||||
ASSERT_TRUE(exp.isSameShape(output));
|
||||
ASSERT_TRUE(exp.equalsTo(output));
|
||||
|
||||
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
@ -1629,8 +1621,6 @@ TEST_F(DeclarableOpsTests5, trace_test5) {
|
|||
ASSERT_EQ(Status::OK(), results.status());
|
||||
ASSERT_TRUE(exp.isSameShape(output));
|
||||
ASSERT_TRUE(exp.equalsTo(output));
|
||||
|
||||
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
@ -1638,22 +1628,15 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test1) {
|
|||
|
||||
auto input = NDArrayFactory::create<double>('c', {2, 2, 2});
|
||||
input.linspace(1);
|
||||
NDArray exp1 = input.dup();
|
||||
NDArray exp2('c',{2,2,2}, {5,6,7,8, 1,2,3,4}, sd::DataType::DOUBLE);
|
||||
|
||||
sd::ops::random_shuffle op;
|
||||
auto results = op.evaluate({&input});
|
||||
auto output = results.at(0);
|
||||
|
||||
bool haveZeros = false;
|
||||
for(int i = 0; i < output->lengthOf(); ++i)
|
||||
if(output->e<float>(i) == (float)0.)
|
||||
haveZeros = true;
|
||||
|
||||
ASSERT_EQ(Status::OK(), results.status());
|
||||
ASSERT_TRUE(input.isSameShape(output));
|
||||
ASSERT_TRUE(!input.equalsTo(output));
|
||||
ASSERT_TRUE(!haveZeros);
|
||||
|
||||
|
||||
ASSERT_TRUE(output->equalsTo(exp1) || output->equalsTo(exp2));
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
@ -1661,16 +1644,14 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test2) {
|
|||
|
||||
auto input = NDArrayFactory::create<double>('c', {1, 3, 2});
|
||||
input.linspace(1);
|
||||
NDArray exp1 = input.dup();
|
||||
|
||||
sd::ops::random_shuffle op;
|
||||
auto results = op.evaluate({&input});
|
||||
auto output = results.at(0);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results.status());
|
||||
ASSERT_TRUE(input.isSameShape(output));
|
||||
ASSERT_TRUE(input.equalsTo(output));
|
||||
|
||||
|
||||
ASSERT_TRUE(output->equalsTo(exp1));
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
@ -1678,129 +1659,132 @@ TEST_F(DeclarableOpsTests5, random_shuffle_test3) {
|
|||
|
||||
auto input = NDArrayFactory::create<double>('c', {3, 2, 1});
|
||||
input.linspace(1);
|
||||
NDArray exp1 = input.dup();
|
||||
NDArray exp2('c',{3,2,1}, {1,2, 5,6, 3,4}, sd::DataType::DOUBLE);
|
||||
NDArray exp3('c',{3,2,1}, {3,4, 1,2, 5,6}, sd::DataType::DOUBLE);
|
||||
NDArray exp4('c',{3,2,1}, {3,4, 5,6, 1,2}, sd::DataType::DOUBLE);
|
||||
NDArray exp5('c',{3,2,1}, {5,6, 1,2, 3,4}, sd::DataType::DOUBLE);
|
||||
NDArray exp6('c',{3,2,1}, {5,6, 3,4, 1,2}, sd::DataType::DOUBLE);
|
||||
|
||||
sd::ops::random_shuffle op;
|
||||
auto results = op.evaluate({&input});
|
||||
auto output = results.at(0);
|
||||
|
||||
bool haveZeros = false;
|
||||
for(int i = 0; i < output->lengthOf(); ++i)
|
||||
if(output->e<float>(i) == (float)0.)
|
||||
haveZeros = true;
|
||||
|
||||
ASSERT_EQ(Status::OK(), results.status());
|
||||
ASSERT_TRUE(input.isSameShape(output));
|
||||
ASSERT_TRUE(!input.equalsTo(output));
|
||||
ASSERT_TRUE(!haveZeros);
|
||||
|
||||
|
||||
}
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests5, random_shuffle_test04) {
|
||||
auto input = NDArrayFactory::create<double>('c', {4});
|
||||
input.linspace(1);
|
||||
|
||||
sd::ops::random_shuffle op;
|
||||
//NDArray* output;
|
||||
auto results = op.evaluate({&input}, {}, {}, {}, {}, true);
|
||||
|
||||
ASSERT_EQ(Status::OK(), results.status());
|
||||
auto output = &input; //results.at(0);
|
||||
bool haveZeros = false;
|
||||
for(int i = 0; i < output->lengthOf(); ++i)
|
||||
if(output->e<float>(i) == (float)0.)
|
||||
haveZeros = true;
|
||||
|
||||
ASSERT_TRUE(input.isSameShape(output));
|
||||
//ASSERT_TRUE(!input.equalsTo(output));
|
||||
ASSERT_TRUE(!haveZeros);
|
||||
|
||||
|
||||
ASSERT_TRUE(input.equalsTo(exp1) || input.equalsTo(exp2) || input.equalsTo(exp3)
|
||||
|| input.equalsTo(exp4) || input.equalsTo(exp5) || input.equalsTo(exp6));
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests5, random_shuffle_test4) {
|
||||
auto input = NDArrayFactory::create<double>('c', {4});
|
||||
|
||||
auto input = NDArrayFactory::create<double>('c', {3, 2, 1});
|
||||
input.linspace(1);
|
||||
NDArray exp1 = input.dup();
|
||||
NDArray exp2('c',{3,2,1}, {1,2, 5,6, 3,4}, sd::DataType::DOUBLE);
|
||||
NDArray exp3('c',{3,2,1}, {3,4, 1,2, 5,6}, sd::DataType::DOUBLE);
|
||||
NDArray exp4('c',{3,2,1}, {3,4, 5,6, 1,2}, sd::DataType::DOUBLE);
|
||||
NDArray exp5('c',{3,2,1}, {5,6, 1,2, 3,4}, sd::DataType::DOUBLE);
|
||||
NDArray exp6('c',{3,2,1}, {5,6, 3,4, 1,2}, sd::DataType::DOUBLE);
|
||||
|
||||
sd::ops::random_shuffle op;
|
||||
//NDArray* output;
|
||||
auto results = op.evaluate({&input});
|
||||
ASSERT_EQ(Status::OK(), results.status());
|
||||
auto output = results.at(0);
|
||||
bool haveZeros = false;
|
||||
for(int i = 0; i < output->lengthOf(); ++i)
|
||||
if(output->e<float>(i) == (float)0.)
|
||||
haveZeros = true;
|
||||
|
||||
ASSERT_TRUE(input.isSameShape(output));
|
||||
//ASSERT_TRUE(!input.equalsTo(output));
|
||||
ASSERT_TRUE(!haveZeros);
|
||||
|
||||
|
||||
ASSERT_EQ(Status::OK(), results.status());
|
||||
ASSERT_TRUE(output->equalsTo(exp1) || output->equalsTo(exp2) || output->equalsTo(exp3)
|
||||
|| output->equalsTo(exp4) || output->equalsTo(exp5) || output->equalsTo(exp6));
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests5, random_shuffle_test5) {
|
||||
|
||||
auto input = NDArrayFactory::create<double>('c', {4,1});
|
||||
auto input = NDArrayFactory::create<int>('c', {4});
|
||||
input.linspace(1);
|
||||
|
||||
sd::ops::random_shuffle op;
|
||||
auto results = op.evaluate({&input});
|
||||
auto results = op.evaluate({&input}, {}, {}, {}, {}, false);
|
||||
auto output = results.at(0);
|
||||
|
||||
bool haveZeros = false;
|
||||
for(int i = 0; i < output->lengthOf(); ++i)
|
||||
if(output->e<float>(i) == (float)0.)
|
||||
haveZeros = true;
|
||||
// output->printBuffer();
|
||||
|
||||
ASSERT_EQ(Status::OK(), results.status());
|
||||
ASSERT_TRUE(input.isSameShape(output));
|
||||
ASSERT_TRUE(!input.equalsTo(output));
|
||||
ASSERT_TRUE(!haveZeros);
|
||||
|
||||
// ASSERT_TRUE(!output->equalsTo(input));
|
||||
|
||||
bool hasDublicates = false;
|
||||
for(int i = 0; i < output->lengthOf() - 1; ++i)
|
||||
for(int j = i+1; j < output->lengthOf(); ++j)
|
||||
if(output->t<int>(i) == output->t<int>(j)) {
|
||||
hasDublicates = true;
|
||||
i = output->lengthOf();
|
||||
break;
|
||||
}
|
||||
ASSERT_TRUE(!hasDublicates);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests5, random_shuffle_test6) {
|
||||
|
||||
auto input = NDArrayFactory::create<double>('c', {4,1,1});
|
||||
auto input = NDArrayFactory::create<int>('c', {4,1,1});
|
||||
input.linspace(1);
|
||||
|
||||
sd::ops::random_shuffle op;
|
||||
auto results = op.evaluate({&input});
|
||||
auto results = op.evaluate({&input}, {}, {}, {}, {}, false);
|
||||
auto output = results.at(0);
|
||||
|
||||
bool haveZeros = false;
|
||||
for(int i = 0; i < output->lengthOf(); ++i)
|
||||
if(output->e<float>(i) == (float)0.)
|
||||
haveZeros = true;
|
||||
|
||||
ASSERT_EQ(Status::OK(), results.status());
|
||||
ASSERT_TRUE(input.isSameShape(output));
|
||||
ASSERT_TRUE(!input.equalsTo(output));
|
||||
ASSERT_TRUE(!haveZeros);
|
||||
|
||||
// ASSERT_TRUE(!output->equalsTo(input));
|
||||
|
||||
bool hasDublicates = false;
|
||||
for(int i = 0; i < output->lengthOf() - 1; ++i)
|
||||
for(int j = i+1; j < output->lengthOf(); ++j)
|
||||
if(output->t<int>(i) == output->t<int>(j)) {
|
||||
hasDublicates = true;
|
||||
i = output->lengthOf();
|
||||
break;
|
||||
}
|
||||
ASSERT_TRUE(!hasDublicates);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests5, random_shuffle_test7) {
|
||||
|
||||
auto input = NDArrayFactory::create<double>('c', {1,4});
|
||||
auto input = NDArrayFactory::create<int>('c', {16010});
|
||||
input.linspace(1);
|
||||
auto exp = NDArrayFactory::create<double>('c', {1,4}, {1, 2, 3, 4});
|
||||
|
||||
sd::ops::random_shuffle op;
|
||||
auto results = op.evaluate({&input});
|
||||
auto results = op.evaluate({&input}, {}, {}, {}, {}, false);
|
||||
auto output = results.at(0);
|
||||
|
||||
// output->printBuffer();
|
||||
ASSERT_EQ(Status::OK(), results.status());
|
||||
ASSERT_TRUE(input.isSameShape(output));
|
||||
ASSERT_TRUE(input.equalsTo(output));
|
||||
ASSERT_TRUE(!output->equalsTo(input));
|
||||
|
||||
auto vec1 = input.getBufferAsVector<int>();
|
||||
auto vec2 = output->getBufferAsVector<int>();
|
||||
std::sort(vec2.begin(), vec2.end());
|
||||
ASSERT_TRUE(std::equal(vec1.begin(), vec1.end(), vec2.begin()));
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests5, random_shuffle_test8) {
|
||||
auto input = NDArrayFactory::create<int>('c', {1,4,1});
|
||||
input.linspace(1);
|
||||
NDArray inCopy = input.dup();
|
||||
|
||||
sd::ops::random_shuffle op;
|
||||
auto results = op.evaluate({&input}, {}, {}, {}, {}, false);
|
||||
ASSERT_EQ(Status::OK(), results.status());
|
||||
ASSERT_TRUE(input.equalsTo(inCopy));
|
||||
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests5, random_shuffle_test9) {
|
||||
|
||||
auto x = NDArrayFactory::create<int>('c', {4}, {1, 2, 3, 4});
|
||||
auto z = x.ulike();
|
||||
|
||||
sd::ops::random_shuffle op;
|
||||
auto status = op.execute({&x}, {&z});
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
|
||||
auto vec = z.getBufferAsVector<int>();
|
||||
std::sort(vec.begin(), vec.end());
|
||||
ASSERT_EQ(std::vector<int>({1, 2, 3, 4}), vec);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -251,11 +251,10 @@ TEST_F(DeclarableOpsTests9, concat_test1) {
|
|||
auto result = op.evaluate({&x0, &x1, &x2}, {}, {1});
|
||||
ASSERT_EQ(ND4J_STATUS_OK, result.status());
|
||||
auto output = result.at(0);
|
||||
// output->printCurrentBuffer<float>(false);
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(output));
|
||||
ASSERT_TRUE(exp.equalsTo(output));
|
||||
|
||||
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -317,7 +317,7 @@ void fill_random(sd::NDArray& arr) {
|
|||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void testLegacy(bool random) {
|
||||
#if 0
|
||||
int bases[] = { 3, 2, 4, 5, 7 };
|
||||
|
@ -364,7 +364,7 @@ int k = 4;
|
|||
#endif
|
||||
auto dim = NDArrayFactory::create<int>(dimension);
|
||||
|
||||
#if 1
|
||||
#if 1
|
||||
nd4j_printf("C(N:%d K:%d) \n", N, k);
|
||||
dim.printIndexedBuffer("Dimension");
|
||||
for (int xind : dimension) {
|
||||
|
@ -385,7 +385,7 @@ for (int e = 0; e < Loop; e++) {
|
|||
auto outerTime = std::chrono::duration_cast<std::chrono::microseconds>(timeEnd - timeStart).count();
|
||||
values.emplace_back(outerTime);
|
||||
}
|
||||
|
||||
|
||||
std::sort(values.begin(), values.end());
|
||||
|
||||
nd4j_printf("Time: %lld us;\n", values[values.size() / 2]);
|
||||
|
@ -411,7 +411,7 @@ void testNewReduction(bool random, bool checkCorrectness = false , char order ='
|
|||
constexpr int N = 5;
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
for (int i = 0; i < N; i++) {
|
||||
arr_dimensions.push_back(bases[i]);
|
||||
}
|
||||
|
@ -451,7 +451,7 @@ void testNewReduction(bool random, bool checkCorrectness = false , char order ='
|
|||
#endif
|
||||
auto dim = NDArrayFactory::create<int>(dimension);
|
||||
|
||||
#if 1
|
||||
#if 1
|
||||
nd4j_printf("C(N:%d K:%d) \n", N, k);
|
||||
dim.printIndexedBuffer("Dimension");
|
||||
for (int xind : dimension) {
|
||||
|
@ -477,14 +477,14 @@ void testNewReduction(bool random, bool checkCorrectness = false , char order ='
|
|||
//check for the correctness
|
||||
NDArray exp = output_bases.size() > 0 ? NDArrayFactory::create<Nd4jLong>('c', output_bases) : NDArrayFactory::create<Nd4jLong>(0);
|
||||
original_argmax(x, dimension, exp);
|
||||
|
||||
|
||||
|
||||
#if 0// defined(DEBUG)
|
||||
x.printIndexedBuffer("X");
|
||||
exp.printIndexedBuffer("Expected");
|
||||
z->printIndexedBuffer("Z");
|
||||
#endif
|
||||
|
||||
|
||||
ASSERT_TRUE(exp.isSameShape(z));
|
||||
ASSERT_TRUE(exp.equalsTo(z));
|
||||
}
|
||||
|
@ -505,7 +505,7 @@ TEST_F(PlaygroundTests, ArgMaxPerfLinspace) {
|
|||
testNewReduction(false, test_corr);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
TEST_F(PlaygroundTests, ArgMaxPerfRandom) {
|
||||
testNewReduction(true, test_corr);
|
||||
}
|
||||
|
@ -513,7 +513,7 @@ TEST_F(PlaygroundTests, ArgMaxPerfRandom) {
|
|||
TEST_F(PlaygroundTests, ArgMaxPerfRandomOrderF) {
|
||||
testNewReduction(true, test_corr, 'f');
|
||||
}
|
||||
|
||||
|
||||
#if !defined(DEBUG)
|
||||
TEST_F(PlaygroundTests, ArgMaxPerfLegacyLinspace) {
|
||||
testLegacy(false);
|
||||
|
@ -1062,39 +1062,6 @@ TEST_F(PlaygroundTests, my) {
|
|||
delete variableSpace;
|
||||
}
|
||||
|
||||
TEST_F(PlaygroundTests, my) {
|
||||
|
||||
int N = 100;
|
||||
int bS=16, iH=128,iW=128, iC=32,oC=64, kH=4,kW=4, sH=1,sW=1, pH=0,pW=0, dH=1,dW=1;
|
||||
int oH=128,oW=128;
|
||||
|
||||
int paddingMode = 1; // 1-SAME, 0-VALID;
|
||||
int dataFormat = 1; // 1-NHWC, 0-NCHW
|
||||
|
||||
// NDArray input('c', {bS, iC, iH, iW}, sd::DataType::FLOAT32);
|
||||
// NDArray output('c', {bS, oC, oH, oW}, sd::DataType::FLOAT32);
|
||||
NDArray input('c', {bS, iH, iW, iC}, sd::DataType::FLOAT32);
|
||||
NDArray output('c', {bS, oH, oW, oC}, sd::DataType::FLOAT32);
|
||||
// NDArray weights('c', {kH, kW, iC, oC}, sd::DataType::FLOAT32); // permute [kH, kW, iC, oC] -> [oC, iC, kH, kW]
|
||||
NDArray weights('c', {oC, iC, kH, kW}, sd::DataType::FLOAT32);
|
||||
NDArray bias('c', {oC}, sd::DataType::FLOAT32);
|
||||
|
||||
input = 5.;
|
||||
weights = 3.;
|
||||
bias = 1.;
|
||||
|
||||
sd::ops::conv2d op;
|
||||
auto err = op.execute({&input, &weights, &bias}, {&output}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||
|
||||
auto timeStart = std::chrono::system_clock::now();
|
||||
for (int i = 0; i < N; ++i)
|
||||
err = op.execute({&input, &weights, &bias}, {&output}, {kH,kW, sH,sW, pH,pW, dH,dW, paddingMode, dataFormat});
|
||||
auto timeEnd = std::chrono::system_clock::now();
|
||||
auto time = std::chrono::duration_cast<std::chrono::microseconds> ((timeEnd - timeStart) / N).count();
|
||||
|
||||
printf("time: %i \n", time);
|
||||
}
|
||||
|
||||
///////////////////////////////////////////////////////////////////
|
||||
TEST_F(PlaygroundTests, lstmLayerCellBp_1) {
|
||||
|
||||
|
@ -1690,6 +1657,52 @@ TEST_F(DeclarableOpsTests15, gru_bp_1) {
|
|||
const bool isGradCorrect = GradCheck::checkGrad(opFF, opBP, argsHolderFF, argsHolderBP);
|
||||
}
|
||||
|
||||
#include<ops/declarable/helpers/transforms.h>
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
TEST_F(PlaygroundTests, my) {
|
||||
|
||||
const int N = 10;
|
||||
|
||||
NDArray input('c', {8000000}, sd::DataType::INT32);
|
||||
input.linspace(1);
|
||||
NDArray output = input.dup();
|
||||
|
||||
|
||||
sd::graph::RandomGenerator rng;
|
||||
|
||||
sd::ops::helpers::randomShuffle(input.getContext(), input, output, rng, true);
|
||||
|
||||
// auto timeStart = std::chrono::system_clock::now();
|
||||
// for (int i = 0; i < N; ++i)
|
||||
// sd::ops::helpers::randomShuffle(input.getContext(), input, output, rng, true);
|
||||
// auto timeEnd = std::chrono::system_clock::now();
|
||||
// auto time = std::chrono::duration_cast<std::chrono::microseconds> ((timeEnd - timeStart) / N).count();
|
||||
// printf("time: %i \n", time);
|
||||
|
||||
// bool hasDublicates = false;
|
||||
// for(int i = 0; i < output.lengthOf() - 1; ++i)
|
||||
// for(int j = i+1; j < output.lengthOf(); ++j)
|
||||
// if(output.t<int>(i) == output.t<int>(j)) {
|
||||
// hasDublicates = true;
|
||||
// i = output.lengthOf();
|
||||
// break;
|
||||
// }
|
||||
|
||||
ASSERT_TRUE(!input.equalsTo(output));
|
||||
|
||||
bool hasDublicates = false;
|
||||
for(int i = 0; i < input.lengthOf() - 1; ++i)
|
||||
for(int j = i+1; j < input.lengthOf(); ++j)
|
||||
if(input.t<int>(i) == input.t<int>(j)) {
|
||||
hasDublicates = true;
|
||||
i = input.lengthOf();
|
||||
break;
|
||||
}
|
||||
ASSERT_TRUE(!hasDublicates);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
*/
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue