Null arrays fix (#208)
* don't skip null arrays Signed-off-by: raver119 <raver119@gmail.com> * one test tweak Signed-off-by: raver119 <raver119@gmail.com>master
parent
81efa5c3b6
commit
9bb5798cac
|
@ -171,7 +171,7 @@ namespace nd4j {
|
||||||
|
|
||||||
Nd4jStatus execute(const std::vector<NDArray*> &inputs, const std::vector<NDArray*> &outputs);
|
Nd4jStatus execute(const std::vector<NDArray*> &inputs, const std::vector<NDArray*> &outputs);
|
||||||
|
|
||||||
template <class T>
|
template <class T, typename = std::enable_if<DataTypeUtils::scalarTypesForExecution<T>::value>>
|
||||||
Nd4jStatus execute(const std::vector<NDArray*> &inputs, const std::vector<NDArray*> &outputs, std::initializer_list<T> tArgs);
|
Nd4jStatus execute(const std::vector<NDArray*> &inputs, const std::vector<NDArray*> &outputs, std::initializer_list<T> tArgs);
|
||||||
|
|
||||||
Nd4jStatus execute(const std::vector<NDArray*> &inputs, const std::vector<NDArray*> &outputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs = std::vector<bool>(), const std::vector<nd4j::DataType> &dArgs = std::vector<nd4j::DataType>(), bool isInplace = false);
|
Nd4jStatus execute(const std::vector<NDArray*> &inputs, const std::vector<NDArray*> &outputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs = std::vector<bool>(), const std::vector<nd4j::DataType> &dArgs = std::vector<nd4j::DataType>(), bool isInplace = false);
|
||||||
|
@ -179,7 +179,7 @@ namespace nd4j {
|
||||||
|
|
||||||
nd4j::ResultSet* evaluate(const std::vector<NDArray*> &inputs);
|
nd4j::ResultSet* evaluate(const std::vector<NDArray*> &inputs);
|
||||||
|
|
||||||
template <class T>
|
template <class T, typename = std::enable_if<DataTypeUtils::scalarTypesForExecution<T>::value>>
|
||||||
nd4j::ResultSet* evaluate(const std::vector<NDArray*> &inputs, std::initializer_list<T> args);
|
nd4j::ResultSet* evaluate(const std::vector<NDArray*> &inputs, std::initializer_list<T> args);
|
||||||
|
|
||||||
nd4j::ResultSet* evaluate(const std::vector<NDArray*> &inputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs = std::vector<bool>(), const std::vector<nd4j::DataType> &dArgs = std::vector<nd4j::DataType>(), bool isInplace = false);
|
nd4j::ResultSet* evaluate(const std::vector<NDArray*> &inputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs = std::vector<bool>(), const std::vector<nd4j::DataType> &dArgs = std::vector<nd4j::DataType>(), bool isInplace = false);
|
||||||
|
|
|
@ -222,7 +222,7 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) {
|
||||||
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
||||||
|
|
||||||
nd4j::ops::conv2d_bp conv2dBP;
|
nd4j::ops::conv2d_bp conv2dBP;
|
||||||
const Nd4jStatus status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {});
|
auto status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {});
|
||||||
if (status != ND4J_STATUS_OK)
|
if (status != ND4J_STATUS_OK)
|
||||||
return status;
|
return status;
|
||||||
|
|
||||||
|
|
|
@ -165,10 +165,7 @@ namespace nd4j {
|
||||||
// we build list of input shapes
|
// we build list of input shapes
|
||||||
if (ctx.isFastPath()) {
|
if (ctx.isFastPath()) {
|
||||||
for (const auto p:ctx.fastpath_in()) {
|
for (const auto p:ctx.fastpath_in()) {
|
||||||
if (p == nullptr)
|
inSha.push_back(p == nullptr ? nullptr : p->getShapeInfo());
|
||||||
continue;
|
|
||||||
|
|
||||||
inSha.push_back(p->getShapeInfo());
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (auto p: *ctx.inputs()) {
|
for (auto p: *ctx.inputs()) {
|
||||||
|
@ -184,6 +181,11 @@ namespace nd4j {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// if we override shape function, we'll return size of fastPath
|
||||||
|
if (ctx.isFastPath() && ctx.shapeFunctionOverride()) {
|
||||||
|
return (int) ctx.fastpath_out().size();
|
||||||
|
}
|
||||||
|
|
||||||
// optionally saving input time
|
// optionally saving input time
|
||||||
if (Environment::getInstance()->isProfiling() && node != nullptr) {
|
if (Environment::getInstance()->isProfiling() && node != nullptr) {
|
||||||
inputEnd = std::chrono::system_clock::now();
|
inputEnd = std::chrono::system_clock::now();
|
||||||
|
@ -193,11 +195,6 @@ namespace nd4j {
|
||||||
shapeStart = std::chrono::system_clock::now();
|
shapeStart = std::chrono::system_clock::now();
|
||||||
}
|
}
|
||||||
|
|
||||||
// if we override shape function, we'll return size of fastPath
|
|
||||||
if (ctx.isFastPath() && ctx.shapeFunctionOverride()) {
|
|
||||||
return (int) ctx.fastpath_out().size();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto outSha = this->calculateOutputShape(&inSha, ctx);
|
auto outSha = this->calculateOutputShape(&inSha, ctx);
|
||||||
results = outSha->size();
|
results = outSha->size();
|
||||||
|
|
||||||
|
@ -870,16 +867,10 @@ namespace nd4j {
|
||||||
Context ctx(1);
|
Context ctx(1);
|
||||||
|
|
||||||
for (int e = 0; e < inputs.size(); e++) {
|
for (int e = 0; e < inputs.size(); e++) {
|
||||||
if (inputs[e] == nullptr)
|
|
||||||
break;
|
|
||||||
|
|
||||||
ctx.setInputArray(e, inputs[e]);
|
ctx.setInputArray(e, inputs[e]);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int e = 0; e < outputs.size(); e++) {
|
for (int e = 0; e < outputs.size(); e++) {
|
||||||
if (outputs[e] == nullptr)
|
|
||||||
break;
|
|
||||||
|
|
||||||
ctx.setOutputArray(e, outputs[e]);
|
ctx.setOutputArray(e, outputs[e]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -37,4 +37,33 @@ public:
|
||||||
printf("\n");
|
printf("\n");
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests19, test_conv1d_bp_1) {
|
||||||
|
/*
|
||||||
|
DynamicCustomOp op = DynamicCustomOp.builder("conv1d_bp")
|
||||||
|
.addInputs(
|
||||||
|
Nd4j.create(DataType.FLOAT, 2,2,12),
|
||||||
|
Nd4j.create(DataType.FLOAT, 3,2,3),
|
||||||
|
Nd4j.create(DataType.FLOAT, 2,3,6)
|
||||||
|
)
|
||||||
|
.addOutputs(
|
||||||
|
Nd4j.create(DataType.FLOAT, 2,2,12),
|
||||||
|
Nd4j.create(DataType.FLOAT, 3,2,3))
|
||||||
|
.addIntegerArguments(3,2,0,1,2,0)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
Nd4j.exec(op);
|
||||||
|
*/
|
||||||
|
|
||||||
|
auto t = NDArrayFactory::create<float>('c', {2, 2, 12});
|
||||||
|
auto u = NDArrayFactory::create<float>('c', {3, 2, 3});
|
||||||
|
auto v = NDArrayFactory::create<float>('c', {2, 3, 6});
|
||||||
|
|
||||||
|
nd4j::ops::conv1d_bp op;
|
||||||
|
auto result = op.evaluate({&t, &u, &v}, {3, 2, 0, 1, 2,0});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
|
||||||
|
delete result;
|
||||||
|
}
|
|
@ -810,9 +810,11 @@ public class RandomTests extends BaseNd4jTest {
|
||||||
threads[x].start();
|
threads[x].start();
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int x = 0; x < threads.length; x++) {
|
// we want all threads finished before comparing arrays
|
||||||
|
for (int x = 0; x < threads.length; x++)
|
||||||
threads[x].join();
|
threads[x].join();
|
||||||
|
|
||||||
|
for (int x = 0; x < threads.length; x++) {
|
||||||
assertNotEquals(null, list.get(x));
|
assertNotEquals(null, list.get(x));
|
||||||
|
|
||||||
if (x > 0) {
|
if (x > 0) {
|
||||||
|
|
Loading…
Reference in New Issue