[WIP] repeat op (#143)
* - write new repeat helper (cpu) Signed-off-by: Yurii <yurii@skymind.io> * - update NDArray::cpu Signed-off-by: Yurii <yurii@skymind.io> * - update NDArray::repeat cuda Signed-off-by: Yurii <yurii@skymind.io>master
parent
3cf72e5e30
commit
e604ffe0d2
|
@ -316,10 +316,10 @@ namespace nd4j {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* create a new array by replicating current array by repeats times along given dimension
|
* create a new array by replicating current array by repeats times along given dimension
|
||||||
* dimension - dimension along which to repeat elements
|
* axis - axis along which to repeat elements
|
||||||
* repeats - number of repetitions
|
* repeats - number of repetitions
|
||||||
*/
|
*/
|
||||||
NDArray* repeat(int dimension, const std::vector<Nd4jLong>& repeats) const;
|
NDArray* repeat(const int axis, const std::vector<int>& repeats) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This method fills this array with zeros
|
* This method fills this array with zeros
|
||||||
|
@ -344,9 +344,10 @@ namespace nd4j {
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* fill target array by repeating current array
|
* fill target array by repeating current array
|
||||||
* dimension - dimension along which to repeat elements
|
* axis - axis along which to repeat elements
|
||||||
|
* repeats - vector containing numbers of repetition for elements at given axis
|
||||||
*/
|
*/
|
||||||
void repeat(int dimension, NDArray& target) const;
|
void repeat(const int axis, const std::vector<int>& repeats, NDArray& target) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* creates array which points on certain sub-range of this array, sub-range is defined by given indices
|
* creates array which points on certain sub-range of this array, sub-range is defined by given indices
|
||||||
|
|
|
@ -363,79 +363,62 @@ void NDArray::tile(NDArray& target) const {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
////////////////////////////////////////////////////////////////////////
|
||||||
|
template<typename X, typename Z>
|
||||||
|
static void repeat_(const NDArray& input, NDArray& output, const std::vector<int>& repeats, const int axis) {
|
||||||
|
|
||||||
|
const X* x = input.bufferAsT<X>();
|
||||||
|
Z* z = output.bufferAsT<Z>();
|
||||||
|
|
||||||
|
const int rank = input.rankOf(); // xRank = zRank
|
||||||
|
const int zLen = output.lengthOf(); // xLen <= zLen
|
||||||
|
const int repSize = repeats.size();
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> coords(rank);
|
||||||
|
|
||||||
|
// loop through input array
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_ARGS(schedule(guided) firstprivate(coords))
|
||||||
|
for (Nd4jLong i = 0; i < zLen; ++i) {
|
||||||
|
|
||||||
|
shape::index2coords(rank, output.shapeOf(), i, zLen, coords.data());
|
||||||
|
|
||||||
|
const auto zOffset = shape::getOffset(0, output.shapeOf(), output.stridesOf(), coords.data(), rank);
|
||||||
|
|
||||||
|
if(repSize > 1) {
|
||||||
|
for (uint j = 0; j < repSize; ++j) {
|
||||||
|
coords[axis] -= repeats[j];
|
||||||
|
if (coords[axis] < 0) {
|
||||||
|
coords[axis] = j;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
coords[axis] /= repeats[0];
|
||||||
|
|
||||||
|
z[zOffset] = x[shape::getOffset(0, input.shapeOf(), input.stridesOf(), coords.data(), rank)];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// create new array by repeating it the number of times given by reps
|
// create new array by repeating it the number of times given by repeats
|
||||||
NDArray* NDArray::repeat(int dimension, const std::vector<Nd4jLong>& repeats) const {
|
NDArray* NDArray::repeat(const int axis, const std::vector<int>& repeats) const {
|
||||||
auto outShape = ShapeUtils::evalRepeatShape(dimension, repeats, *this);
|
|
||||||
|
|
||||||
// the size of outShape == rank
|
auto output = new NDArray('c', ShapeUtils::evalRepeatShape(axis, repeats, *this), dataType(), getContext());
|
||||||
int rank = rankOf(); // = outShape.size()
|
|
||||||
|
|
||||||
auto ret = new NDArray('c', outShape, dataType(), getContext());
|
BUILD_SINGLE_SELECTOR_TWICE(dataType(), repeat_, (*this, *output, repeats, axis), LIBND4J_TYPES);
|
||||||
|
|
||||||
auto retArrs = ret->allTensorsAlongDimension({dimension});
|
return output;
|
||||||
auto thisArrs = this->allTensorsAlongDimension({dimension});
|
|
||||||
|
|
||||||
auto repeatDelta = shape::prodLong(outShape.data(), rank) / this->lengthOf();
|
|
||||||
auto numTads = retArrs->size();
|
|
||||||
|
|
||||||
for (int i = 0; i < numTads; i++) {
|
|
||||||
auto thisTensor = thisArrs->at(i);
|
|
||||||
auto retTensor = retArrs->at(i);
|
|
||||||
Nd4jLong retIdx = 0;
|
|
||||||
|
|
||||||
for (Nd4jLong k = 0; k < thisTensor->lengthOf(); k++) {
|
|
||||||
auto s = thisTensor->e(k);
|
|
||||||
for (Nd4jLong j = 0; j < repeatDelta; j++)
|
|
||||||
retTensor->p(retIdx++, s);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
delete retArrs;
|
|
||||||
delete thisArrs;
|
|
||||||
return ret;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// fill array by repeating it the number of times given by reps
|
// fill array by repeating it the number of times given by reps
|
||||||
void NDArray::repeat(int dimension, NDArray& target) const {
|
void NDArray::repeat(const int axis, const std::vector<int>& repeats, NDArray& target) const {
|
||||||
|
|
||||||
if(dimension < 0)
|
if(!target.isSameShape(ShapeUtils::evalRepeatShape(axis, repeats, *this)))
|
||||||
dimension += rankOf();
|
throw std::invalid_argument("NDArray::repeat(const int axis, const std::vector<int>& repeats, NDArray& target) method: wrong shape of target array!");
|
||||||
|
|
||||||
if(rankOf() != target.rankOf())
|
BUILD_DOUBLE_SELECTOR(dataType(), target.dataType(), repeat_, (*this, target, repeats, axis), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
throw std::invalid_argument("NDArray::repeat(int dimension, NDArray& target) method: wrong rank of target array it must be equal to this array rank!");
|
|
||||||
|
|
||||||
Nd4jLong repeatDelta = target.sizeAt(dimension) / sizeAt(dimension);
|
|
||||||
|
|
||||||
if(repeatDelta == 0)
|
|
||||||
throw std::invalid_argument("NDArray::repeat(int dimension, NDArray& target) method: wrong shape of target array!");
|
|
||||||
|
|
||||||
|
|
||||||
std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(rankOf(), {dimension});
|
|
||||||
const Nd4jLong numTads = ShapeUtils::getNumOfSubArrs(getShapeInfo(), dimsToExclude);
|
|
||||||
|
|
||||||
for (int i = 0; i < numTads; i++) {
|
|
||||||
auto thisTensor = (*this)(i, dimsToExclude);
|
|
||||||
auto retTensor = target(i, dimsToExclude);
|
|
||||||
int tensorLength = thisTensor.lengthOf();
|
|
||||||
int retIdx = 0;
|
|
||||||
if (isR()) {
|
|
||||||
for (int k = 0; k < tensorLength; k++) {
|
|
||||||
auto s = thisTensor.e<double>(k);
|
|
||||||
for (int j = 0; j < repeatDelta; j++) {
|
|
||||||
retTensor.p<double>(retIdx++, s);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int k = 0; k < tensorLength; k++) {
|
|
||||||
auto s = thisTensor.e<Nd4jLong>(k);
|
|
||||||
for (int j = 0; j < repeatDelta; j++) {
|
|
||||||
retTensor.p<Nd4jLong>(retIdx++, s);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
|
|
@ -392,64 +392,115 @@ void NDArray::tile(NDArray& target) const {
|
||||||
registerSpecialUse({&target}, {this});
|
registerSpecialUse({&target}, {this});
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
// create new array by repeating it the number of times given by reps
|
template<typename X, typename Z>
|
||||||
NDArray* NDArray::repeat(int dimension, const std::vector<Nd4jLong>& repeats) const {
|
__global__ static void repeatCuda(const void* vx, const Nd4jLong* xShapeInfo,
|
||||||
auto outShape = ShapeUtils::evalRepeatShape(dimension, repeats, *this);
|
void* vz, const Nd4jLong* zShapeInfo,
|
||||||
|
const int* repeats, const int repSize,
|
||||||
|
const int axis) {
|
||||||
|
|
||||||
// the size of outShape == rank
|
const X* x = reinterpret_cast<const X*>(vx);
|
||||||
int rank = rankOf(); // = outShape.size()
|
Z* z = reinterpret_cast<Z*>(vz);
|
||||||
|
|
||||||
std::vector<Nd4jLong> newShape(rank);
|
__shared__ int rank;
|
||||||
for (int i = 0; i < rank; i++)
|
__shared__ Nd4jLong zLen, totalThreads, *sharedMem; // xLen = zLen
|
||||||
newShape[i] = outShape[i];
|
|
||||||
|
|
||||||
auto ret = new NDArray('c', outShape, dataType(), getContext());
|
if (threadIdx.x == 0) {
|
||||||
|
|
||||||
auto repeatDelta = shape::prodLong(newShape.data(), rank) / this->lengthOf();
|
extern __shared__ unsigned char shmem[];
|
||||||
std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(rankOf(), {dimension});
|
sharedMem = reinterpret_cast<Nd4jLong*>(shmem);
|
||||||
const Nd4jLong numTads = ShapeUtils::getNumOfSubArrs(getShapeInfo(), dimsToExclude); //this->tensorsAlongDimension({dimension});
|
|
||||||
std::vector<int> copy({dimension});
|
|
||||||
|
|
||||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), copy);
|
rank = shape::rank(zShapeInfo); // xRank = zRank
|
||||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(ret->getShapeInfo(), copy);
|
zLen = shape::length(zShapeInfo); // xLen <= zLen
|
||||||
|
|
||||||
prepareSpecialUse({ret}, {this});
|
totalThreads = gridDim.x * blockDim.x;
|
||||||
auto stream = getContext()->getCudaStream();
|
}
|
||||||
BUILD_SINGLE_SELECTOR(dataType(), repeatKernelH, (getSpecialBuffer(), ret->getSpecialBuffer(), numTads, lengthOf(), ret->lengthOf(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets(), *stream), LIBND4J_TYPES);
|
|
||||||
registerSpecialUse({ret}, {this});
|
|
||||||
|
|
||||||
return ret;
|
__syncthreads();
|
||||||
|
|
||||||
|
auto coords = sharedMem + threadIdx.x * rank;
|
||||||
|
|
||||||
|
const auto tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
|
||||||
|
for (Nd4jLong i = tid; i < zLen; i += totalThreads) {
|
||||||
|
|
||||||
|
shape::index2coords(rank, zShapeInfo + 1, i, zLen, coords);
|
||||||
|
|
||||||
|
const auto zOffset = shape::getOffset(0, zShapeInfo + 1, zShapeInfo + rank + 1, coords, rank);
|
||||||
|
|
||||||
|
if(repSize > 1) {
|
||||||
|
for (uint j = 0; j < repSize; ++j) {
|
||||||
|
coords[axis] -= repeats[j];
|
||||||
|
if (coords[axis] < 0) {
|
||||||
|
coords[axis] = j;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else
|
||||||
|
coords[axis] /= repeats[0];
|
||||||
|
|
||||||
|
z[zOffset] = x[shape::getOffset(0, xShapeInfo + 1, xShapeInfo + rank + 1, coords, rank)];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// fill array by repeating it the number of times given by reps
|
template<typename X, typename Z>
|
||||||
void NDArray::repeat(int dimension, NDArray& target) const {
|
static void repeatCudaLauncher(const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream,
|
||||||
|
const void* vx, const Nd4jLong* xShapeInfo,
|
||||||
|
void* vz, const Nd4jLong* zShapeInfo,
|
||||||
|
const int* repeats, const int repSize,
|
||||||
|
const int axis) {
|
||||||
|
|
||||||
if(dimension < 0)
|
repeatCuda<X,Z><<<blocksPerGrid, threadsPerBlock, sharedMem, *stream>>>(vx, xShapeInfo, vz, zShapeInfo, repeats, repSize, axis);
|
||||||
dimension += rankOf();
|
|
||||||
|
|
||||||
if(rankOf() != target.rankOf())
|
|
||||||
throw std::invalid_argument("NDArray::repeat(int dimension, NDArray& target) method: wrong rank of target array it must be equal to this array rank!");
|
|
||||||
|
|
||||||
Nd4jLong repeatDelta = target.sizeAt(dimension) / sizeAt(dimension);
|
|
||||||
|
|
||||||
if(repeatDelta == 0)
|
|
||||||
throw std::invalid_argument("NDArray::repeat(int dimension, NDArray& target) method: wrong shape of target array!");
|
|
||||||
|
|
||||||
|
|
||||||
std::vector<int> dimsToExclude = ShapeUtils::evalDimsToExclude(rankOf(), {dimension});
|
|
||||||
const Nd4jLong numTads = ShapeUtils::getNumOfSubArrs(getShapeInfo(), dimsToExclude);
|
|
||||||
|
|
||||||
std::vector<int> copy({dimension});
|
|
||||||
auto packX = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(this->getShapeInfo(), copy);
|
|
||||||
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(target.getShapeInfo(), copy);
|
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&target}, {this});
|
|
||||||
auto stream = getContext()->getCudaStream();
|
|
||||||
BUILD_SINGLE_SELECTOR_TWICE(target.dataType(), repeatKernelHH, (getSpecialBuffer(), target.getSpecialBuffer(), numTads, lengthOf(), packX.platformShapeInfo(), packX.platformOffsets(), packZ.platformShapeInfo(), packZ.platformOffsets(), *stream), LIBND4J_TYPES);
|
|
||||||
NDArray::registerSpecialUse({&target}, {this});
|
|
||||||
}
|
}
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template void repeatCudaLauncher, (const int blocksPerGrid, const int threadsPerBlock, const int sharedMem, const cudaStream_t *stream, const void *vx, const Nd4jLong *xShapeInfo, void *vz, const Nd4jLong *zShapeInfo, const int* repeats, const int repSize, const int axis), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
// create new array by repeating it the number of times given by repeats
|
||||||
|
NDArray* NDArray::repeat(const int axis, const std::vector<int>& repeats) const {
|
||||||
|
|
||||||
|
auto output = new NDArray('c', ShapeUtils::evalRepeatShape(axis, repeats, *this), dataType(), getContext());
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
|
const int blocksPerGrid = (output->lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
const int sharedMem = output->rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
|
||||||
|
|
||||||
|
PointersManager manager(getContext(), "NDArray::repeat(const int axis, const std::vector<int>& repeats)");
|
||||||
|
|
||||||
|
const int* reps = reinterpret_cast<int*>(manager.replicatePointer(repeats.data(), repeats.size() * sizeof(int)));
|
||||||
|
|
||||||
|
prepareSpecialUse({output}, {this});
|
||||||
|
BUILD_SINGLE_SELECTOR_TWICE(dataType(), repeatCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), getSpecialBuffer(), getSpecialShapeInfo(), output->specialBuffer(), output->specialShapeInfo(), reps, repeats.size(), axis), LIBND4J_TYPES);
|
||||||
|
prepareSpecialUse({output}, {this});
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
// fill array by repeating it the number of times given by repeats
|
||||||
|
void NDArray::repeat(const int axis, const std::vector<int>& repeats, NDArray& target) const {
|
||||||
|
|
||||||
|
if(!target.isSameShape(ShapeUtils::evalRepeatShape(axis, repeats, *this)))
|
||||||
|
throw std::invalid_argument("NDArray::repeat(const int axis, const std::vector<int>& repeats, NDArray& target) method: wrong shape of target array!");
|
||||||
|
|
||||||
|
const int threadsPerBlock = MAX_NUM_THREADS / 2;
|
||||||
|
const int blocksPerGrid = (target.lengthOf() + threadsPerBlock - 1) / threadsPerBlock;
|
||||||
|
const int sharedMem = target.rankOf() * sizeof(Nd4jLong) * threadsPerBlock + 128;
|
||||||
|
|
||||||
|
PointersManager manager(getContext(), "NDArray::repeat(const int axis, const std::vector<int>& repeats)");
|
||||||
|
|
||||||
|
const int* reps = reinterpret_cast<int*>(manager.replicatePointer(repeats.data(), repeats.size() * sizeof(int)));
|
||||||
|
|
||||||
|
prepareSpecialUse({&target}, {this});
|
||||||
|
BUILD_DOUBLE_SELECTOR(dataType(), target.dataType(), repeatCudaLauncher, (blocksPerGrid, threadsPerBlock, sharedMem, getContext()->getCudaStream(), getSpecialBuffer(), getSpecialShapeInfo(), target.specialBuffer(), target.specialShapeInfo(), reps, repeats.size(), axis), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
|
prepareSpecialUse({&target}, {this});
|
||||||
|
|
||||||
|
manager.synchronize();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
void* NDArray::specialBuffer() {
|
void* NDArray::specialBuffer() {
|
||||||
|
|
|
@ -47,7 +47,7 @@ namespace nd4j {
|
||||||
static Nd4jLong* evalReduceShapeInfoEmpty(const char order, std::vector<int>& dimensions, const Nd4jLong *shapeInfo, const nd4j::DataType dataType, const bool keepDims, nd4j::memory::Workspace* workspace);
|
static Nd4jLong* evalReduceShapeInfoEmpty(const char order, std::vector<int>& dimensions, const Nd4jLong *shapeInfo, const nd4j::DataType dataType, const bool keepDims, nd4j::memory::Workspace* workspace);
|
||||||
|
|
||||||
// evaluate shape for array which is result of repeat operation applied to arr
|
// evaluate shape for array which is result of repeat operation applied to arr
|
||||||
static std::vector<Nd4jLong> evalRepeatShape(int dimension, const std::vector<Nd4jLong>& repeats, const NDArray& arr);
|
static std::vector<Nd4jLong> evalRepeatShape(int axis, const std::vector<int>& repeats, const NDArray& arr);
|
||||||
|
|
||||||
// evaluate shapeInfo of permuted array
|
// evaluate shapeInfo of permuted array
|
||||||
static Nd4jLong* evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace);
|
static Nd4jLong* evalPermShapeInfo(const int* dimensions, const int rank, const NDArray& arr, nd4j::memory::Workspace* workspace);
|
||||||
|
|
|
@ -281,37 +281,21 @@ Nd4jLong* ShapeUtils::evalReduceShapeInfo(const char order, std::vector<int>& di
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// evaluate shape for array which is result of repeat operation applied to arr
|
// evaluate shape for array which is result of repeat operation applied to arr
|
||||||
std::vector<Nd4jLong> ShapeUtils::evalRepeatShape(int dimension, const std::vector<Nd4jLong>& repeats, const NDArray& arr) {
|
std::vector<Nd4jLong> ShapeUtils::evalRepeatShape(int axis, const std::vector<int>& repeats, const NDArray& arr) {
|
||||||
|
|
||||||
int rank = arr.rankOf();
|
if (axis < 0)
|
||||||
|
axis += arr.rankOf();
|
||||||
|
|
||||||
if (dimension < 0)
|
if(repeats.size() != 1 && repeats.size() != arr.sizeAt(axis))
|
||||||
dimension += rank;
|
throw std::invalid_argument("ShapeUtils::evalRepeatShape: size of repeats vector must be 1 or equal to dimension at given axis !");
|
||||||
|
|
||||||
std::vector<Nd4jLong> reps;
|
std::vector<Nd4jLong> outShape = arr.getShapeAsVector();
|
||||||
|
|
||||||
if ((int) reps.size() < rank) {
|
if(repeats.size() == 1)
|
||||||
if (dimension > 0) {
|
outShape[axis] *= repeats[0];
|
||||||
for (int e = 0; e < rank - (int) repeats.size(); e++)
|
|
||||||
reps.push_back(1);
|
|
||||||
|
|
||||||
for (auto r: repeats)
|
else
|
||||||
reps.push_back(r);
|
outShape[axis] = std::accumulate(repeats.begin(), repeats.end(), 0);
|
||||||
} else {
|
|
||||||
for (auto r: repeats)
|
|
||||||
reps.push_back(r);
|
|
||||||
|
|
||||||
for (int e = 0; e < rank - (int) repeats.size(); e++)
|
|
||||||
reps.push_back(1);
|
|
||||||
}
|
|
||||||
}/* else {
|
|
||||||
for (auto r: repeats)
|
|
||||||
reps.push_back(r);
|
|
||||||
}*/
|
|
||||||
|
|
||||||
std::vector<Nd4jLong> outShape(rank);
|
|
||||||
for (int i = 0; i < rank; i++)
|
|
||||||
outShape[i] = arr.sizeAt(i) * reps.at(i);
|
|
||||||
|
|
||||||
return outShape;
|
return outShape;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,97 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
//
|
|
||||||
// @author GS <sgazeos@gmail.com>, created on 17.01.2019
|
|
||||||
//
|
|
||||||
|
|
||||||
#include <loops/special_kernels.h>
|
|
||||||
|
|
||||||
namespace nd4j {
|
|
||||||
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
static __global__ void repeatKernel(void const* inputBuffer, void* outputBuffer, Nd4jLong numTads, Nd4jLong inputLength,
|
|
||||||
Nd4jLong* tadOnlyInputShapeInfo, Nd4jLong *tadInputOffsets,
|
|
||||||
Nd4jLong* tadOnlyOutputShapeInfo, Nd4jLong *tadOutputOffsets) {
|
|
||||||
//auto tid = blockIdx.x * blockDim.x; // + threadIdx.x;
|
|
||||||
// int totalThreads = gridDim.x * blockDim.x;
|
|
||||||
int totalThreads = blockDim.x;
|
|
||||||
//const auto resultLength = shape::length(outputShape);
|
|
||||||
for (Nd4jLong i = blockIdx.x; i < numTads; i += gridDim.x) {
|
|
||||||
auto yOffset = tadInputOffsets[i];
|
|
||||||
auto xOffset = tadOutputOffsets[i];
|
|
||||||
for (Nd4jLong j = threadIdx.x; j < inputLength; j += totalThreads) {
|
|
||||||
auto outputOffset = shape::getIndexOrderOffset(j, tadOnlyOutputShapeInfo, inputLength, shape::order(tadOnlyInputShapeInfo));
|
|
||||||
auto inputOffset = shape::getIndexOrderOffset(j, tadOnlyInputShapeInfo, inputLength, shape::order(tadOnlyInputShapeInfo));
|
|
||||||
*(reinterpret_cast<T*>(outputBuffer) + xOffset + outputOffset) = *(reinterpret_cast<T const*>(inputBuffer) + yOffset + inputOffset);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
BUILD_SINGLE_TEMPLATE(template __global__ void repeatKernel, (void const* inputBuffer, void* outputBuffer,
|
|
||||||
Nd4jLong numTads, Nd4jLong inputLength, Nd4jLong* tadOnlyInputShapeInfo, Nd4jLong *tadInputOffsets,
|
|
||||||
Nd4jLong* tadOnlyOutputShapeInfo, Nd4jLong *tadOutputOffsets), LIBND4J_TYPES);
|
|
||||||
|
|
||||||
template <typename X, typename Y>
|
|
||||||
static __global__ void repeatKernelDouble(void const* inputBuffer, void* outputBuffer, Nd4jLong numTads, Nd4jLong inputLength,
|
|
||||||
Nd4jLong* tadOnlyInputShapeInfo, Nd4jLong *tadInputOffsets,
|
|
||||||
Nd4jLong* tadOnlyOutputShapeInfo, Nd4jLong *tadOutputOffsets) {
|
|
||||||
//auto tid = blockIdx.x * blockDim.x; // + threadIdx.x;
|
|
||||||
int totalThreads = gridDim.x * blockDim.x;
|
|
||||||
//const auto resultLength = shape::length(outputShape);
|
|
||||||
for (Nd4jLong i = blockIdx.x; i < numTads; i += gridDim.x) {
|
|
||||||
auto yOffset = tadInputOffsets[i];
|
|
||||||
auto xOffset = tadOutputOffsets[i];
|
|
||||||
for (Nd4jLong j = threadIdx.x; j < inputLength; j += totalThreads) {
|
|
||||||
auto outputOffset = shape::getIndexOrderOffset(j, tadOnlyOutputShapeInfo, inputLength, shape::order(tadOnlyInputShapeInfo));
|
|
||||||
auto inputOffset = shape::getIndexOrderOffset(j, tadOnlyInputShapeInfo, inputLength, shape::order(tadOnlyInputShapeInfo));
|
|
||||||
*(reinterpret_cast<X*>(outputBuffer) + xOffset + outputOffset) = static_cast<X>(*(reinterpret_cast<Y const*>(inputBuffer) + yOffset + inputOffset));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
BUILD_SINGLE_TEMPLATE_TWICE(template __global__ void repeatKernelDouble, (void const* inputBuffer, void* outputBuffer,
|
|
||||||
Nd4jLong numTads, Nd4jLong inputLength, Nd4jLong* tadOnlyInputShapeInfo, Nd4jLong *tadInputOffsets,
|
|
||||||
Nd4jLong* tadOnlyOutputShapeInfo, Nd4jLong *tadOutputOffsets), LIBND4J_TYPES);
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void repeatKernelH(void const* inputBuffer, void* outputBuffer, Nd4jLong numTads, Nd4jLong inputLength, Nd4jLong outputLength,
|
|
||||||
Nd4jLong *tadOnlyInputShapeInfo, Nd4jLong *tadInputOffsets,
|
|
||||||
Nd4jLong *tadOnlyOutputShapeInfo,Nd4jLong *tadOutputOffsets,
|
|
||||||
cudaStream_t stream) {
|
|
||||||
dim3 launchDims(256, 512, 8192);
|
|
||||||
repeatKernel<T><<<launchDims.x, launchDims.y, launchDims.z, stream>>>(inputBuffer, outputBuffer, numTads, inputLength, tadOnlyInputShapeInfo, tadInputOffsets, tadOnlyOutputShapeInfo, tadOutputOffsets);
|
|
||||||
}
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void repeatKernelH, (void const* inputBuffer, void* outputBuffer, Nd4jLong numTads, Nd4jLong inputLength, Nd4jLong outputLength,
|
|
||||||
Nd4jLong* tadOnlyInputShapeInfo, Nd4jLong *tadInputOffsets,
|
|
||||||
Nd4jLong* tadOnlyOutputShapeInfo, Nd4jLong *tadOutputOffsets,
|
|
||||||
cudaStream_t stream), LIBND4J_TYPES);
|
|
||||||
|
|
||||||
|
|
||||||
template <typename X, typename Y>
|
|
||||||
void repeatKernelHH(void const* inputBuffer, void* outputBuffer, Nd4jLong numTads, Nd4jLong inputLength,
|
|
||||||
Nd4jLong *tadOnlyInputShapeInfo, Nd4jLong *tadInputOffsets,
|
|
||||||
Nd4jLong *tadOnlyOutputShapeInfo,Nd4jLong *tadOutputOffsets,
|
|
||||||
cudaStream_t stream) {
|
|
||||||
dim3 launchDims(256, 512, 8192);
|
|
||||||
repeatKernelDouble<X,Y><<<launchDims.x, launchDims.y, launchDims.z, stream>>>(inputBuffer, outputBuffer, numTads, inputLength, tadOnlyInputShapeInfo, tadInputOffsets, tadOnlyOutputShapeInfo, tadOutputOffsets);
|
|
||||||
}
|
|
||||||
BUILD_SINGLE_TEMPLATE_TWICE(template void repeatKernelHH, (void const* inputBuffer, void* outputBuffer, Nd4jLong numTads, Nd4jLong inputLength,
|
|
||||||
Nd4jLong* tadOnlyInputShapeInfo, Nd4jLong *tadInputOffsets,
|
|
||||||
Nd4jLong* tadOnlyOutputShapeInfo, Nd4jLong *tadOutputOffsets,
|
|
||||||
cudaStream_t stream), LIBND4J_TYPES);
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -89,17 +89,6 @@ namespace nd4j {
|
||||||
template <typename X, typename Y>
|
template <typename X, typename Y>
|
||||||
_CUDA_H void tileKernelHH(void const* inputBuffer, Nd4jLong* inputShape, void* outputBuffer, Nd4jLong* outputShape, Nd4jLong resultLength, Nd4jLong ews, cudaStream_t *stream);
|
_CUDA_H void tileKernelHH(void const* inputBuffer, Nd4jLong* inputShape, void* outputBuffer, Nd4jLong* outputShape, Nd4jLong resultLength, Nd4jLong ews, cudaStream_t *stream);
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
_CUDA_H void repeatKernelH(void const* inputBuffer, void* outputBuffer, Nd4jLong numTads, Nd4jLong inputLength, Nd4jLong outputLength,
|
|
||||||
Nd4jLong *tadOnlyInputShapeInfo, Nd4jLong *tadInputOffsets,
|
|
||||||
Nd4jLong *tadOnlyOutputShapeInfo,Nd4jLong *tadOutputOffsets,
|
|
||||||
cudaStream_t stream);
|
|
||||||
|
|
||||||
template <typename X, typename Y>
|
|
||||||
_CUDA_H void repeatKernelHH(void const* inputBuffer, void* outputBuffer, Nd4jLong numTads, Nd4jLong inputLength,
|
|
||||||
Nd4jLong *tadOnlyInputShapeInfo, Nd4jLong *tadInputOffsets,
|
|
||||||
Nd4jLong *tadOnlyOutputShapeInfo,Nd4jLong *tadOutputOffsets,
|
|
||||||
cudaStream_t stream);
|
|
||||||
|
|
||||||
class NDArray;
|
class NDArray;
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
//
|
//
|
||||||
// @author raver119@gmail.com
|
// @author Yurii Shyrma (iuriish@yahoo.com)
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <op_boilerplate.h>
|
#include <op_boilerplate.h>
|
||||||
|
@ -25,14 +25,25 @@
|
||||||
|
|
||||||
namespace nd4j {
|
namespace nd4j {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////////
|
||||||
// here iArgs is int vector of repeats at the beginning and last element in iArgs is dimension
|
// here iArgs is int vector of repeats at the beginning and last element in iArgs is dimension
|
||||||
CUSTOM_OP_IMPL(repeat, 1, 1, true, 0, -1) {
|
CUSTOM_OP_IMPL(repeat, 1, 1, true, 0, -1) {
|
||||||
|
|
||||||
auto x = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto ret = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
x->repeat(block.getIArguments()->back(), *ret);
|
std::vector<int> repeats = *block.getIArguments();
|
||||||
|
|
||||||
|
const int axis = repeats.back() < 0 ? repeats.back() + input->rankOf() : repeats.back();
|
||||||
|
|
||||||
|
repeats.pop_back();
|
||||||
|
|
||||||
|
REQUIRE_TRUE(0 <= axis && axis < input->rankOf(), 0, "CUSTOM REPEAT OP: wrong axis argument it should be less then input array rank %i, but got %i instead !", input->rankOf(), axis);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(repeats.size() == 1 || repeats.size() == input->sizeAt(axis), 0, "CUSTOM REPEAT OP: wrong axis argument, size of repeats vector must be 1 or equal to dimension at given axis, but got repeats.size = %i and axis = %i !", repeats.size(), axis);
|
||||||
|
|
||||||
|
input->repeat(axis, repeats, *output);
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
@ -45,15 +56,18 @@ namespace nd4j {
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(repeat) {
|
DECLARE_SHAPE_FN(repeat) {
|
||||||
|
|
||||||
auto x = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto argumets = block.getIArguments();
|
|
||||||
int argsSize = argumets->size();
|
std::vector<int> repeats = *block.getIArguments();
|
||||||
int dimension = (*argumets)[argsSize-1];
|
|
||||||
auto repeats = *argumets;
|
const int axis = repeats.back() < 0 ? repeats.back() + input->rankOf() : repeats.back();
|
||||||
|
|
||||||
repeats.pop_back();
|
repeats.pop_back();
|
||||||
|
|
||||||
auto outShape = ShapeUtils::evalRepeatShape(dimension, ArrayUtils::toLongVector(repeats), *x);
|
auto outShape = ShapeUtils::evalRepeatShape(axis, repeats, *input);
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(x->dataType(), x->ordering(), outShape)));
|
|
||||||
|
return SHAPELIST(ConstantShapeHelper::getInstance()->createShapeInfo(ShapeDescriptor(input->dataType(), input->ordering(), outShape)));
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1270,6 +1270,10 @@ static void tileBP_(const NDArray& gradO /*input*/, NDArray& gradI /*output*/, c
|
||||||
|
|
||||||
BUILD_SINGLE_TEMPLATE(template void tileBP_, (const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector<Nd4jLong> reps), FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void tileBP_, (const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector<Nd4jLong> reps), FLOAT_TYPES);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -73,7 +73,6 @@ namespace helpers {
|
||||||
void concat(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output, const int axis);
|
void concat(nd4j::LaunchContext * context, const std::vector<NDArray*>& inArrs, NDArray& output, const int axis);
|
||||||
|
|
||||||
void tileBP(nd4j::LaunchContext * context, const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector<Nd4jLong> reps);
|
void tileBP(nd4j::LaunchContext * context, const NDArray& gradO /*input*/, NDArray& gradI /*output*/, const std::vector<Nd4jLong> reps);
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1782,40 +1782,6 @@ TEST_F(DeclarableOpsTests1, Reshape7){
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
|
||||||
TEST_F(DeclarableOpsTests1, Repeat1) {
|
|
||||||
|
|
||||||
float eBuffer[8] = {1.0,2.0,1.0,2.0,3.0,4.0,3.0,4.0};
|
|
||||||
Nd4jLong eShape[8] = {2, 4, 2, 2, 1, 0, 1, 99};
|
|
||||||
ArrayOptions::setDataType(eShape, nd4j::DataType::FLOAT32);
|
|
||||||
auto x = NDArrayFactory::create_<float>('c', {2, 2});
|
|
||||||
auto exp = new NDArray(eBuffer, eShape);
|
|
||||||
for (int e = 0; e < x->lengthOf(); e++)
|
|
||||||
x->p(e, e + 1);
|
|
||||||
|
|
||||||
auto variableSpace = new VariableSpace();
|
|
||||||
variableSpace->putVariable(-1, x);
|
|
||||||
variableSpace->putVariable(1, new Variable());
|
|
||||||
|
|
||||||
auto block = new Context(1, variableSpace, false);
|
|
||||||
block->fillInputs({-1});
|
|
||||||
std::vector<int>* arguments = block->getIArguments();
|
|
||||||
*arguments = {2}; // set repeats
|
|
||||||
arguments->push_back(0); // set dimension
|
|
||||||
|
|
||||||
nd4j::ops::repeat repeat;
|
|
||||||
|
|
||||||
Nd4jStatus status = repeat.execute(block);
|
|
||||||
ASSERT_EQ(ND4J_STATUS_OK, status);
|
|
||||||
auto result = variableSpace->getVariable(block->getNodeId())->getNDArray();
|
|
||||||
|
|
||||||
ASSERT_TRUE(exp->equalsTo(result));
|
|
||||||
|
|
||||||
delete exp;
|
|
||||||
delete block;
|
|
||||||
delete variableSpace;
|
|
||||||
}
|
|
||||||
|
|
||||||
//////////////////////////////////////////////////////////////////////
|
//////////////////////////////////////////////////////////////////////
|
||||||
TEST_F(DeclarableOpsTests1, Transpose1) {
|
TEST_F(DeclarableOpsTests1, Transpose1) {
|
||||||
|
|
||||||
|
|
|
@ -450,3 +450,93 @@ TEST_F(DeclarableOpsTests14, test_empty_tanh_5) {
|
||||||
|
|
||||||
delete result;
|
delete result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests14, repeat_1) {
|
||||||
|
|
||||||
|
NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6});
|
||||||
|
NDArray e('c', {4, 3}, {1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6});
|
||||||
|
|
||||||
|
nd4j::ops::repeat op;
|
||||||
|
auto result = op.execute({&x}, {}, {2, 0});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests14, repeat_2) {
|
||||||
|
|
||||||
|
NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6});
|
||||||
|
NDArray e('c', {2, 6}, {1, 1, 2, 2, 3, 3,4, 4, 5, 5, 6, 6});
|
||||||
|
|
||||||
|
nd4j::ops::repeat op;
|
||||||
|
auto result = op.execute({&x}, {}, {2, 1});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests14, repeat_3) {
|
||||||
|
|
||||||
|
NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6});
|
||||||
|
NDArray e('c', {2, 6}, {1, 2, 2, 3, 3, 3,4, 5, 5, 6, 6, 6});
|
||||||
|
|
||||||
|
nd4j::ops::repeat op;
|
||||||
|
auto result = op.execute({&x}, {}, {1,2,3, 1});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests14, repeat_4) {
|
||||||
|
|
||||||
|
NDArray x('c', {2, 3}, {1, 2, 3, 4, 5, 6});
|
||||||
|
NDArray e('c', {7, 3}, {1, 2, 3, 1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6, 4, 5, 6, 4, 5, 6});
|
||||||
|
|
||||||
|
nd4j::ops::repeat op;
|
||||||
|
auto result = op.execute({&x}, {}, {3,4, 0});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////
|
||||||
|
TEST_F(DeclarableOpsTests14, repeat_5) {
|
||||||
|
|
||||||
|
NDArray x('c', {2, 3, 4}, {1, 2, 3, 4, 5, 6, 7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24});
|
||||||
|
NDArray e('c', {2, 4, 4}, {1, 2, 3, 4, 5, 6, 7, 8, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 17, 18, 19, 20, 21, 22, 23, 24});
|
||||||
|
|
||||||
|
nd4j::ops::repeat op;
|
||||||
|
auto result = op.execute({&x}, {}, {1,2,1, 1});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
ASSERT_TRUE(e.isSameShape(z));
|
||||||
|
ASSERT_TRUE(e.equalsTo(z));
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
||||||
|
|
|
@ -39,20 +39,6 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests16, test_repeat_119) {
|
|
||||||
auto x = NDArrayFactory::create<double>('c', {2, 3}, {1, 2, 3, 4, 5, 6});
|
|
||||||
auto e = NDArrayFactory::create<double>('c', {4, 3}, {1, 2, 3, 1, 2, 3, 4, 5, 6, 4, 5, 6});
|
|
||||||
|
|
||||||
nd4j::ops::repeat op;
|
|
||||||
auto result = op.execute({&x}, {}, {2, 0});
|
|
||||||
ASSERT_EQ(Status::OK(), result->status());
|
|
||||||
|
|
||||||
auto z = result->at(0);
|
|
||||||
ASSERT_EQ(e, *z);
|
|
||||||
|
|
||||||
delete result;
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests16, test_scatter_update_119) {
|
TEST_F(DeclarableOpsTests16, test_scatter_update_119) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {3}, {1, 1, 1});
|
auto x = NDArrayFactory::create<float>('c', {3}, {1, 1, 1});
|
||||||
auto y = NDArrayFactory::create<int>(0);
|
auto y = NDArrayFactory::create<int>(0);
|
||||||
|
|
|
@ -196,7 +196,7 @@ TEST_F(MultiDataTypeTests, ndarray_repeat_test1) {
|
||||||
NDArray y('c', {2, 4}, nd4j::DataType::HALF);
|
NDArray y('c', {2, 4}, nd4j::DataType::HALF);
|
||||||
NDArray exp('c', {2, 4}, {0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5}, nd4j::DataType::HALF);
|
NDArray exp('c', {2, 4}, {0.5, 0.5, 1.5, 1.5, 2.5, 2.5, 3.5, 3.5}, nd4j::DataType::HALF);
|
||||||
|
|
||||||
x.repeat(1, y);
|
x.repeat(1, {2}, y);
|
||||||
|
|
||||||
ASSERT_EQ(y, exp);
|
ASSERT_EQ(y, exp);
|
||||||
}
|
}
|
||||||
|
|
|
@ -324,7 +324,7 @@ TEST_F(NDArrayTest, TestRepeat2) {
|
||||||
|
|
||||||
auto rep = exp->dup();
|
auto rep = exp->dup();
|
||||||
rep->assign(0.);
|
rep->assign(0.);
|
||||||
array->repeat(0, *rep);
|
array->repeat(0, {2}, *rep);
|
||||||
//rep->printIndexedBuffer("Repeated");
|
//rep->printIndexedBuffer("Repeated");
|
||||||
|
|
||||||
ASSERT_EQ(4, rep->sizeAt(0));
|
ASSERT_EQ(4, rep->sizeAt(0));
|
||||||
|
|
Loading…
Reference in New Issue