Null arrays fix (#208)
* don't skip null arrays Signed-off-by: raver119 <raver119@gmail.com> * one test tweak Signed-off-by: raver119 <raver119@gmail.com>master
parent
81efa5c3b6
commit
9bb5798cac
|
@ -171,7 +171,7 @@ namespace nd4j {
|
|||
|
||||
Nd4jStatus execute(const std::vector<NDArray*> &inputs, const std::vector<NDArray*> &outputs);
|
||||
|
||||
template <class T>
|
||||
template <class T, typename = std::enable_if<DataTypeUtils::scalarTypesForExecution<T>::value>>
|
||||
Nd4jStatus execute(const std::vector<NDArray*> &inputs, const std::vector<NDArray*> &outputs, std::initializer_list<T> tArgs);
|
||||
|
||||
Nd4jStatus execute(const std::vector<NDArray*> &inputs, const std::vector<NDArray*> &outputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs = std::vector<bool>(), const std::vector<nd4j::DataType> &dArgs = std::vector<nd4j::DataType>(), bool isInplace = false);
|
||||
|
@ -179,7 +179,7 @@ namespace nd4j {
|
|||
|
||||
nd4j::ResultSet* evaluate(const std::vector<NDArray*> &inputs);
|
||||
|
||||
template <class T>
|
||||
template <class T, typename = std::enable_if<DataTypeUtils::scalarTypesForExecution<T>::value>>
|
||||
nd4j::ResultSet* evaluate(const std::vector<NDArray*> &inputs, std::initializer_list<T> args);
|
||||
|
||||
nd4j::ResultSet* evaluate(const std::vector<NDArray*> &inputs, const std::vector<double> &tArgs, const std::vector<Nd4jLong> &iArgs, const std::vector<bool> &bArgs = std::vector<bool>(), const std::vector<nd4j::DataType> &dArgs = std::vector<nd4j::DataType>(), bool isInplace = false);
|
||||
|
|
|
@ -222,7 +222,7 @@ CUSTOM_OP_IMPL(conv1d_bp, 3, 2, false, 0, 5) {
|
|||
auto gradWReshaped = gradW ->reshape(gradW->ordering(), {1, weights->sizeAt(0), weights->sizeAt(1), weights->sizeAt(2)}); // [kW, iC, oC] -> [1, kW, iC, oC]
|
||||
|
||||
nd4j::ops::conv2d_bp conv2dBP;
|
||||
const Nd4jStatus status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {});
|
||||
auto status = conv2dBP.execute({&inputReshaped, &weightsReshaped, bias, &gradOReshaped}, {&gradIReshaped, &gradWReshaped, gradB}, {}, {1,kW, 1,sW, 0,pW, 1,dW, paddingMode, !isNCW}, {});
|
||||
if (status != ND4J_STATUS_OK)
|
||||
return status;
|
||||
|
||||
|
|
|
@ -165,10 +165,7 @@ namespace nd4j {
|
|||
// we build list of input shapes
|
||||
if (ctx.isFastPath()) {
|
||||
for (const auto p:ctx.fastpath_in()) {
|
||||
if (p == nullptr)
|
||||
continue;
|
||||
|
||||
inSha.push_back(p->getShapeInfo());
|
||||
inSha.push_back(p == nullptr ? nullptr : p->getShapeInfo());
|
||||
}
|
||||
} else {
|
||||
for (auto p: *ctx.inputs()) {
|
||||
|
@ -184,6 +181,11 @@ namespace nd4j {
|
|||
}
|
||||
}
|
||||
|
||||
// if we override shape function, we'll return size of fastPath
|
||||
if (ctx.isFastPath() && ctx.shapeFunctionOverride()) {
|
||||
return (int) ctx.fastpath_out().size();
|
||||
}
|
||||
|
||||
// optionally saving input time
|
||||
if (Environment::getInstance()->isProfiling() && node != nullptr) {
|
||||
inputEnd = std::chrono::system_clock::now();
|
||||
|
@ -193,11 +195,6 @@ namespace nd4j {
|
|||
shapeStart = std::chrono::system_clock::now();
|
||||
}
|
||||
|
||||
// if we override shape function, we'll return size of fastPath
|
||||
if (ctx.isFastPath() && ctx.shapeFunctionOverride()) {
|
||||
return (int) ctx.fastpath_out().size();
|
||||
}
|
||||
|
||||
auto outSha = this->calculateOutputShape(&inSha, ctx);
|
||||
results = outSha->size();
|
||||
|
||||
|
@ -870,16 +867,10 @@ namespace nd4j {
|
|||
Context ctx(1);
|
||||
|
||||
for (int e = 0; e < inputs.size(); e++) {
|
||||
if (inputs[e] == nullptr)
|
||||
break;
|
||||
|
||||
ctx.setInputArray(e, inputs[e]);
|
||||
}
|
||||
|
||||
for (int e = 0; e < outputs.size(); e++) {
|
||||
if (outputs[e] == nullptr)
|
||||
break;
|
||||
|
||||
ctx.setOutputArray(e, outputs[e]);
|
||||
}
|
||||
|
||||
|
|
|
@ -38,3 +38,32 @@ public:
|
|||
fflush(stdout);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(DeclarableOpsTests19, test_conv1d_bp_1) {
|
||||
/*
|
||||
DynamicCustomOp op = DynamicCustomOp.builder("conv1d_bp")
|
||||
.addInputs(
|
||||
Nd4j.create(DataType.FLOAT, 2,2,12),
|
||||
Nd4j.create(DataType.FLOAT, 3,2,3),
|
||||
Nd4j.create(DataType.FLOAT, 2,3,6)
|
||||
)
|
||||
.addOutputs(
|
||||
Nd4j.create(DataType.FLOAT, 2,2,12),
|
||||
Nd4j.create(DataType.FLOAT, 3,2,3))
|
||||
.addIntegerArguments(3,2,0,1,2,0)
|
||||
.build();
|
||||
|
||||
Nd4j.exec(op);
|
||||
*/
|
||||
|
||||
auto t = NDArrayFactory::create<float>('c', {2, 2, 12});
|
||||
auto u = NDArrayFactory::create<float>('c', {3, 2, 3});
|
||||
auto v = NDArrayFactory::create<float>('c', {2, 3, 6});
|
||||
|
||||
nd4j::ops::conv1d_bp op;
|
||||
auto result = op.evaluate({&t, &u, &v}, {3, 2, 0, 1, 2,0});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
|
||||
delete result;
|
||||
}
|
|
@ -810,9 +810,11 @@ public class RandomTests extends BaseNd4jTest {
|
|||
threads[x].start();
|
||||
}
|
||||
|
||||
for (int x = 0; x < threads.length; x++) {
|
||||
// we want all threads finished before comparing arrays
|
||||
for (int x = 0; x < threads.length; x++)
|
||||
threads[x].join();
|
||||
|
||||
for (int x = 0; x < threads.length; x++) {
|
||||
assertNotEquals(null, list.get(x));
|
||||
|
||||
if (x > 0) {
|
||||
|
|
Loading…
Reference in New Issue