Merge pull request #6 from KonduitAI/shyrma_broadcast

- replace condition isScalar() by condition length ==1 in some NDArra…
master
raver119 2019-10-22 07:56:04 +03:00 committed by GitHub
commit 35e6ffede4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 20 additions and 20 deletions

View File

@ -1004,7 +1004,7 @@ NDArray NDArray::reduceNumber(nd4j::reduce::LongOps op, void *extraParams) const
void NDArray::reduceNumber(nd4j::reduce::FloatOps op, NDArray& target, void *extraParams) const {
if (isS())
throw std::runtime_error("NDArray::reduceNumber FloatOps: you can't use this method on String array!");
if(!target.isScalar() || target.dataType() != DataTypeUtils::pickFloatingType(dataType()))
if(target.lengthOf() != 1 || target.dataType() != DataTypeUtils::pickFloatingType(dataType()))
throw std::invalid_argument("NDArray::reduceNumber FloatOps: target array should be scalar and have corresponding float type!");
NDArray::prepareSpecialUse({&target}, {this});
@ -1017,7 +1017,7 @@ void NDArray::reduceNumber(nd4j::reduce::SameOps op, NDArray& target, void *extr
if (isS())
throw std::runtime_error("NDArray::reduceNumber SameOps: you can't use this method on String array!");
if(!target.isScalar() || target.dataType() != dataType())
if(target.lengthOf() != 1 || target.dataType() != dataType())
throw std::invalid_argument("NDArray::reduceNumber SameOps: target array should be scalar and have same type as this array!");
NDArray::prepareSpecialUse({&target}, {this});
@ -1030,7 +1030,7 @@ void NDArray::reduceNumber(nd4j::reduce::BoolOps op, NDArray& target, void *extr
if (isS())
throw std::runtime_error("NDArray::reduceNumber BoolOps: you can't use this method on String array!");
if(!target.isScalar() || target.dataType() != DataType::BOOL)
if(target.lengthOf() != 1 || target.dataType() != DataType::BOOL)
throw std::invalid_argument("NDArray::reduceNumber BoolOps: target array should be scalar and have bool type!");
NDArray::prepareSpecialUse({&target}, {this});
@ -1043,7 +1043,7 @@ void NDArray::reduceNumber(nd4j::reduce::LongOps op, NDArray& target, void *extr
if (isS())
throw std::runtime_error("NDArray::reduceNumber LongOps: you can't use this method on String array!");
if(!target.isScalar() || target.dataType() != DataType::INT64)
if(target.lengthOf() != 1 || target.dataType() != DataType::INT64)
throw std::invalid_argument("NDArray::reduceNumber LongOps: target array should be scalar and have long type!");
NDArray::prepareSpecialUse({&target}, {this});
@ -2104,7 +2104,7 @@ void NDArray::operator+=(const NDArray& other) {
if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL))
throw nd4j::datatype_exception::build("NDArray operator+=: Cannot add different types", this->dataType(), other.dataType());
if (!this->isScalar() && other.isScalar()) {
if (this->lengthOf() != 1 && other.lengthOf() == 1) {
NDArray::prepareSpecialUse({this}, {this, &other});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Add, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
NDArray::registerSpecialUse({this}, {this, &other});
@ -2138,7 +2138,7 @@ void NDArray::operator-=(const NDArray& other) {
if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL))
throw nd4j::datatype_exception::build("NDArray operator-=: Cannot subtract different types", this->dataType(), other.dataType());
if (!this->isScalar() && other.isScalar()) {
if (lengthOf() != 1 && other.lengthOf() == 1) {
NDArray::prepareSpecialUse({this}, {this, &other});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Subtract, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
NDArray::registerSpecialUse({this}, {this, &other});
@ -2171,7 +2171,7 @@ void NDArray::operator*=(const NDArray& other) {
if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL))
throw nd4j::datatype_exception::build("NDArray operator*=: Cannot multiply different types", this->dataType(), other.dataType());
if (!this->isScalar() && other.isScalar()) {
if (lengthOf() != 1 && other.lengthOf() == 1) {
NDArray::prepareSpecialUse({this}, {this, &other});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Multiply, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
NDArray::registerSpecialUse({this}, {this, &other});
@ -2208,7 +2208,7 @@ void NDArray::operator/=(const NDArray& other) {
throw nd4j::datatype_exception::build("NDArray operator/=: Cannot divide different types", this->dataType(), other.dataType());
}
if (!this->isScalar() && other.isScalar()) {
if (lengthOf() != 1 && other.lengthOf() == 1) {
NDArray::prepareSpecialUse({this}, {this, &other});
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Divide, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
NDArray::registerSpecialUse({this}, {this, &other});
@ -2520,12 +2520,12 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray* othe
if (isEmpty() || other->isEmpty())
return;
if (isScalar()) {
if (lengthOf() == 1) {
target->assign(this);
target->applyPairwiseTransform(op.p, *other, extraArgs);
return;
}
if (other->isScalar()) {
if (other->lengthOf() == 1) {
const_cast<NDArray*>(this)->applyScalarArr(op.s, other, target, extraArgs);
return;
}
@ -2560,13 +2560,13 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray*
if (isEmpty() || other->isEmpty())
return;
if (isScalar()) {
if (lengthOf() == 1) {
NDArray temp(target->_shapeInfo, dataType(), false, getContext());
temp.assign(this);
temp.applyPairwiseTransform(op.p, other, target, extraArgs);
return;
}
if (other->isScalar()) {
if (other->lengthOf() == 1) {
this->applyScalarArr(op.s, other, target, extraArgs);
return;
}
@ -2599,13 +2599,13 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastIntOpsTuple op, const NDArray* o
if (isEmpty() || other->isEmpty())
return;
if (isScalar()) {
if (lengthOf() == 1) {
NDArray temp(target->_shapeInfo, dataType(), false, getContext());
temp.assign(this);
temp.applyPairwiseTransform(op.p, other, target, extraArgs);
return;
}
if (other->isScalar()) {
if (other->lengthOf() == 1) {
this->applyScalarArr(op.s, other, target, extraArgs);
return;
}
@ -3178,9 +3178,9 @@ void NDArray::assign(const NDArray& other, bool allowParallelism) {
return;
}
if (other.isScalar()) {
if (other.lengthOf() == 1) {
if(this->isScalar()) {
if(lengthOf() == 1) {
NDArray::preparePrimaryUse({this}, {&other});
BUILD_DOUBLE_SELECTOR(dataType(), other.dataType(), templatedDoubleAssign, (buffer(), 0, other.getBuffer(), 0), LIBND4J_TYPES, LIBND4J_TYPES);
NDArray::registerPrimaryUse({this}, {&other});
@ -3559,7 +3559,7 @@ NDArray NDArray::transform(nd4j::transform::BoolOps op, void *extraParams) const
void NDArray::applyScalarArr(nd4j::scalar::Ops op, const NDArray* scalar, NDArray* target, ExtraArguments *extraParams) {
if (isS())
throw std::runtime_error("NDArray::applyScalarArr: you can't use this method on String array!");
if (!scalar->isScalar())
if (scalar->lengthOf() != 1)
throw std::invalid_argument("NDArray::applyScalarArr method: operand is not a scalar!");
if(target == nullptr)
target = this;
@ -3678,7 +3678,7 @@ void NDArray::applyIndexReduce(nd4j::indexreduce::Ops op, NDArray* target, const
NDArray::prepareSpecialUse({target}, {this});
if (target->isScalar()) {
if (target->lengthOf() == 1) {
NativeOpExecutioner::execIndexReduceScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo());
}
else {
@ -4060,7 +4060,7 @@ template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, c
////////////////////////////////////////////////////////////////////////
void NDArray::p(const Nd4jLong i, const NDArray& scalar) {
if(!scalar.isScalar())
if(scalar.lengthOf() != 1)
throw std::invalid_argument("NDArray::p method: input array must be scalar!");
if (i >= _length)
throw std::invalid_argument("NDArray::p(i, NDArray_scalar): input index is out of array length !");
@ -4074,7 +4074,7 @@ void NDArray::p(const Nd4jLong i, const NDArray& scalar) {
////////////////////////////////////////////////////////////////////////
void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const NDArray& scalar) {
if(!scalar.isScalar())
if(scalar.lengthOf() != 1)
throw std::invalid_argument("NDArray::p method: input array must be scalar!");
if (i >= _length)
throw std::invalid_argument("NDArray::p(i, NDArray_scalar): input index is out of array length !");