diff --git a/libnd4j/blas/NDArray.hpp b/libnd4j/blas/NDArray.hpp index ba0c34f6c..9c1b44818 100644 --- a/libnd4j/blas/NDArray.hpp +++ b/libnd4j/blas/NDArray.hpp @@ -1004,7 +1004,7 @@ NDArray NDArray::reduceNumber(nd4j::reduce::LongOps op, void *extraParams) const void NDArray::reduceNumber(nd4j::reduce::FloatOps op, NDArray& target, void *extraParams) const { if (isS()) throw std::runtime_error("NDArray::reduceNumber FloatOps: you can't use this method on String array!"); - if(!target.isScalar() || target.dataType() != DataTypeUtils::pickFloatingType(dataType())) + if(target.lengthOf() != 1 || target.dataType() != DataTypeUtils::pickFloatingType(dataType())) throw std::invalid_argument("NDArray::reduceNumber FloatOps: target array should be scalar and have corresponding float type!"); NDArray::prepareSpecialUse({&target}, {this}); @@ -1017,7 +1017,7 @@ void NDArray::reduceNumber(nd4j::reduce::SameOps op, NDArray& target, void *extr if (isS()) throw std::runtime_error("NDArray::reduceNumber SameOps: you can't use this method on String array!"); - if(!target.isScalar() || target.dataType() != dataType()) + if(target.lengthOf() != 1 || target.dataType() != dataType()) throw std::invalid_argument("NDArray::reduceNumber SameOps: target array should be scalar and have same type as this array!"); NDArray::prepareSpecialUse({&target}, {this}); @@ -1030,7 +1030,7 @@ void NDArray::reduceNumber(nd4j::reduce::BoolOps op, NDArray& target, void *extr if (isS()) throw std::runtime_error("NDArray::reduceNumber BoolOps: you can't use this method on String array!"); - if(!target.isScalar() || target.dataType() != DataType::BOOL) + if(target.lengthOf() != 1 || target.dataType() != DataType::BOOL) throw std::invalid_argument("NDArray::reduceNumber BoolOps: target array should be scalar and have bool type!"); NDArray::prepareSpecialUse({&target}, {this}); @@ -1043,7 +1043,7 @@ void NDArray::reduceNumber(nd4j::reduce::LongOps op, NDArray& target, void *extr if (isS()) throw std::runtime_error("NDArray::reduceNumber LongOps: you can't use this method on String array!"); - if(!target.isScalar() || target.dataType() != DataType::INT64) + if(target.lengthOf() != 1 || target.dataType() != DataType::INT64) throw std::invalid_argument("NDArray::reduceNumber LongOps: target array should be scalar and have long type!"); NDArray::prepareSpecialUse({&target}, {this}); @@ -2104,7 +2104,7 @@ void NDArray::operator+=(const NDArray& other) { if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL)) throw nd4j::datatype_exception::build("NDArray operator+=: Cannot add different types", this->dataType(), other.dataType()); - if (!this->isScalar() && other.isScalar()) { + if (this->lengthOf() != 1 && other.lengthOf() == 1) { NDArray::prepareSpecialUse({this}, {this, &other}); NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Add, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr); NDArray::registerSpecialUse({this}, {this, &other}); @@ -2138,7 +2138,7 @@ void NDArray::operator-=(const NDArray& other) { if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL)) throw nd4j::datatype_exception::build("NDArray operator-=: Cannot subtract different types", this->dataType(), other.dataType()); - if (!this->isScalar() && other.isScalar()) { + if (lengthOf() != 1 && other.lengthOf() == 1) { NDArray::prepareSpecialUse({this}, {this, &other}); NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Subtract, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr); NDArray::registerSpecialUse({this}, {this, &other}); @@ -2171,7 +2171,7 @@ void NDArray::operator*=(const NDArray& other) { if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL)) throw nd4j::datatype_exception::build("NDArray operator*=: Cannot multiply different types", this->dataType(), other.dataType()); - if (!this->isScalar() && other.isScalar()) { + if (lengthOf() != 1 && other.lengthOf() == 1) { NDArray::prepareSpecialUse({this}, {this, &other}); NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Multiply, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr); NDArray::registerSpecialUse({this}, {this, &other}); @@ -2208,7 +2208,7 @@ void NDArray::operator/=(const NDArray& other) { throw nd4j::datatype_exception::build("NDArray operator/=: Cannot divide different types", this->dataType(), other.dataType()); } - if (!this->isScalar() && other.isScalar()) { + if (lengthOf() != 1 && other.lengthOf() == 1) { NDArray::prepareSpecialUse({this}, {this, &other}); NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Divide, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr); NDArray::registerSpecialUse({this}, {this, &other}); @@ -2520,12 +2520,12 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray* othe if (isEmpty() || other->isEmpty()) return; - if (isScalar()) { + if (lengthOf() == 1) { target->assign(this); target->applyPairwiseTransform(op.p, *other, extraArgs); return; } - if (other->isScalar()) { + if (other->lengthOf() == 1) { const_cast(this)->applyScalarArr(op.s, other, target, extraArgs); return; } @@ -2560,13 +2560,13 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray* if (isEmpty() || other->isEmpty()) return; - if (isScalar()) { + if (lengthOf() == 1) { NDArray temp(target->_shapeInfo, dataType(), false, getContext()); temp.assign(this); temp.applyPairwiseTransform(op.p, other, target, extraArgs); return; } - if (other->isScalar()) { + if (other->lengthOf() == 1) { this->applyScalarArr(op.s, other, target, extraArgs); return; } @@ -2599,13 +2599,13 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastIntOpsTuple op, const NDArray* o if (isEmpty() || other->isEmpty()) return; - if (isScalar()) { + if (lengthOf() == 1) { NDArray temp(target->_shapeInfo, dataType(), false, getContext()); temp.assign(this); temp.applyPairwiseTransform(op.p, other, target, extraArgs); return; } - if (other->isScalar()) { + if (other->lengthOf() == 1) { this->applyScalarArr(op.s, other, target, extraArgs); return; } @@ -3178,9 +3178,9 @@ void NDArray::assign(const NDArray& other, bool allowParallelism) { return; } - if (other.isScalar()) { + if (other.lengthOf() == 1) { - if(this->isScalar()) { + if(lengthOf() == 1) { NDArray::preparePrimaryUse({this}, {&other}); BUILD_DOUBLE_SELECTOR(dataType(), other.dataType(), templatedDoubleAssign, (buffer(), 0, other.getBuffer(), 0), LIBND4J_TYPES, LIBND4J_TYPES); NDArray::registerPrimaryUse({this}, {&other}); @@ -3559,7 +3559,7 @@ NDArray NDArray::transform(nd4j::transform::BoolOps op, void *extraParams) const void NDArray::applyScalarArr(nd4j::scalar::Ops op, const NDArray* scalar, NDArray* target, ExtraArguments *extraParams) { if (isS()) throw std::runtime_error("NDArray::applyScalarArr: you can't use this method on String array!"); - if (!scalar->isScalar()) + if (scalar->lengthOf() != 1) throw std::invalid_argument("NDArray::applyScalarArr method: operand is not a scalar!"); if(target == nullptr) target = this; @@ -3678,7 +3678,7 @@ void NDArray::applyIndexReduce(nd4j::indexreduce::Ops op, NDArray* target, const NDArray::prepareSpecialUse({target}, {this}); - if (target->isScalar()) { + if (target->lengthOf() == 1) { NativeOpExecutioner::execIndexReduceScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo()); } else { @@ -4060,7 +4060,7 @@ template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, c //////////////////////////////////////////////////////////////////////// void NDArray::p(const Nd4jLong i, const NDArray& scalar) { - if(!scalar.isScalar()) + if(scalar.lengthOf() != 1) throw std::invalid_argument("NDArray::p method: input array must be scalar!"); if (i >= _length) throw std::invalid_argument("NDArray::p(i, NDArray_scalar): input index is out of array length !"); @@ -4074,7 +4074,7 @@ void NDArray::p(const Nd4jLong i, const NDArray& scalar) { //////////////////////////////////////////////////////////////////////// void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const NDArray& scalar) { - if(!scalar.isScalar()) + if(scalar.lengthOf() != 1) throw std::invalid_argument("NDArray::p method: input array must be scalar!"); if (i >= _length) throw std::invalid_argument("NDArray::p(i, NDArray_scalar): input index is out of array length !");