Merge pull request #6 from KonduitAI/shyrma_broadcast
- replace condition isScalar() by condition length ==1 in some NDArra…master
commit
35e6ffede4
|
@ -1004,7 +1004,7 @@ NDArray NDArray::reduceNumber(nd4j::reduce::LongOps op, void *extraParams) const
|
||||||
void NDArray::reduceNumber(nd4j::reduce::FloatOps op, NDArray& target, void *extraParams) const {
|
void NDArray::reduceNumber(nd4j::reduce::FloatOps op, NDArray& target, void *extraParams) const {
|
||||||
if (isS())
|
if (isS())
|
||||||
throw std::runtime_error("NDArray::reduceNumber FloatOps: you can't use this method on String array!");
|
throw std::runtime_error("NDArray::reduceNumber FloatOps: you can't use this method on String array!");
|
||||||
if(!target.isScalar() || target.dataType() != DataTypeUtils::pickFloatingType(dataType()))
|
if(target.lengthOf() != 1 || target.dataType() != DataTypeUtils::pickFloatingType(dataType()))
|
||||||
throw std::invalid_argument("NDArray::reduceNumber FloatOps: target array should be scalar and have corresponding float type!");
|
throw std::invalid_argument("NDArray::reduceNumber FloatOps: target array should be scalar and have corresponding float type!");
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&target}, {this});
|
NDArray::prepareSpecialUse({&target}, {this});
|
||||||
|
@ -1017,7 +1017,7 @@ void NDArray::reduceNumber(nd4j::reduce::SameOps op, NDArray& target, void *extr
|
||||||
|
|
||||||
if (isS())
|
if (isS())
|
||||||
throw std::runtime_error("NDArray::reduceNumber SameOps: you can't use this method on String array!");
|
throw std::runtime_error("NDArray::reduceNumber SameOps: you can't use this method on String array!");
|
||||||
if(!target.isScalar() || target.dataType() != dataType())
|
if(target.lengthOf() != 1 || target.dataType() != dataType())
|
||||||
throw std::invalid_argument("NDArray::reduceNumber SameOps: target array should be scalar and have same type as this array!");
|
throw std::invalid_argument("NDArray::reduceNumber SameOps: target array should be scalar and have same type as this array!");
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&target}, {this});
|
NDArray::prepareSpecialUse({&target}, {this});
|
||||||
|
@ -1030,7 +1030,7 @@ void NDArray::reduceNumber(nd4j::reduce::BoolOps op, NDArray& target, void *extr
|
||||||
|
|
||||||
if (isS())
|
if (isS())
|
||||||
throw std::runtime_error("NDArray::reduceNumber BoolOps: you can't use this method on String array!");
|
throw std::runtime_error("NDArray::reduceNumber BoolOps: you can't use this method on String array!");
|
||||||
if(!target.isScalar() || target.dataType() != DataType::BOOL)
|
if(target.lengthOf() != 1 || target.dataType() != DataType::BOOL)
|
||||||
throw std::invalid_argument("NDArray::reduceNumber BoolOps: target array should be scalar and have bool type!");
|
throw std::invalid_argument("NDArray::reduceNumber BoolOps: target array should be scalar and have bool type!");
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&target}, {this});
|
NDArray::prepareSpecialUse({&target}, {this});
|
||||||
|
@ -1043,7 +1043,7 @@ void NDArray::reduceNumber(nd4j::reduce::LongOps op, NDArray& target, void *extr
|
||||||
|
|
||||||
if (isS())
|
if (isS())
|
||||||
throw std::runtime_error("NDArray::reduceNumber LongOps: you can't use this method on String array!");
|
throw std::runtime_error("NDArray::reduceNumber LongOps: you can't use this method on String array!");
|
||||||
if(!target.isScalar() || target.dataType() != DataType::INT64)
|
if(target.lengthOf() != 1 || target.dataType() != DataType::INT64)
|
||||||
throw std::invalid_argument("NDArray::reduceNumber LongOps: target array should be scalar and have long type!");
|
throw std::invalid_argument("NDArray::reduceNumber LongOps: target array should be scalar and have long type!");
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({&target}, {this});
|
NDArray::prepareSpecialUse({&target}, {this});
|
||||||
|
@ -2104,7 +2104,7 @@ void NDArray::operator+=(const NDArray& other) {
|
||||||
if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL))
|
if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL))
|
||||||
throw nd4j::datatype_exception::build("NDArray operator+=: Cannot add different types", this->dataType(), other.dataType());
|
throw nd4j::datatype_exception::build("NDArray operator+=: Cannot add different types", this->dataType(), other.dataType());
|
||||||
|
|
||||||
if (!this->isScalar() && other.isScalar()) {
|
if (this->lengthOf() != 1 && other.lengthOf() == 1) {
|
||||||
NDArray::prepareSpecialUse({this}, {this, &other});
|
NDArray::prepareSpecialUse({this}, {this, &other});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Add, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Add, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
|
||||||
NDArray::registerSpecialUse({this}, {this, &other});
|
NDArray::registerSpecialUse({this}, {this, &other});
|
||||||
|
@ -2138,7 +2138,7 @@ void NDArray::operator-=(const NDArray& other) {
|
||||||
if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL))
|
if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL))
|
||||||
throw nd4j::datatype_exception::build("NDArray operator-=: Cannot subtract different types", this->dataType(), other.dataType());
|
throw nd4j::datatype_exception::build("NDArray operator-=: Cannot subtract different types", this->dataType(), other.dataType());
|
||||||
|
|
||||||
if (!this->isScalar() && other.isScalar()) {
|
if (lengthOf() != 1 && other.lengthOf() == 1) {
|
||||||
NDArray::prepareSpecialUse({this}, {this, &other});
|
NDArray::prepareSpecialUse({this}, {this, &other});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Subtract, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Subtract, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
|
||||||
NDArray::registerSpecialUse({this}, {this, &other});
|
NDArray::registerSpecialUse({this}, {this, &other});
|
||||||
|
@ -2171,7 +2171,7 @@ void NDArray::operator*=(const NDArray& other) {
|
||||||
if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL))
|
if (!Environment::getInstance()->isExperimentalBuild() && this->dataType() != other.dataType() && (this->dataType() != DataType::BOOL || other.dataType() != BOOL))
|
||||||
throw nd4j::datatype_exception::build("NDArray operator*=: Cannot multiply different types", this->dataType(), other.dataType());
|
throw nd4j::datatype_exception::build("NDArray operator*=: Cannot multiply different types", this->dataType(), other.dataType());
|
||||||
|
|
||||||
if (!this->isScalar() && other.isScalar()) {
|
if (lengthOf() != 1 && other.lengthOf() == 1) {
|
||||||
NDArray::prepareSpecialUse({this}, {this, &other});
|
NDArray::prepareSpecialUse({this}, {this, &other});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Multiply, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Multiply, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
|
||||||
NDArray::registerSpecialUse({this}, {this, &other});
|
NDArray::registerSpecialUse({this}, {this, &other});
|
||||||
|
@ -2208,7 +2208,7 @@ void NDArray::operator/=(const NDArray& other) {
|
||||||
throw nd4j::datatype_exception::build("NDArray operator/=: Cannot divide different types", this->dataType(), other.dataType());
|
throw nd4j::datatype_exception::build("NDArray operator/=: Cannot divide different types", this->dataType(), other.dataType());
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!this->isScalar() && other.isScalar()) {
|
if (lengthOf() != 1 && other.lengthOf() == 1) {
|
||||||
NDArray::prepareSpecialUse({this}, {this, &other});
|
NDArray::prepareSpecialUse({this}, {this, &other});
|
||||||
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Divide, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
|
NativeOpExecutioner::execScalar(getContext(), nd4j::scalar::Divide, buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), buffer(), getShapeInfo(), specialBuffer(), getSpecialShapeInfo(), other.getBuffer(), other.getShapeInfo(), other.getSpecialBuffer(), other.getSpecialShapeInfo(), nullptr);
|
||||||
NDArray::registerSpecialUse({this}, {this, &other});
|
NDArray::registerSpecialUse({this}, {this, &other});
|
||||||
|
@ -2520,12 +2520,12 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastOpsTuple op, const NDArray* othe
|
||||||
if (isEmpty() || other->isEmpty())
|
if (isEmpty() || other->isEmpty())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
if (isScalar()) {
|
if (lengthOf() == 1) {
|
||||||
target->assign(this);
|
target->assign(this);
|
||||||
target->applyPairwiseTransform(op.p, *other, extraArgs);
|
target->applyPairwiseTransform(op.p, *other, extraArgs);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (other->isScalar()) {
|
if (other->lengthOf() == 1) {
|
||||||
const_cast<NDArray*>(this)->applyScalarArr(op.s, other, target, extraArgs);
|
const_cast<NDArray*>(this)->applyScalarArr(op.s, other, target, extraArgs);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -2560,13 +2560,13 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastBoolOpsTuple op, const NDArray*
|
||||||
if (isEmpty() || other->isEmpty())
|
if (isEmpty() || other->isEmpty())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
if (isScalar()) {
|
if (lengthOf() == 1) {
|
||||||
NDArray temp(target->_shapeInfo, dataType(), false, getContext());
|
NDArray temp(target->_shapeInfo, dataType(), false, getContext());
|
||||||
temp.assign(this);
|
temp.assign(this);
|
||||||
temp.applyPairwiseTransform(op.p, other, target, extraArgs);
|
temp.applyPairwiseTransform(op.p, other, target, extraArgs);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (other->isScalar()) {
|
if (other->lengthOf() == 1) {
|
||||||
this->applyScalarArr(op.s, other, target, extraArgs);
|
this->applyScalarArr(op.s, other, target, extraArgs);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -2599,13 +2599,13 @@ void NDArray::applyTrueBroadcast(nd4j::BroadcastIntOpsTuple op, const NDArray* o
|
||||||
if (isEmpty() || other->isEmpty())
|
if (isEmpty() || other->isEmpty())
|
||||||
return;
|
return;
|
||||||
|
|
||||||
if (isScalar()) {
|
if (lengthOf() == 1) {
|
||||||
NDArray temp(target->_shapeInfo, dataType(), false, getContext());
|
NDArray temp(target->_shapeInfo, dataType(), false, getContext());
|
||||||
temp.assign(this);
|
temp.assign(this);
|
||||||
temp.applyPairwiseTransform(op.p, other, target, extraArgs);
|
temp.applyPairwiseTransform(op.p, other, target, extraArgs);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (other->isScalar()) {
|
if (other->lengthOf() == 1) {
|
||||||
this->applyScalarArr(op.s, other, target, extraArgs);
|
this->applyScalarArr(op.s, other, target, extraArgs);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -3178,9 +3178,9 @@ void NDArray::assign(const NDArray& other, bool allowParallelism) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (other.isScalar()) {
|
if (other.lengthOf() == 1) {
|
||||||
|
|
||||||
if(this->isScalar()) {
|
if(lengthOf() == 1) {
|
||||||
NDArray::preparePrimaryUse({this}, {&other});
|
NDArray::preparePrimaryUse({this}, {&other});
|
||||||
BUILD_DOUBLE_SELECTOR(dataType(), other.dataType(), templatedDoubleAssign, (buffer(), 0, other.getBuffer(), 0), LIBND4J_TYPES, LIBND4J_TYPES);
|
BUILD_DOUBLE_SELECTOR(dataType(), other.dataType(), templatedDoubleAssign, (buffer(), 0, other.getBuffer(), 0), LIBND4J_TYPES, LIBND4J_TYPES);
|
||||||
NDArray::registerPrimaryUse({this}, {&other});
|
NDArray::registerPrimaryUse({this}, {&other});
|
||||||
|
@ -3559,7 +3559,7 @@ NDArray NDArray::transform(nd4j::transform::BoolOps op, void *extraParams) const
|
||||||
void NDArray::applyScalarArr(nd4j::scalar::Ops op, const NDArray* scalar, NDArray* target, ExtraArguments *extraParams) {
|
void NDArray::applyScalarArr(nd4j::scalar::Ops op, const NDArray* scalar, NDArray* target, ExtraArguments *extraParams) {
|
||||||
if (isS())
|
if (isS())
|
||||||
throw std::runtime_error("NDArray::applyScalarArr: you can't use this method on String array!");
|
throw std::runtime_error("NDArray::applyScalarArr: you can't use this method on String array!");
|
||||||
if (!scalar->isScalar())
|
if (scalar->lengthOf() != 1)
|
||||||
throw std::invalid_argument("NDArray::applyScalarArr method: operand is not a scalar!");
|
throw std::invalid_argument("NDArray::applyScalarArr method: operand is not a scalar!");
|
||||||
if(target == nullptr)
|
if(target == nullptr)
|
||||||
target = this;
|
target = this;
|
||||||
|
@ -3678,7 +3678,7 @@ void NDArray::applyIndexReduce(nd4j::indexreduce::Ops op, NDArray* target, const
|
||||||
|
|
||||||
NDArray::prepareSpecialUse({target}, {this});
|
NDArray::prepareSpecialUse({target}, {this});
|
||||||
|
|
||||||
if (target->isScalar()) {
|
if (target->lengthOf() == 1) {
|
||||||
NativeOpExecutioner::execIndexReduceScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo());
|
NativeOpExecutioner::execIndexReduceScalar(getContext(), op, getBuffer(), getShapeInfo(), getSpecialBuffer(), getSpecialShapeInfo(), params, target->buffer(), target->shapeInfo(), target->specialBuffer(), target->specialShapeInfo());
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
@ -4060,7 +4060,7 @@ template void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, c
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::p(const Nd4jLong i, const NDArray& scalar) {
|
void NDArray::p(const Nd4jLong i, const NDArray& scalar) {
|
||||||
|
|
||||||
if(!scalar.isScalar())
|
if(scalar.lengthOf() != 1)
|
||||||
throw std::invalid_argument("NDArray::p method: input array must be scalar!");
|
throw std::invalid_argument("NDArray::p method: input array must be scalar!");
|
||||||
if (i >= _length)
|
if (i >= _length)
|
||||||
throw std::invalid_argument("NDArray::p(i, NDArray_scalar): input index is out of array length !");
|
throw std::invalid_argument("NDArray::p(i, NDArray_scalar): input index is out of array length !");
|
||||||
|
@ -4074,7 +4074,7 @@ void NDArray::p(const Nd4jLong i, const NDArray& scalar) {
|
||||||
////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////
|
||||||
void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const NDArray& scalar) {
|
void NDArray::p(const Nd4jLong i, const Nd4jLong j, const Nd4jLong k, const Nd4jLong l, const NDArray& scalar) {
|
||||||
|
|
||||||
if(!scalar.isScalar())
|
if(scalar.lengthOf() != 1)
|
||||||
throw std::invalid_argument("NDArray::p method: input array must be scalar!");
|
throw std::invalid_argument("NDArray::p method: input array must be scalar!");
|
||||||
if (i >= _length)
|
if (i >= _length)
|
||||||
throw std::invalid_argument("NDArray::p(i, NDArray_scalar): input index is out of array length !");
|
throw std::invalid_argument("NDArray::p(i, NDArray_scalar): input index is out of array length !");
|
||||||
|
|
Loading…
Reference in New Issue