Null arrays fix (#208)

* don't skip null arrays

Signed-off-by: raver119 <raver119@gmail.com>

* one test tweak

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2020-02-02 23:14:00 +03:00 committed by GitHub
parent 81efa5c3b6
commit 9bb5798cac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 42 additions and 20 deletions

View File

@ -171,7 +171,7 @@ namespace nd4j {
Nd4jStatus execute(const std::vector<NDArray*> &inputs, const std::vector<NDArray*> &outputs);
template <class T>
template <class T, typename = std::enable_if<DataTypeUtils::scalarTypesForExecution<T>::value>>
Nd4jStatus execute(const std::vector<NDArray*> &inputs, const std::vector<NDArray*> &outputs, std::initializer_list<T> tArgs);
Nd4jStatus execute(const std::vector<NDArray*> &inputs, const std::vector<NDArray*> &outputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs = std::vector<bool>(), const std::vector<nd4j::DataType> &dArgs = std::vector<nd4j::DataType>(), bool isInplace = false);
@ -179,7 +179,7 @@ namespace nd4j {
nd4j::ResultSet* evaluate(const std::vector<NDArray*> &inputs);
template <class T>
template <class T, typename = std::enable_if<DataTypeUtils::scalarTypesForExecution<T>::value>>
nd4j::ResultSet* evaluate(const std::vector<NDArray*> &inputs, std::initializer_list<T> args);
nd4j::ResultSet* evaluate(const std::vector<NDArray*> &inputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs = std::vector<bool>(), const std::vector<nd4j::DataType> &dArgs = std::vector<nd4j::DataType>(), bool isInplace = false);

View File

@ -222,7 +222,7 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) {
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
nd4j::ops::conv2d_bp conv2dBP;
const Nd4jStatus status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {});
auto status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {});
if (status != ND4J_STATUS_OK)
return status;

View File

@ -165,10 +165,7 @@ namespace nd4j {
// we build list of input shapes
if (ctx.isFastPath()) {
for (const auto p:ctx.fastpath_in()) {
if (p == nullptr)
continue;
inSha.push_back(p->getShapeInfo());
inSha.push_back(p == nullptr ? nullptr : p->getShapeInfo());
}
} else {
for (auto p: *ctx.inputs()) {
@ -184,6 +181,11 @@ namespace nd4j {
}
}
// if we override shape function, we'll return size of fastPath
if (ctx.isFastPath() && ctx.shapeFunctionOverride()) {
return (int) ctx.fastpath_out().size();
}
// optionally saving input time
if (Environment::getInstance()->isProfiling() && node != nullptr) {
inputEnd = std::chrono::system_clock::now();
@ -193,11 +195,6 @@ namespace nd4j {
shapeStart = std::chrono::system_clock::now();
}
// if we override shape function, we'll return size of fastPath
if (ctx.isFastPath() && ctx.shapeFunctionOverride()) {
return (int) ctx.fastpath_out().size();
}
auto outSha = this->calculateOutputShape(&inSha, ctx);
results = outSha->size();
@ -870,16 +867,10 @@ namespace nd4j {
Context ctx(1);
for (int e = 0; e < inputs.size(); e++) {
if (inputs[e] == nullptr)
break;
ctx.setInputArray(e, inputs[e]);
}
for (int e = 0; e < outputs.size(); e++) {
if (outputs[e] == nullptr)
break;
ctx.setOutputArray(e, outputs[e]);
}

View File

@ -38,3 +38,32 @@ public:
fflush(stdout);
}
};
TEST_F(DeclarableOpsTests19, test_conv1d_bp_1) {
/*
DynamicCustomOp op = DynamicCustomOp.builder("conv1d_bp")
.addInputs(
Nd4j.create(DataType.FLOAT, 2,2,12),
Nd4j.create(DataType.FLOAT, 3,2,3),
Nd4j.create(DataType.FLOAT, 2,3,6)
)
.addOutputs(
Nd4j.create(DataType.FLOAT, 2,2,12),
Nd4j.create(DataType.FLOAT, 3,2,3))
.addIntegerArguments(3,2,0,1,2,0)
.build();
Nd4j.exec(op);
*/
auto t = NDArrayFactory::create<float>('c', {2, 2, 12});
auto u = NDArrayFactory::create<float>('c', {3, 2, 3});
auto v = NDArrayFactory::create<float>('c', {2, 3, 6});
nd4j::ops::conv1d_bp op;
auto result = op.evaluate({&t, &u, &v}, {3, 2, 0, 1, 2,0});
ASSERT_EQ(Status::OK(), result->status());
delete result;
}

View File

@ -810,9 +810,11 @@ public class RandomTests extends BaseNd4jTest {
threads[x].start();
}
for (int x = 0; x < threads.length; x++) {
// we want all threads finished before comparing arrays
for (int x = 0; x < threads.length; x++)
threads[x].join();
for (int x = 0; x < threads.length; x++) {
assertNotEquals(null, list.get(x));
if (x > 0) {