[WIP] Minor fixes (#140)

* - Tile java shape fn removed
- Tile 0 validation added
- scatter_upd test

Signed-off-by: raver119 <raver119@gmail.com>

* additional tile validation

Signed-off-by: raver119 <raver119@gmail.com>

* - provide vector case in cuda scatter op

Signed-off-by: Yurii <yurii@skymind.io>

* cpu ismax view fix

Signed-off-by: raver119 <raver119@gmail.com>

* exp

Signed-off-by: raver119 <raver119@gmail.com>

* cuda ismax fix

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-21 15:05:47 +03:00 committed by GitHub
parent a5867bb527
commit d9ab299759
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 89 additions and 45 deletions

View File

@ -231,7 +231,8 @@ void* NDArray::getSpecialBuffer() const {
// change an array by repeating it the number of times given by reps. // change an array by repeating it the number of times given by reps.
NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const { NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
const int repsSize = reps.size(); const int repsSize = reps.size();
int product = 1;
Nd4jLong product = 1;
for(const auto& item : reps) for(const auto& item : reps)
product *= item; product *= item;
if(product == 0) if(product == 0)
@ -286,6 +287,10 @@ NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
// change an array by repeating it the number of times given by reps. // change an array by repeating it the number of times given by reps.
void NDArray::tile(const std::vector<Nd4jLong>& reps, NDArray& target) const { void NDArray::tile(const std::vector<Nd4jLong>& reps, NDArray& target) const {
auto repProd = shape::prodLong(reps.data(), reps.size());
if (repProd < 1)
throw std::runtime_error("NDArray::tile: reps can't contain 0s");
// evaluate true tile shapeInfo for comparison with target shapeInfo // evaluate true tile shapeInfo for comparison with target shapeInfo
auto newShapeInfo = ShapeUtils::evalTileShapeInfo(*this, reps, getContext()->getWorkspace()); auto newShapeInfo = ShapeUtils::evalTileShapeInfo(*this, reps, getContext()->getWorkspace());
if(!shape::equalsSoft(newShapeInfo, target.getShapeInfo())) { if(!shape::equalsSoft(newShapeInfo, target.getShapeInfo())) {

View File

@ -312,7 +312,8 @@ NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
Nd4jLong product = 1; Nd4jLong product = 1;
for(const auto& item : reps) for(const auto& item : reps)
product *= item; product *= item;
if(product == 0)
if(product < 1)
throw std::runtime_error("NDArray::tile method: one of the elements in reps array is zero !"); throw std::runtime_error("NDArray::tile method: one of the elements in reps array is zero !");
int rankOld = rankOf(); int rankOld = rankOf();
@ -351,6 +352,10 @@ NDArray NDArray::tile(const std::vector<Nd4jLong>& reps) const {
// change an array by repeating it the number of times given by reps. // change an array by repeating it the number of times given by reps.
void NDArray::tile(const std::vector<Nd4jLong>& reps, NDArray& target) const { void NDArray::tile(const std::vector<Nd4jLong>& reps, NDArray& target) const {
auto repProd = shape::prodLong(reps.data(), reps.size());
if (repProd < 1)
throw std::runtime_error("NDArray::tile: reps can't contain 0s");
// evaluate true tile shapeInfo for comparison with target shapeInfo // evaluate true tile shapeInfo for comparison with target shapeInfo
auto newShapeInfo = ShapeUtils::evalTileShapeInfo(*this, reps, getContext()->getWorkspace()); auto newShapeInfo = ShapeUtils::evalTileShapeInfo(*this, reps, getContext()->getWorkspace());
if(!shape::equalsSoft(newShapeInfo, target.getShapeInfo())) { if(!shape::equalsSoft(newShapeInfo, target.getShapeInfo())) {

View File

@ -48,18 +48,16 @@ namespace nd4j {
for (int r = blockIdx.x; r < numTads; r += gridDim.x) { for (int r = blockIdx.x; r < numTads; r += gridDim.x) {
auto tadOffsetForBlock = tadOffsets[r]; auto tadOffsetForBlock = tadOffsets[r];
auto highestElement = dX[r];
int highestElement = (int) dX[r];
if (dimensionLength > 1 || tadEWS < 1) { if (dimensionLength > 1 || tadEWS < 1) {
for (int e = threadIdx.x; e < tadLength; e += blockDim.x) { for (Nd4jLong e = threadIdx.x; e < tadLength; e += blockDim.x) {
auto xOffset = tadOffsetForBlock + shape::getIndexOffset(e, tadOnlyShapeInfo, tadLength); auto xOffset = tadOffsetForBlock + shape::getIndexOffset(e, tadOnlyShapeInfo, tadLength);
dZ[xOffset] = (e == highestElement ? (T) 1 : (T) 0); dZ[xOffset] = (e == highestElement ? (T) 1 : (T) 0);
} }
} else { } else {
for (int e = threadIdx.x; e < tadLength; e += blockDim.x) { for (Nd4jLong e = threadIdx.x; e < tadLength; e += blockDim.x) {
// so, we just set dZ[e] for each TAD. Sure, e should be replaced with // so, we just set dZ[e] for each TAD. Sure, e should be replaced with
auto idx = tadOffsetForBlock + (e * tadEWS); auto idx = tadOffsetForBlock + (e * tadEWS);
dZ[idx] = (e == highestElement ? (T) 1 : (T) 0); dZ[idx] = (e == highestElement ? (T) 1 : (T) 0);

View File

@ -50,6 +50,9 @@ CUSTOM_OP_IMPL(tile, 1, 1, false, 0, -2) {
else { else {
REQUIRE_TRUE(false, 0, "TILE op: this op requires repeats vector, either as IArgs or second array with length equal to rank of input array to be tiled !"); REQUIRE_TRUE(false, 0, "TILE op: this op requires repeats vector, either as IArgs or second array with length equal to rank of input array to be tiled !");
} }
auto repProd = shape::prodLong(reps.data(), reps.size());
REQUIRE_TRUE(repProd > 0, 0, "TILE op: reps can't contain 0s");
input->tile(reps, *output); input->tile(reps, *output);
@ -81,7 +84,10 @@ DECLARE_SHAPE_FN(tile) {
} }
else { else {
REQUIRE_TRUE(false, 0, "TILE op: this op requires repeats vector, either as IArgs or second array with length equal to rank of input array to be tiled !"); REQUIRE_TRUE(false, 0, "TILE op: this op requires repeats vector, either as IArgs or second array with length equal to rank of input array to be tiled !");
} }
auto repProd = shape::prodLong(reps.data(), reps.size());
REQUIRE_TRUE(repProd > 0, 0, "TILE op: reps can't contain 0s");
std::vector<Nd4jLong> shape(inRank); std::vector<Nd4jLong> shape(inRank);
for (int e = 0; e < shape::rank(inShape); e++) for (int e = 0; e < shape::rank(inShape); e++)

View File

@ -125,9 +125,12 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
//to the back. //to the back.
//permuted version of the input shape info for setting up the tad problem //permuted version of the input shape info for setting up the tad problem
auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), const_cast<int*>(dimensions.data()), dimensionsLength); auto tadPack = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), const_cast<int*>(dimensions.data()), dimensionsLength);
auto tadPackZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), const_cast<int*>(dimensions.data()), dimensionsLength);
auto tadShapeShapeInfo = tadPack.primaryShapeInfo(); auto tadShapeShapeInfo = tadPack.primaryShapeInfo();
auto tadOffsets = tadPack.primaryOffsets(); auto tadOffsets = tadPack.primaryOffsets();
auto zOfsets = tadPackZ.platformOffsets();
int tadLength = shape::length(tadShapeShapeInfo); int tadLength = shape::length(tadShapeShapeInfo);
int tads = tadPack.numberOfTads(); int tads = tadPack.numberOfTads();
@ -137,7 +140,7 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
num_threads = nd4j::math::nd4j_min<int>(num_threads, omp_get_max_threads()); num_threads = nd4j::math::nd4j_min<int>(num_threads, omp_get_max_threads());
auto tadEWS = shape::elementWiseStride(tadShapeShapeInfo); auto tadEWS = shape::elementWiseStride(tadShapeShapeInfo);
auto zEWS = tadEWS; auto zEWS = shape::elementWiseStride(tadPackZ.primaryShapeInfo());
int span = (tads / num_threads) + 8; int span = (tads / num_threads) + 8;
@ -151,7 +154,7 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
for (int r = start; r < end; r++) { for (int r = start; r < end; r++) {
if (tadEWS > 0 && zEWS > 0 && dimensionsLength == 1) { if (tadEWS > 0 && zEWS > 0 && dimensionsLength == 1) {
auto rX = const_cast<NDArray*>(input)->bufferAsT<X>() + tadOffsets[r]; auto rX = const_cast<NDArray*>(input)->bufferAsT<X>() + tadOffsets[r];
auto rZ = output->bufferAsT<Z>() + tadOffsets[r]; auto rZ = output->bufferAsT<Z>() + zOfsets[r];
auto maxValue = rX[0]; auto maxValue = rX[0];
int maxIdx = 0; int maxIdx = 0;
@ -168,7 +171,7 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
rZ[i] = maxIdx == i ? (Z) 1 : (Z) 0; rZ[i] = maxIdx == i ? (Z) 1 : (Z) 0;
} }
} }
else { else if (tadEWS > 1 && zEWS > 1) {
for (int i = 0; i < tadLength; i++) { for (int i = 0; i < tadLength; i++) {
if (rX[i * tadEWS] > maxValue) { if (rX[i * tadEWS] > maxValue) {
maxIdx = i; maxIdx = i;
@ -180,6 +183,20 @@ static void ismax_(const NDArray* input, NDArray* output, const std::vector<int>
for (int i = 0; i < tadLength; i++) { for (int i = 0; i < tadLength; i++) {
rZ[i * zEWS] = maxIdx == i ? (Z) 1 : (Z) 0; rZ[i * zEWS] = maxIdx == i ? (Z) 1 : (Z) 0;
} }
} else {
for (int i = 0; i < tadLength; i++) {
auto xOffset = shape::getIndexOffset(i, tadShapeShapeInfo, tadLength);
if (rX[xOffset] > maxValue) {
maxIdx = i;
maxValue = rX[xOffset];
}
}
PRAGMA_OMP_SIMD
for (int i = 0; i < tadLength; i++) {
auto zOffset = shape::getIndexOffset(i, tadPackZ.primaryShapeInfo(), tadLength);
rZ[zOffset] = maxIdx == i ? (Z) 1 : (Z) 0;
}
} }
} }
else { else {

View File

@ -62,7 +62,7 @@ static void ismax_(nd4j::LaunchContext * context, const NDArray* input, NDArray*
int dimensionLength = dimensions.size(); int dimensionLength = dimensions.size();
std::vector<int> copy(dimensions); std::vector<int> copy(dimensions);
auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(input->getShapeInfo(), copy.data(), copy.size()); auto packZ = nd4j::ConstantTadHelper::getInstance()->tadForDimensions(output->getShapeInfo(), copy.data(), copy.size());
auto indexMaxArr = input->applyIndexReduce(indexreduce::IndexMax, dimensions); auto indexMaxArr = input->applyIndexReduce(indexreduce::IndexMax, dimensions);

View File

@ -183,7 +183,7 @@ __global__ static void scatterLockCuda(const int opCode,
__shared__ bool vectorCase; __shared__ bool vectorCase;
if(threadIdx.x == 0) if(threadIdx.x == 0)
vectorCase = yTadLen == xLen && shape::rank(xShapeInfo) == 1; vectorCase = yTadLen == xLen && shape::rank(xShapeInfo) <= 1;
__syncthreads(); __syncthreads();
for (int e = 0; e < xLen; e++) { for (int e = 0; e < xLen; e++) {

View File

@ -52,3 +52,20 @@ TEST_F(DeclarableOpsTests16, test_repeat_119) {
delete result; delete result;
} }
TEST_F(DeclarableOpsTests16, test_scatter_update_119) {
auto x = NDArrayFactory::create<float>('c', {3}, {1, 1, 1});
auto y = NDArrayFactory::create<int>(0);
auto w = NDArrayFactory::create<float>(3.0f);
auto e = NDArrayFactory::create<float>('c', {3}, {3.f, 1.f, 1.f});
nd4j::ops::scatter_upd op;
auto result = op.execute({&x, &y, &w}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
}

View File

@ -1161,6 +1161,34 @@ TEST_F(JavaInteropTests, test_bfloat16_rng) {
ASSERT_TRUE(z.sumNumber().e<float>(0) > 0); ASSERT_TRUE(z.sumNumber().e<float>(0) > 0);
} }
TEST_F(JavaInteropTests, test_ismax_view) {
auto original = NDArrayFactory::create<double>('c', {2, 3, 40});
auto v = original.subarray({NDIndex::all(), NDIndex::all(), NDIndex::interval(0, 40, 2)});
v->assign(1.0);
auto e = v->ulike();
auto t = e.tensorAlongDimension(0, {0, 1});
t->assign(1.0);
auto z = v->ulike();
Nd4jLong iArgs[] = {2L, 0L};
Context ctx(1);
ctx.setInputArray(0, v->buffer(), v->shapeInfo(), v->specialBuffer(), v->specialShapeInfo());
ctx.setOutputArray(0, z.buffer(), z.shapeInfo(), z.specialBuffer(), z.specialShapeInfo());
ctx.setIArguments(iArgs, 1);
nd4j::ops::ismax op;
op.execute(&ctx);
z.printIndexedBuffer("z");
ASSERT_EQ(e, z);
delete v;
delete t;
}
/* /*
TEST_F(JavaInteropTests, Test_Results_Conversion_1) { TEST_F(JavaInteropTests, Test_Results_Conversion_1) {
auto pl = nd4j::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb"); auto pl = nd4j::graph::readFlatBuffers("./resources/gru_dynamic_mnist.fb");

View File

@ -103,38 +103,6 @@ public class Tile extends DynamicCustomOp {
return ret; return ret;
} }
@Override
public List<LongShapeDescriptor> calculateOutputShape() {
if(inputArguments.size() == 0)
return Collections.emptyList();
/**
* This op is special case: we can't infer its shape before both inputs are available.
* So if reps argument is full of 0.0s - we skip shape inference
*
* And during actual op invocation both inputs should be available due to topo sort
*/
if (is_static_reps)
return Nd4j.getExecutioner().calculateOutputShape(this);
if (inputArguments().length < 2)
return Collections.emptyList();
val array = inputArguments()[1];
// FIXME: int cast
val reps = new long[(int) array.length()];
for (int e = 0; e < reps.length; e++)
reps[e] = (int) array.getDouble(e);
if (ArrayUtil.prodLong(reps) == 0)
return Collections.emptyList();
else
return Nd4j.getExecutioner().calculateOutputShape(this);
}
@Override @Override
public String opName() { public String opName() {
return "tile"; return "tile";