- new implementations for Index Reductions (#421)

* - new implementations for Index Reductions
- small fix in the legacy reduction
- disabled index reduction bench tests inside Playground

Signed-off-by: Abdelrauf <rauf@konduit.ai>

* Allow LIBND4J_TYPES

Signed-off-by: Abdelrauf <rauf@konduit.ai>

* index reduction stuff split into bunch of units

* meh

* IMax switched to new impl

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* minor fix + test

* minor fix

* index range fix

Signed-off-by: Abdelrauf <rauf@konduit.ai>

* noop on empty outputs

* minor fix

* minor fix

Signed-off-by: Abdelrauf <rauf@konduit.ai>

* ArgMax replaces IMax

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* argmax/argmin/argamax/argamin shape functions updated

* ArgAmax/ArgAmin/ArgMin replaces IAMax/IAMin/IMin

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* argmax/argmin/argamax/argamin CUDA

* IMax replaced in dl4j

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* Codegen output

* imports fixed

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* fix compilation issue

Signed-off-by: Abdelrauf <rauf@konduit.ai>

* Auto-generate compilation units

Signed-off-by: Abdelrauf <rauf@konduit.ai>

* Should fix NDArray refactored function calls in indexReductions.cu

Signed-off-by: Abdelrauf <rauf@konduit.ai>

Co-authored-by: raver119@gmail.com <raver119@gmail.com>
Co-authored-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>
master
Abdelrauf 2020-05-14 14:41:55 +04:00 committed by GitHub
parent 62e9dc83e0
commit 69d91e272a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
55 changed files with 2742 additions and 488 deletions

View File

@ -20,7 +20,7 @@ import org.deeplearning4j.clustering.algorithm.Distance;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.ReduceOp; import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.impl.indexaccum.IMin; import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair; import org.nd4j.common.primitives.Pair;
@ -29,7 +29,7 @@ public class CentersHolder {
private long index = 0; private long index = 0;
protected transient ReduceOp op; protected transient ReduceOp op;
protected IMin imin; protected ArgMin imin;
protected transient INDArray distances; protected transient INDArray distances;
protected transient INDArray argMin; protected transient INDArray argMin;
@ -60,7 +60,7 @@ public class CentersHolder {
if (op == null) { if (op == null) {
op = ClusterUtils.createDistanceFunctionOp(distanceFunction, centers, point.getArray(), 1); op = ClusterUtils.createDistanceFunctionOp(distanceFunction, centers, point.getArray(), 1);
imin = new IMin(distances, argMin); imin = new ArgMin(distances, argMin);
op.setZ(distances); op.setZ(distances);
} }
@ -84,7 +84,7 @@ public class CentersHolder {
if (op == null) { if (op == null) {
op = ClusterUtils.createDistanceFunctionOp(distanceFunction, centers, point.getArray(), 1); op = ClusterUtils.createDistanceFunctionOp(distanceFunction, centers, point.getArray(), 1);
imin = new IMin(distances, argMin); imin = new ArgMin(distances, argMin);
op.setZ(distances); op.setZ(distances);
} }

View File

@ -23,6 +23,7 @@ import org.deeplearning4j.BaseDL4JTest;
import org.junit.Rule; import org.junit.Rule;
import org.junit.rules.TemporaryFolder; import org.junit.rules.TemporaryFolder;
import org.nd4j.common.io.ClassPathResource; import org.nd4j.common.io.ClassPathResource;
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
import org.deeplearning4j.models.word2vec.VocabWord; import org.deeplearning4j.models.word2vec.VocabWord;
import org.deeplearning4j.models.word2vec.wordstore.VocabCache; import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareFileSentenceIterator; import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareFileSentenceIterator;
@ -31,7 +32,6 @@ import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFac
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
import org.junit.Test; import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.indexaccum.IMax;
import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.util.SerializationUtils; import org.nd4j.common.util.SerializationUtils;
@ -111,7 +111,7 @@ public class BagOfWordsVectorizerTest extends BaseDL4JTest {
INDArray labelz = dataSet.getLabels(); INDArray labelz = dataSet.getLabels();
log.info("Labels array: " + labelz); log.info("Labels array: " + labelz);
int idx2 = Nd4j.getExecutioner().exec(new IMax(labelz)).getInt(0); int idx2 = Nd4j.getExecutioner().exec(new ArgMax(labelz))[0].getInt(0);
//int idx2 = ((IndexAccumulation) Nd4j.getExecutioner().exec(new IMax(labelz))).getFinalResult().intValue(); //int idx2 = ((IndexAccumulation) Nd4j.getExecutioner().exec(new IMax(labelz))).getFinalResult().intValue();
// assertEquals(1.0, dataSet.getLabels().getDouble(0), 0.1); // assertEquals(1.0, dataSet.getLabels().getDouble(0), 0.1);
@ -125,7 +125,7 @@ public class BagOfWordsVectorizerTest extends BaseDL4JTest {
assertEquals(1, dataSet.getFeatures().getDouble(vocabCache.tokenFor("1").getIndex()), 0.1); assertEquals(1, dataSet.getFeatures().getDouble(vocabCache.tokenFor("1").getIndex()), 0.1);
assertEquals(0, dataSet.getFeatures().getDouble(vocabCache.tokenFor("2").getIndex()), 0.1); assertEquals(0, dataSet.getFeatures().getDouble(vocabCache.tokenFor("2").getIndex()), 0.1);
int idx1 = Nd4j.getExecutioner().exec(new IMax(dataSet.getLabels())).getInt(0); int idx1 = Nd4j.getExecutioner().exec(new ArgMax(dataSet.getLabels()))[0].getInt(0);
//int idx1 = ((IndexAccumulation) Nd4j.getExecutioner().exec(new IMax(dataSet.getLabels()))).getFinalResult().intValue(); //int idx1 = ((IndexAccumulation) Nd4j.getExecutioner().exec(new IMax(dataSet.getLabels()))).getFinalResult().intValue();
//assertEquals(0.0, dataSet.getLabels().getDouble(0), 0.1); //assertEquals(0.0, dataSet.getLabels().getDouble(0), 0.1);

View File

@ -294,12 +294,26 @@ elseif(SD_CPU)
file(GLOB_RECURSE LEGACY_SOURCES false ../include/legacy/impl/*.cpp ../include/legacy/cpu/*.cpp ../include/legacy/*.h) file(GLOB_RECURSE LEGACY_SOURCES false ../include/legacy/impl/*.cpp ../include/legacy/cpu/*.cpp ../include/legacy/*.h)
file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/*.cpp ../include/loops/*.h) file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/*.cpp ../include/loops/*.h)
file(GLOB_RECURSE COMPILATION_UNITS false ../include/ops/declarable/helpers/cpu/compilation_units/*.cpp.in)
foreach(FL_ITEM ${COMPILATION_UNITS})
string(REGEX MATCH "^(.*)\\.cpp\.in$" dummy ${FL_ITEM})
set(FL_ITEM_WLE ${CMAKE_MATCH_1})
foreach(FL_TYPE_INDEX RANGE 0 9)
message( "${FL_ITEM_WLE}_${FL_TYPE_INDEX}.cpp")
configure_file( "${FL_ITEM}" "${FL_ITEM_WLE}_${FL_TYPE_INDEX}.cpp" @ONLY)
LIST(APPEND CUSTOMOPS_GENERIC_SOURCES ${FL_ITEM_WLE}_${FL_TYPE_INDEX}.cpp )
endforeach()
endforeach()
if (SD_X86_BUILD) if (SD_X86_BUILD)
# we disable platform optimizations for certains files for linux/macos # we disable platform optimizations for certains files for linux/macos
set_source_files_properties(cpu/NativeOps.cpp PROPERTIES COMPILE_FLAGS "-march=x86-64 -mtune=generic") set_source_files_properties(cpu/NativeOps.cpp PROPERTIES COMPILE_FLAGS "-march=x86-64 -mtune=generic")
set_source_files_properties(../include/helpers/impl/OpTracker.cpp PROPERTIES COMPILE_FLAGS "-march=x86-64 -mtune=generic") set_source_files_properties(../include/helpers/impl/OpTracker.cpp PROPERTIES COMPILE_FLAGS "-march=x86-64 -mtune=generic")
endif() endif()
if(SD_CHECK_VECTORIZATION) if(SD_CHECK_VECTORIZATION)
set(VECT_FILES cpu/NativeOps.cpp ${OPS_SOURCES} ${HELPERS_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${LOOPS_SOURCES}) set(VECT_FILES cpu/NativeOps.cpp ${OPS_SOURCES} ${HELPERS_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${LOOPS_SOURCES})
if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")

View File

@ -19,12 +19,13 @@
// //
#ifndef LIBND4J_LOOPCOORDSHELPER_H #ifndef LIBND4J_LOOPCOORDSHELPER_H
#define LIBND4J_LOOPCOORDSHELPER_H #define LIBND4J_LOOPCOORDSHELPER_H
#include <vector>
#include <cstddef> #include <cstddef>
#include <type_traits> #include <type_traits>
#include <utility> #include <utility>
#include <system/pointercast.h> #include <system/pointercast.h>
#include <system/op_boilerplate.h> #include <system/op_boilerplate.h>
#include <helpers/shape.h>
namespace sd { namespace sd {
#if defined(__GNUC__) #if defined(__GNUC__)
@ -125,7 +126,7 @@ namespace sd {
} }
FORCEINLINE zip_size_t offset_from_coords(const Nd4jLong*& x_strides, const Nd4jLong*& z_strides, const Nd4jLong* coords, const Nd4jLong& rank) { FORCEINLINE zip_size_t offset_from_coords(const Nd4jLong* x_strides, const Nd4jLong* z_strides, const Nd4jLong* coords, const Nd4jLong& rank) {
zip_size_t offset = { 0,0 }; zip_size_t offset = { 0,0 };
size_t rank_4 = rank & -4; size_t rank_4 = rank & -4;
@ -435,6 +436,509 @@ namespace sd {
return last_offset; return last_offset;
} }
struct triple_size_t {
size_t first;
size_t second;
size_t third;
};
template<bool Last_Index_Faster = true>
FORCEINLINE triple_size_t inc_coords(const Nd4jLong* bases, const Nd4jLong* x_strides, const Nd4jLong* y_strides, const Nd4jLong* z_strides, Nd4jLong* coords, triple_size_t last_offset, const size_t rank, const size_t skip = 0) {
Nd4jLong val = 0;
for (int i = rank - skip - 1; i >= 0; i--) {
val = coords[i] + 1;
if (likely(val < bases[i])) {
coords[i] = val;
last_offset.first += x_strides[i];
last_offset.second += y_strides[i];
last_offset.third += z_strides[i];
break;
}
else {
last_offset.first -= coords[i] * x_strides[i];
last_offset.second -= coords[i] * y_strides[i];
last_offset.third -= coords[i] * z_strides[i];
coords[i] = 0;
}
}
return last_offset;
}
template<>
FORCEINLINE triple_size_t inc_coords<false>(const Nd4jLong* bases, const Nd4jLong* x_strides, const Nd4jLong* y_strides, const Nd4jLong* z_strides, Nd4jLong* coords, triple_size_t last_offset, const size_t rank, const size_t skip) {
Nd4jLong val = 0;
for (int i = skip; i < rank; i++) {
val = coords[i] + 1;
if (likely(val < bases[i])) {
coords[i] = val;
last_offset.first += x_strides[i];
last_offset.second += y_strides[i];
last_offset.third += z_strides[i];
break;
}
else {
last_offset.first -= coords[i] * x_strides[i];
last_offset.second -= coords[i] * y_strides[i];
last_offset.third -= coords[i] * z_strides[i];
coords[i] = 0;
}
}
return last_offset;
}
FORCEINLINE triple_size_t offset_from_coords(const Nd4jLong* x_strides, const Nd4jLong* y_strides, const Nd4jLong* z_strides, const Nd4jLong* coords, const Nd4jLong& rank) {
triple_size_t offset = { 0,0 ,0 };
size_t rank_4 = rank & -4;
for (int i = 0; i < rank_4; i += 4) {
offset.first = offset.first
+ coords[i] * x_strides[i]
+ coords[i + 1] * x_strides[i + 1]
+ coords[i + 2] * x_strides[i + 2]
+ coords[i + 3] * x_strides[i + 3];
offset.second = offset.second
+ coords[i] * y_strides[i]
+ coords[i + 1] * y_strides[i + 1]
+ coords[i + 2] * y_strides[i + 2]
+ coords[i + 3] * y_strides[i + 3];
offset.third = offset.third
+ coords[i] * z_strides[i]
+ coords[i + 1] * z_strides[i + 1]
+ coords[i + 2] * z_strides[i + 2]
+ coords[i + 3] * z_strides[i + 3];
}
for (int i = rank_4; i < rank; i++) {
offset.first += coords[i] * x_strides[i];
offset.second += coords[i] * y_strides[i];
offset.third += coords[i] * z_strides[i];
}
return offset;
}
template<bool Last_Index_Faster = true>
FORCEINLINE Nd4jLong getLength(const Nd4jLong* bases, int rank, int skip = 0)
{
if (skip < 0 || skip >= rank) skip = 0;
Nd4jLong total = 1;
for (int i = 0; i < rank - skip; i++) {
total *= bases[i];
}
return total;
}
template<>
FORCEINLINE Nd4jLong getLength<false>(const Nd4jLong* bases, int rank, int skip)
{
if (skip < 0 || skip >= rank) skip = 0;
Nd4jLong total = 1;
for (int i = skip; i < rank; i++) {
total *= bases[i];
}
return total;
}
template<bool Last_Index_Faster = true>
FORCEINLINE Nd4jLong getLength(const Nd4jLong* bases, int rank, int skip, Nd4jLong& outSkippedLength)
{
if (skip < 0 || skip >= rank) skip = 0;
Nd4jLong total = 1;
for (int i = 0; i < rank - skip; i++) {
total *= bases[i];
}
if (skip > 0) {
outSkippedLength = 1;
for (int i = rank - skip; i < rank; i++) {
outSkippedLength *= bases[i];
}
}
else {
outSkippedLength = 0;
}
return total;
}
template<>
FORCEINLINE Nd4jLong getLength<false>(const Nd4jLong* bases, int rank, int skip, Nd4jLong& outSkippedLength)
{
if (skip < 0 || skip >= rank) skip = 0;
if (skip > 0) {
outSkippedLength = 1;
for (int i = 0; i < skip; i++) {
outSkippedLength *= bases[i];
}
}
else {
outSkippedLength = 0;
}
Nd4jLong total = 1;
for (int i = skip; i < rank; i++) {
total *= bases[i];
}
return total;
}
/*
for ODR rule it willbe declared as inline
rePartition for reductions and et cet
Indices mentioned in the dimension list will be moved to the tail
This way it will be splitted into two parts
the first part will contain output part,the second tail part will be used for reductions and other purposes
if squash is True then it will attempt to minimize the output ( for both orders) and the tail
*/
FORCEINLINE void rePartition(char order, const std::vector<int>& dimensions, const size_t rank, const Nd4jLong* bases, const Nd4jLong* strides, Nd4jLong(&new_bases)[MAX_RANK], Nd4jLong(&new_strides)[MAX_RANK], int& first_begin, int& first_end, int& second_begin, int& second_end, bool first_squash = false, bool second_squash = true) {
bool indices[MAX_RANK] = {};
int ind = 0;
size_t second_rank;
if (dimensions.size() == 0 || (dimensions.size() == 1 && dimensions.at(0) == sd::DataTypeUtils::max<int>())){
first_end = 0;
first_begin = 0;
//treat it as the whole
for (int i = 0; i < rank; i++) {
new_bases[i] = bases[i];
new_strides[i] = strides[i];
}
second_rank = rank;
second_end = rank;
second_begin = 0;
}
else {
for (int index : dimensions) {
if (index < 0) index = rank + index;
if (index >= 0 && index < rank) {
indices[index] = true;
}
}
//move output ones and
for (int i = 0; i < rank; i++) {
if (!indices[i]) {
new_bases[ind] = bases[i];
new_strides[ind] = strides[i];
ind++;
}
}
int first_rank = ind;
first_end = ind;
first_begin = 0;
//nd4j_printf("rffrr ss & %d ind-- %d %d\n", first_rank, first_begin, first_end);
//squash output rank
if (first_squash && first_rank > 1) {
if (order == 'c') {
int uniq_ind = first_end-1;
for (int i = first_end - 2; i >= first_begin; i--) {
if (new_strides[i] == new_bases[uniq_ind] * new_strides[uniq_ind]) {
new_bases[uniq_ind] = new_bases[i] * new_bases[uniq_ind];
new_strides[uniq_ind] = new_strides[uniq_ind];
--first_rank;
}
else {
--uniq_ind;
new_bases[uniq_ind] = new_bases[i];
new_strides[uniq_ind] = new_strides[i];
}
}
first_begin = first_end - first_rank;
}
else {
//squash fortran
int uniq_ind = 0;
for (int i = 1; i < first_end; i++) {
if (new_strides[i] == new_bases[uniq_ind] * new_strides[uniq_ind]) {
new_bases[uniq_ind] = new_bases[i] * new_bases[uniq_ind];
new_strides[uniq_ind] = new_strides[uniq_ind];
--first_rank;
}
else {
uniq_ind++;
new_bases[uniq_ind] = new_bases[i];
new_strides[uniq_ind] = new_strides[i];
}
}
first_end = first_begin + first_rank;
}
ind = first_end;
}
//nd4j_printf("rffrr ss & %d ind-- %d %d\n", first_rank, first_begin, first_end);
//move process indices
for (int i = 0; i < rank; i++) {
if (indices[i]) {
new_bases[ind] = bases[i];
new_strides[ind] = strides[i];
ind++;
}
}
second_rank = ind - first_end;
second_end = ind;
second_begin = first_end;
}
if (second_squash && second_rank > 1) {
if (order == 'c') {
int uniq_ind = second_end - 1;
for (int i = second_end - 2; i >= second_begin; i--) {
if (new_strides[i] == new_bases[uniq_ind] * new_strides[uniq_ind]) {
new_bases[uniq_ind] = new_bases[i] * new_bases[uniq_ind];
new_strides[uniq_ind] = new_strides[uniq_ind];
--second_rank;
}
else {
--uniq_ind;
new_bases[uniq_ind] = new_bases[i];
new_strides[uniq_ind] = new_strides[i];
}
}
second_begin = second_end - second_rank;
}
else {
int uniq_ind = second_begin;
for (int i = second_begin+1; i < second_end; i++) {
if (new_strides[i] == new_bases[uniq_ind] * new_strides[uniq_ind]) {
new_bases[uniq_ind] = new_bases[i] * new_bases[uniq_ind];
new_strides[uniq_ind] = new_strides[uniq_ind];
--second_rank;
}
else {
uniq_ind++;
new_bases[uniq_ind] = new_bases[i];
new_strides[uniq_ind] = new_strides[i];
}
}
second_end = second_begin + second_rank;
}
}
return;
}
//basic CRTP static polymorphism classes for offset increments
template<typename Derived>
struct CoordsBaseMovement {
void init(const Nd4jLong* bases, const Nd4jLong* strides1, const Nd4jLong* strides2, int rank, int start = 0) {
static_cast<Derived*>(this)->initImpl(bases, strides1, strides2, rank, start);
}
void increment(int skipRank = 0) {
static_cast<Derived*>(this)->incrementImpl(skipRank);
}
Nd4jLong First() { return static_cast<Derived*>(this)->FirstImpl(); };
Nd4jLong Second() { return static_cast<Derived*>(this)->SecondImpl(); };
};
struct ZipGenericCoordsRank1Stride1 : CoordsBaseMovement<ZipGenericCoordsRank1Stride1> {
size_t offset1;
size_t offset2;
void initImpl(const Nd4jLong* bases, const Nd4jLong* strides1, const Nd4jLong* strides2, int rank, int start = 0) {
offset1 = start;
offset2 = start;
}
void incrementImpl(int skipRank = 0) {
offset1 += 1;
offset2 += 1;
}
Nd4jLong FirstImpl() { return offset1; };
Nd4jLong SecondImpl() { return offset2; };
};
struct ZipGenericCoordsRank1BothStrideN : CoordsBaseMovement<ZipGenericCoordsRank1BothStrideN> {
size_t stride1;
size_t stride2;
size_t offset1;
size_t offset2;
void initImpl(const Nd4jLong* bases, const Nd4jLong* strides1, const Nd4jLong* strides2, int rank, int start = 0) {
stride1 = strides1[0];
stride2 = strides2[0];
offset1 = start * stride1;
offset2 = start * stride2;
}
void incrementImpl(int skipRank = 0) {
offset1 += stride1;
offset2 += stride2;
}
Nd4jLong FirstImpl() { return offset1; };
Nd4jLong SecondImpl() { return offset2; };
};
template<int ConstRank, bool LastIndexFaster = true>
struct ZipGenericCoordsConstMovementSecondStride1 : CoordsBaseMovement<ZipGenericCoordsConstMovementSecondStride1<ConstRank, LastIndexFaster>> {
sd::CoordsState<ConstRank - 1> cst;
Nd4jLong coords[MAX_RANK];
size_t offset1;
size_t offset2;
int _rank;
void initImpl(const Nd4jLong* bases, const Nd4jLong* strides1, const Nd4jLong* strides2, int rank, int start = 0) {
offset1 = sd::init_coords<ConstRank, 0, LastIndexFaster>(cst, start, bases, strides1);
offset2 = start * 1;
}
void incrementImpl(int skipRank = 0) {
offset1 = sd::inc_coords<ConstRank, 0, LastIndexFaster>(cst, offset1);
offset2 += 1;
}
Nd4jLong FirstImpl() { return offset1; };
Nd4jLong SecondImpl() { return offset2; };
};
template<int ConstRank, bool LastIndexFaster = true>
struct ZipGenericCoordsConstMovementSecondStrideN : CoordsBaseMovement<ZipGenericCoordsConstMovementSecondStrideN<ConstRank, LastIndexFaster>> {
sd::CoordsState<ConstRank - 1> cst;
Nd4jLong _stride2;
Nd4jLong coords[MAX_RANK];
size_t offset1;
size_t offset2;
int _rank;
void initImpl(const Nd4jLong* bases, const Nd4jLong* strides1, const Nd4jLong* strides2, int rank, int start = 0) {
_stride2 = strides2[0];
offset1 = sd::init_coords<ConstRank, 0, LastIndexFaster>(cst, start, bases, strides1);
offset2 = start * _stride2;
}
void incrementImpl(int skipRank = 0) {
offset1 = sd::inc_coords<ConstRank, 0, LastIndexFaster>(cst, offset1);
offset2 += _stride2;
}
Nd4jLong FirstImpl() { return offset1; };
Nd4jLong SecondImpl() { return offset2; };
};
template<bool LastIndexFaster = true>
struct ZipGenericCoordsMovementSecondStrideN : CoordsBaseMovement<ZipGenericCoordsMovementSecondStrideN<LastIndexFaster>> {
const Nd4jLong* _bases;
const Nd4jLong* _strides1;
Nd4jLong _stride2;
Nd4jLong coords[MAX_RANK];
zip_size_t offset;
int _rank;
void initImpl(const Nd4jLong* bases, const Nd4jLong* strides1, const Nd4jLong* strides2, int rank, int start = 0) {
_bases = bases;
_strides1 = strides1;
_stride2 = strides2[0];
_rank = rank;
if (start == 0) {
for (int i = 0; i < MAX_RANK; i++) {
coords[i] = 0;
}
offset = { 0,0 };
}
else {
if (LastIndexFaster) {
sd::index2coords_C(start, rank, bases, (Nd4jLong*)&coords);
}
else {
sd::index2coords_F(start, rank, bases, (Nd4jLong*)&coords);
}
offset.first = sd::offset_from_coords(strides1, (Nd4jLong*)&coords, rank);
offset.second = start * _stride2;
}
}
void incrementImpl(int skipRank = 0) {
offset.first = inc_coords<LastIndexFaster>(_bases, _strides1, (Nd4jLong*)&coords, offset.first, _rank, skipRank);
offset.second += _stride2;
}
Nd4jLong FirstImpl() { return offset.first; };
Nd4jLong SecondImpl() { return offset.second; };
};
template<bool LastIndexFaster = true>
struct ZipGenericCoordsMovement : CoordsBaseMovement<ZipGenericCoordsMovement<LastIndexFaster>> {
const Nd4jLong* _bases;
const Nd4jLong* _strides1;
const Nd4jLong* _strides2;
Nd4jLong coords[MAX_RANK];
zip_size_t offset;
int _rank;
void initImpl(const Nd4jLong* bases, const Nd4jLong* strides1, const Nd4jLong* strides2, int rank, int start = 0) {
_bases = bases;
_strides1 = strides1;
_strides2 = strides2;
_rank = rank;
if (start == 0) {
for (int i = 0; i < MAX_RANK; i++) {
coords[i] = 0;
}
offset = { 0,0 };
}
else {
if (LastIndexFaster) {
sd::index2coords_C(start, rank, bases, (Nd4jLong*)&coords);
}
else {
sd::index2coords_F(start, rank, bases, (Nd4jLong*)&coords);
}
offset = sd::offset_from_coords(strides1, strides2, (Nd4jLong*)&coords, rank);
}
}
void incrementImpl(int skipRank = 0) {
offset = inc_coords<LastIndexFaster>(_bases, _strides1, _strides2, (Nd4jLong*)&coords, offset, _rank, skipRank);
}
Nd4jLong FirstImpl() { return offset.first; };
Nd4jLong SecondImpl() { return offset.second; };
};
} }
#endif #endif

View File

@ -69,7 +69,7 @@ Nd4jLong IndexReduce<X, Y>::execScalar(const void *vx, const Nd4jLong *xShapeInf
for (int e = 0; e < maxThreads; e++) for (int e = 0; e < maxThreads; e++)
intermediatery[e].index = -1; intermediatery[e].index = -1;
if (xEws == 1) { if (xEws == 1 && shape::order(xShapeInfo) == 'c') {
auto func = PRAGMA_THREADS_FOR { auto func = PRAGMA_THREADS_FOR {
intermediatery[thread_id] = OpType::startingIndexValue(x); intermediatery[thread_id] = OpType::startingIndexValue(x);

View File

@ -188,7 +188,7 @@ namespace functions {
auto reductionBuffer = static_cast<X*>(vreductionBuffer); auto reductionBuffer = static_cast<X*>(vreductionBuffer);
auto order = shape::order(xShapeInfo); auto order = shape::order(xShapeInfo);
int tid = blockIdx.x * blockDim.x + threadIdx.x; int tid = blockIdx.x * blockDim.x + threadIdx.x;
__shared__ volatile int resultScalar; __shared__ volatile bool resultScalar;
//shared memory space for storing intermediate results //shared memory space for storing intermediate results
__shared__ IndexValue<X>* sPartials; __shared__ IndexValue<X>* sPartials;
@ -214,17 +214,10 @@ namespace functions {
zLen = shape::length(zShapeInfo); zLen = shape::length(zShapeInfo);
else zLen = 1; else zLen = 1;
if (dimensionLength == 1) {
if (zLen == 1 && (dimension == nullptr || dimension[0] == MAX_DIMENSION))
resultScalar = 1;
else
resultScalar = 0;
}
else
resultScalar = 0;
if (zLen == 1) if (zLen == 1)
resultScalar = 1; resultScalar = true;
else
resultScalar = false;
xLength = shape::length(xShapeInfo); xLength = shape::length(xShapeInfo);
} }

View File

@ -0,0 +1,95 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
// Created by Abdelrauf 2020 (based on argmax)
#include <system/op_boilerplate.h>
#if NOT_EXCLUDED(OP_argamax)
#include <ops/declarable/helpers/axis.h>
#include <ops/declarable/helpers/reductions.h>
#include <ops/declarable/CustomOperations.h>
#include <helpers/ConstantTadHelper.h>
namespace sd {
namespace ops {
DECLARE_TYPES(argamax) {
getOpDescriptor()
->setAllowedInputTypes({ ALL_FLOATS,ALL_INTS })
->setAllowedOutputTypes({ ALL_INTS });
}
CUSTOM_OP_IMPL(argamax, 1, 1, false, 0, -2) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
if (output->isEmpty())
return Status::OK();
auto axis = *block.getIArguments();
// axis might be dynamic (i.e. tf mode)
if (block.width() > 1 && axis.size() == 0) {
auto axisVector = INPUT_VARIABLE(1);
helpers::adjustAxis(input->rankOf(), axisVector, axis);
helpers::argAbsMax(*input, *output, axis);
}
else {
helpers::argAbsMax(*input, *output, axis);
}
STORE_RESULT(output);
return Status::OK();
}
DECLARE_SHAPE_FN(argamax) {
std::vector<int> dims;
if (block.width() == 1) {
dims = *block.getIArguments();
} else {
auto y = INPUT_VARIABLE(1);
dims = y->template asVectorT<int>();
}
auto keepDims = block.numB() ? B_ARG(0) : false;
auto dtype = block.numD() ? D_ARG(0) : DataType::INT64;
// we're resolving negative axis here
helpers::adjustAxis(shape::rank(inputShape->at(0)), dims);
auto in = inputShape->at(0);
for (auto d : dims) {
// we have special case here
if (d == sd::DataTypeUtils::max<int>())
continue;
REQUIRE_TRUE(d < shape::rank(in), 0, "ArgAmax: axis can't be above rank")
REQUIRE_TRUE(in[d + 1] != 0, 0, "ArgAmax: you can't reduce along axis with 0 in shape");
}
// special case - output is scalar
if (dims.empty() || (dims.size() == 1 && dims.at(0) == sd::DataTypeUtils::max<int>())) {
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(dtype));
}
return SHAPELIST(ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), dtype, keepDims, false, block.getWorkspace()));
}
}
}
#endif

View File

@ -0,0 +1,95 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
* Copyright (c) 2019 Konduit K.K.
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
// Created by Abdelrauf 2020 (based on argmax)
#include <system/op_boilerplate.h>
#if NOT_EXCLUDED(OP_argamin)
#include <ops/declarable/helpers/axis.h>
#include <ops/declarable/helpers/reductions.h>
#include <ops/declarable/CustomOperations.h>
#include <helpers/ConstantTadHelper.h>
namespace sd {
namespace ops {
DECLARE_TYPES(argamin) {
getOpDescriptor()
->setAllowedInputTypes({ ALL_FLOATS,ALL_INTS })
->setAllowedOutputTypes({ ALL_INTS });
}
CUSTOM_OP_IMPL(argamin, 1, 1, false, 0, -2) {
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
if (output->isEmpty())
return Status::OK();
auto axis = *block.getIArguments();
// axis might be dynamic (i.e. tf mode)
if (block.width() > 1 && axis.size() == 0) {
auto axisVector = INPUT_VARIABLE(1);
helpers::adjustAxis(input->rankOf(), axisVector, axis);
helpers::argAbsMin(*input, *output, axis);
}
else {
helpers::argAbsMin(*input, *output, axis);
}
STORE_RESULT(output);
return Status::OK();
}
DECLARE_SHAPE_FN(argamin) {
std::vector<int> dims;
if (block.width() == 1) {
dims = *block.getIArguments();
} else {
auto y = INPUT_VARIABLE(1);
dims = y->template asVectorT<int>();
}
auto keepDims = block.numB() ? B_ARG(0) : false;
auto dtype = block.numD() ? D_ARG(0) : DataType::INT64;
// we're resolving negative axis here
helpers::adjustAxis(shape::rank(inputShape->at(0)), dims);
auto in = inputShape->at(0);
for (auto d : dims) {
// we have special case here
if (d == sd::DataTypeUtils::max<int>())
continue;
REQUIRE_TRUE(d < shape::rank(in), 0, "ArgAmin: axis can't be above rank")
REQUIRE_TRUE(in[d + 1] != 0, 0, "ArgAmin: you can't reduce along axis with 0 in shape");
}
// special case - output is scalar
if (dims.empty() || (dims.size() == 1 && dims.at(0) == sd::DataTypeUtils::max<int>())) {
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(dtype));
}
return SHAPELIST(ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), dtype, keepDims, false, block.getWorkspace()));
}
}
}
#endif

View File

@ -1,6 +1,6 @@
/******************************************************************************* /*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc. * Copyright (c) 2015-2018 Skymind, Inc.
* * Copyright (c) 2019 Konduit K.K.
* This program and the accompanying materials are made available under the * This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at * terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0. * https://www.apache.org/licenses/LICENSE-2.0.
@ -22,6 +22,7 @@
#if NOT_EXCLUDED(OP_argmax) #if NOT_EXCLUDED(OP_argmax)
#include <ops/declarable/helpers/axis.h> #include <ops/declarable/helpers/axis.h>
#include <ops/declarable/helpers/reductions.h>
#include <ops/declarable/CustomOperations.h> #include <ops/declarable/CustomOperations.h>
#include <helpers/ConstantTadHelper.h> #include <helpers/ConstantTadHelper.h>
@ -29,7 +30,7 @@ namespace sd {
namespace ops { namespace ops {
DECLARE_TYPES(argmax) { DECLARE_TYPES(argmax) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(sd::DataType::ANY) ->setAllowedInputTypes({ ALL_FLOATS,ALL_INTS })
->setAllowedOutputTypes({ALL_INTS}); ->setAllowedOutputTypes({ALL_INTS});
} }
@ -37,18 +38,19 @@ namespace sd {
auto input = INPUT_VARIABLE(0); auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
auto axis = *block.getIArguments(); if (output->isEmpty())
return Status::OK();
auto axis = *block.getIArguments();
// axis might be dynamic (i.e. tf mode) // axis might be dynamic (i.e. tf mode)
if (block.width() > 1 && axis.size() == 0) { if (block.width() > 1 && axis.size() == 0) {
auto axisVector = INPUT_VARIABLE(1); auto axisVector = INPUT_VARIABLE(1);
helpers::adjustAxis(input->rankOf(), axisVector, axis); helpers::adjustAxis(input->rankOf(), axisVector, axis);
helpers::argMax(*input, *output, axis);
input->applyIndexReduce(indexreduce::IndexMax, *output, axis);
} else { } else {
helpers::adjustAxis(input->rankOf(), axis); helpers::argMax(*input, *output, axis);
input->applyIndexReduce(indexreduce::IndexMax, *output, axis);
} }
STORE_RESULT(output); STORE_RESULT(output);
@ -66,23 +68,28 @@ namespace sd {
dims = y->template asVectorT<int>(); dims = y->template asVectorT<int>();
} }
auto keepDims = block.numB() ? B_ARG(0) : false;
auto dtype = block.numD() ? D_ARG(0) : DataType::INT64;
// we're resolving negative axis here // we're resolving negative axis here
helpers::adjustAxis(shape::rank(inputShape->at(0)), dims); helpers::adjustAxis(shape::rank(inputShape->at(0)), dims);
if (dims.size() > 1) auto in = inputShape->at(0);
std::sort(dims.begin(), dims.end()); for (auto d : dims) {
// we have special case here
if (d == sd::DataTypeUtils::max<int>())
continue;
REQUIRE_TRUE(d < shape::rank(in), 0, "ArgMax: axis can't be above rank")
for (auto d:dims) { REQUIRE_TRUE(in[d + 1] != 0, 0, "ArgMax: you can't reduce along axis with 0 in shape");
REQUIRE_TRUE(inputShape->at(0)[d+1] != 0, 0, "ArgMax: you can't reduce along axis with 0 in shape");
} }
// special case - output is scalar // special case - output is scalar
if (dims.size() == 0 || (dims.size() == 1 && dims.at(0) == sd::DataTypeUtils::max<int>())) { if (dims.empty() || (dims.size() == 1 && dims.at(0) == sd::DataTypeUtils::max<int>())) {
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(sd::DataType::INT64)); return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(dtype));
} }
return SHAPELIST(ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), DataType::INT64, false, false, block.getWorkspace())); return SHAPELIST(ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), dtype, keepDims, false, block.getWorkspace()));
} }
} }
} }

View File

@ -21,15 +21,17 @@
#include <system/op_boilerplate.h> #include <system/op_boilerplate.h>
#if NOT_EXCLUDED(OP_argmin) #if NOT_EXCLUDED(OP_argmin)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/axis.h> #include <ops/declarable/helpers/axis.h>
#include <ops/declarable/helpers/reductions.h>
#include <ops/declarable/CustomOperations.h>
#include <helpers/ConstantTadHelper.h>
namespace sd { namespace sd {
namespace ops { namespace ops {
DECLARE_TYPES(argmin) { DECLARE_TYPES(argmin) {
getOpDescriptor() getOpDescriptor()
->setAllowedInputTypes(sd::DataType::ANY) ->setAllowedInputTypes({ ALL_FLOATS,ALL_INTS })
->setAllowedOutputTypes({ALL_INTS}); ->setAllowedOutputTypes({ALL_INTS});
} }
@ -39,16 +41,18 @@ namespace sd {
auto output = OUTPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0);
if (output->isEmpty())
return Status::OK();
// axis might be dynamic (i.e. tf mode) // axis might be dynamic (i.e. tf mode)
if (block.width() > 1 && axis.size() == 0) { if (block.width() > 1 && axis.size() == 0) {
auto axisVector = INPUT_VARIABLE(1); auto axisVector = INPUT_VARIABLE(1);
helpers::adjustAxis(input->rankOf(), axisVector, axis); helpers::adjustAxis(input->rankOf(), axisVector, axis);
helpers::argMin(*input, *output, axis);
}
else {
helpers::argMin(*input, *output, axis);
input->applyIndexReduce(indexreduce::IndexMin, *output, axis);
} else {
helpers::adjustAxis(input->rankOf(), axis);
input->applyIndexReduce(indexreduce::IndexMin, *output, axis);
} }
STORE_RESULT(output); STORE_RESULT(output);
@ -58,7 +62,7 @@ namespace sd {
DECLARE_SHAPE_FN(argmin) { DECLARE_SHAPE_FN(argmin) {
std::vector<int> dims; std::vector<int> dims;
auto in = inputShape->at(0);
if (block.width() == 1) { if (block.width() == 1) {
dims = *block.getIArguments(); dims = *block.getIArguments();
} else { } else {
@ -66,23 +70,28 @@ namespace sd {
dims = y->template asVectorT<int>(); dims = y->template asVectorT<int>();
} }
auto keepDims = block.numB() ? B_ARG(0) : false;
auto dtype = block.numD() ? D_ARG(0) : DataType::INT64;
// we're resolving negative axis here // we're resolving negative axis here
helpers::adjustAxis(shape::rank(in), dims); helpers::adjustAxis(shape::rank(inputShape->at(0)), dims);
if (dims.size() > 1) auto in = inputShape->at(0);
std::sort(dims.begin(), dims.end()); for (auto d : dims) {
// we have special case here
if (d == sd::DataTypeUtils::max<int>())
continue;
for (auto d:dims) { REQUIRE_TRUE(d < shape::rank(in), 0, "ArgMin: axis can't be above rank")
REQUIRE_TRUE(inputShape->at(0)[d+1] != 0, 0, "ArgMin: you can't reduce along axis with 0 in shape"); REQUIRE_TRUE(in[d + 1] != 0, 0, "ArgMin: you can't reduce along axis with 0 in shape");
} }
// special case - output is scalar // special case - output is scalar
if (dims.size() == 0 || (dims.size() == 1 && dims.at(0) == sd::DataTypeUtils::max<int>())) { if (dims.empty() || (dims.size() == 1 && dims.at(0) == sd::DataTypeUtils::max<int>())) {
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT64)); return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(dtype));
} }
auto newShape = ShapeUtils::evalReduceShapeInfo('c', dims, in, DataType::INT64, false, false, block.getWorkspace()); return SHAPELIST(ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), dtype, keepDims, false, block.getWorkspace()));
return SHAPELIST(newShape);
} }
} }

View File

@ -52,6 +52,32 @@ namespace sd {
DECLARE_CUSTOM_OP(argmin, 1, 1, false, 0, -2); DECLARE_CUSTOM_OP(argmin, 1, 1, false, 0, -2);
#endif #endif
/**
* This operation returns index of absolute max element in a given NDArray (optionally: along given dimension(s))
* Expected input:
* 0: N-dimensional array
* 1: optional axis vector
*
* Int args:
* 0: optional axis
*/
#if NOT_EXCLUDED(OP_argamax)
DECLARE_CUSTOM_OP(argamax, 1, 1, false, 0, -2);
#endif
/**
* This operation returns index of absolute min element in a given NDArray (optionally: along given dimension(s))
* Expected input:
* 0: N-dimensional array
* 1: optional axis vector
*
* Int args:
* 0: optional axis
*/
#if NOT_EXCLUDED(OP_argamin)
DECLARE_CUSTOM_OP(argamin, 1, 1, false, 0, -2);
#endif
/** /**
* This operation provides various normalization modes: * This operation provides various normalization modes:
* 0: frobenius * 0: frobenius

View File

@ -0,0 +1,28 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author AbdelRauf
//
#include <ops/declarable/helpers/cpu/indexReductions.hpp>
namespace sd {
namespace ops {
namespace helpers {
BUILD_DOUBLE_TEMPLATE(template void argAbsMax_, (const NDArray& input, NDArray& output, const std::vector<int>& dimensions), LIBND4J_TYPES_@FL_TYPE_INDEX@, INDEXING_TYPES);
}
}
}

View File

@ -0,0 +1,28 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author AbdelRauf
//
#include <ops/declarable/helpers/cpu/indexReductions.hpp>
namespace sd {
namespace ops {
namespace helpers {
BUILD_DOUBLE_TEMPLATE(template void argAbsMin_, (const NDArray& input, NDArray& output, const std::vector<int>& dimensions), LIBND4J_TYPES_@FL_TYPE_INDEX@, INDEXING_TYPES);
}
}
}

View File

@ -0,0 +1,28 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author AbdelRauf
//
#include <ops/declarable/helpers/cpu/indexReductions.hpp>
namespace sd {
namespace ops {
namespace helpers {
BUILD_DOUBLE_TEMPLATE(template void argMax_, (const NDArray& input, NDArray& output, const std::vector<int>& dimensions), LIBND4J_TYPES_@FL_TYPE_INDEX@, INDEXING_TYPES);
}
}
}

View File

@ -0,0 +1,28 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author AbdelRauf
//
#include <ops/declarable/helpers/cpu/indexReductions.hpp>
namespace sd {
namespace ops {
namespace helpers {
BUILD_DOUBLE_TEMPLATE(template void argMin_, (const NDArray& input, NDArray& output, const std::vector<int>& dimensions), LIBND4J_TYPES_@FL_TYPE_INDEX@, INDEXING_TYPES);
}
}
}

View File

@ -19,7 +19,7 @@
// //
#include <ops/declarable/helpers/crop_and_resize.h> #include <ops/declarable/helpers/crop_and_resize.h>
#include "../crop_and_resize.hpp" #include "ops/declarable/helpers/cpu/crop_and_resize.hpp"
namespace sd { namespace sd {
namespace ops { namespace ops {

View File

@ -19,7 +19,7 @@
// //
#include <ops/declarable/helpers/crop_and_resize.h> #include <ops/declarable/helpers/crop_and_resize.h>
#include "../crop_and_resize.hpp" #include "ops/declarable/helpers/cpu/crop_and_resize.hpp"
namespace sd { namespace sd {
namespace ops { namespace ops {

View File

@ -19,7 +19,7 @@
// //
#include <ops/declarable/helpers/crop_and_resize.h> #include <ops/declarable/helpers/crop_and_resize.h>
#include "../crop_and_resize.hpp" #include "ops/declarable/helpers/cpu/crop_and_resize.hpp"
namespace sd { namespace sd {
namespace ops { namespace ops {

View File

@ -19,7 +19,7 @@
// //
#include <ops/declarable/helpers/crop_and_resize.h> #include <ops/declarable/helpers/crop_and_resize.h>
#include "../crop_and_resize.hpp" #include "ops/declarable/helpers/cpu/crop_and_resize.hpp"
namespace sd { namespace sd {
namespace ops { namespace ops {

View File

@ -19,7 +19,7 @@
// //
#include <ops/declarable/helpers/crop_and_resize.h> #include <ops/declarable/helpers/crop_and_resize.h>
#include "../crop_and_resize.hpp" #include "ops/declarable/helpers/cpu/crop_and_resize.hpp"
namespace sd { namespace sd {
namespace ops { namespace ops {

View File

@ -19,7 +19,7 @@
// //
#include <ops/declarable/helpers/crop_and_resize.h> #include <ops/declarable/helpers/crop_and_resize.h>
#include "../crop_and_resize.hpp" #include "ops/declarable/helpers/cpu/crop_and_resize.hpp"
namespace sd { namespace sd {
namespace ops { namespace ops {

View File

@ -19,7 +19,7 @@
// //
#include <ops/declarable/helpers/crop_and_resize.h> #include <ops/declarable/helpers/crop_and_resize.h>
#include "../crop_and_resize.hpp" #include "ops/declarable/helpers/cpu/crop_and_resize.hpp"
namespace sd { namespace sd {
namespace ops { namespace ops {

View File

@ -19,7 +19,7 @@
// //
#include <ops/declarable/helpers/crop_and_resize.h> #include <ops/declarable/helpers/crop_and_resize.h>
#include "../crop_and_resize.hpp" #include "ops/declarable/helpers/cpu/crop_and_resize.hpp"
namespace sd { namespace sd {
namespace ops { namespace ops {

View File

@ -19,7 +19,7 @@
// //
#include <ops/declarable/helpers/crop_and_resize.h> #include <ops/declarable/helpers/crop_and_resize.h>
#include "../crop_and_resize.hpp" #include "ops/declarable/helpers/cpu/crop_and_resize.hpp"
namespace sd { namespace sd {
namespace ops { namespace ops {

View File

@ -19,7 +19,7 @@
// //
#include <ops/declarable/helpers/crop_and_resize.h> #include <ops/declarable/helpers/crop_and_resize.h>
#include "../crop_and_resize.hpp" #include "ops/declarable/helpers/cpu/crop_and_resize.hpp"
namespace sd { namespace sd {
namespace ops { namespace ops {

View File

@ -0,0 +1,56 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author AbdelRauf
//
#include <ops/declarable/helpers/reductions.h>
namespace sd {
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Z>
void argMax_(const NDArray& input, NDArray& output, const std::vector<int>& dimensions);
template<typename X, typename Z>
void argMin_(const NDArray& input, NDArray& output, const std::vector<int>& dimensions);
template<typename X, typename Z>
void argAbsMax_(const NDArray& input, NDArray& output, const std::vector<int>& dimensions);
template<typename X, typename Z>
void argAbsMin_(const NDArray& input, NDArray& output, const std::vector<int>& dimensions);
//////////////////////////////////////////////////////////////////////////
void argMax(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), argMax_, (input, output, dimensions), LIBND4J_TYPES, INDEXING_TYPES);
}
void argMin(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), argMin_, (input, output, dimensions), LIBND4J_TYPES, INDEXING_TYPES);
}
void argAbsMax(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), argAbsMax_, (input, output, dimensions), LIBND4J_TYPES, INDEXING_TYPES);
}
void argAbsMin(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), argAbsMin_, (input, output, dimensions), LIBND4J_TYPES, INDEXING_TYPES);
}
}
}
}

View File

@ -0,0 +1,900 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author AbdelRauf
//
#include <type_traits>
#include <cmath>
#include <stdexcept>
#include <memory>
#include <execution/Threads.h>
#include <execution/ThreadPool.h>
#include <helpers/LoopsCoordsHelper.h>
#include <ops/declarable/helpers/reductions.h>
#if 1
#define LOG_CALLS(X)
#else
#define LOG_CALLS(X) nd4j_printf("___%s_________%d+\n", __PRETTY_FUNCTION__, X);
#endif
namespace sd {
namespace ops {
namespace helpers {
constexpr int threadingThreshold = 4096;
template<typename X, typename Z, typename ReductionOp>
FORCEINLINE void indexInnerReductionRank1(const X* buffer, X& current, Z& argCurrent, const Nd4jLong& loopCount)
{
argCurrent = 0;
current = buffer[0];
LOG_CALLS(0)
Nd4jLong j_offset = 0;
for (Z j = 0; j < loopCount; j++) {
ReductionOp::update(current, argCurrent, buffer[j], j);
}
}
template<typename X, typename Z, typename ReductionOp>
FORCEINLINE void indexInnerReductionRank1(const X* buffer, X& current, Z& argCurrent, const Nd4jLong& loopCount, const Nd4jLong& inner_stride)
{
argCurrent = 0;
current = buffer[0];
LOG_CALLS(0)
Nd4jLong j_offset = 0;
for (Z j = 0; j < loopCount; j++) {
ReductionOp::update(current, argCurrent, buffer[j_offset], j);
j_offset += inner_stride;
}
}
template<typename X, typename Z, typename ReductionOp, size_t constRank, bool LastIndexFaster = true>
FORCEINLINE void indexInnerReductionConstRank(const X* buffer, X& current, Z& argCurrent, const Nd4jLong* bases, const Nd4jLong* strides, const Nd4jLong outerLoopCount, const Nd4jLong& innerLoopCount)
{
//skip 1 from the beginning or end depending the Order
constexpr size_t updated_index = LastIndexFaster ? 0 : 1;
constexpr size_t updated_rank = constRank - 1;
sd::CoordsState<updated_rank - 1> cst;
//we skip 1
size_t offset = sd::init_coords<updated_rank, 0, LastIndexFaster>(cst, 0, bases + updated_index, strides + updated_index);
Z startIndex = 0;
argCurrent = 0;
current = buffer[offset];
LOG_CALLS(0)
for (Z i = 0; i < outerLoopCount; i++) {
const X* inner_buffer = &(buffer[offset]);
//typename std::make_signed<Z>::type iArgMax = -1;
for (Z j = 0; j < innerLoopCount; j++) {
ReductionOp::update(current, argCurrent, inner_buffer[j], j + startIndex);
}
//we skip 1
offset = sd::inc_coords<updated_rank, 0, LastIndexFaster>(cst, offset);
startIndex += innerLoopCount;
}
}
template<typename X, typename Z, typename ReductionOp, size_t constRank, bool LastIndexFaster = true>
FORCEINLINE void indexInnerReductionConstRank(const X* buffer, X& current, Z& argCurrent, const Nd4jLong* bases, const Nd4jLong* strides, const Nd4jLong outerLoopCount, const Nd4jLong& innerLoopCount, const Nd4jLong& inner_stride)
{
//skip 1 from the beginning or end depending the Order
constexpr size_t updated_index = LastIndexFaster ? 0 : 1;
constexpr size_t updated_rank = constRank - 1;
sd::CoordsState<updated_rank - 1> cst;
//we skip 1
size_t offset = sd::init_coords<updated_rank, 0, LastIndexFaster>(cst, 0, bases + updated_index, strides + updated_index);
Z startIndex = 0;
argCurrent = 0;
current = buffer[offset];
LOG_CALLS(0)
for (Z i = 0; i < outerLoopCount; i++) {
const X* inner_buffer = &(buffer[offset]);
for (Z j = 0; j < innerLoopCount; j++) {
ReductionOp::update(current, argCurrent, *inner_buffer, j + startIndex);
inner_buffer += inner_stride;
}
//we alreaddy skiped
offset = sd::inc_coords<updated_rank, 0, LastIndexFaster>(cst, offset);
startIndex += innerLoopCount;
}
}
template<typename X, typename Z, typename ReductionOp, bool LastIndexFaster = true>
FORCEINLINE void indexInnerReduction(const int& rank, const X* buffer, X& current, Z& argCurrent, const Nd4jLong* bases, const Nd4jLong* strides, const Nd4jLong& outerLoopStart, const Nd4jLong& outerLoopStop, const Nd4jLong& innerLoopCount)
{
size_t offset = 0;
Nd4jLong outerLoopCount = outerLoopStop - outerLoopStart;
Nd4jLong coords[MAX_RANK] = {};
Nd4jLong* ptr_coords = (Nd4jLong*)&coords;
if (outerLoopStart > 0) {
sd::index2coords_C(outerLoopStart, rank - 1, bases, ptr_coords);
offset = sd::offset_from_coords(strides, ptr_coords, rank);
}
Z startIndex = outerLoopStart * innerLoopCount;
argCurrent = startIndex;
current = buffer[offset];
LOG_CALLS(0)
for (Z i = 0; i < outerLoopCount; i++) {
const X* inner_buffer = &(buffer[offset]);
//typename std::make_signed<Z>::type iArgMax = -1;
for (Z j = 0; j < innerLoopCount; j++) {
//nd4j_printf("%f\n", inner_buffer[j]);
ReductionOp::update(current, argCurrent, inner_buffer[j], j + startIndex);
}
offset = inc_coords<true>(bases, strides, ptr_coords, offset, rank, 1);
//if (iArgMax >= 0) argCurrent = startIndex + iArgMax;
startIndex += innerLoopCount;
}
}
template<typename X, typename Z, typename ReductionOp, bool LastIndexFaster = true>
FORCEINLINE void indexInnerReduction(const int& rank, const X* buffer, X& current, Z& argCurrent, const Nd4jLong* bases, const Nd4jLong* strides, const Nd4jLong& outerLoopStart, const Nd4jLong& outerLoopStop, const Nd4jLong& innerLoopCount, const Nd4jLong& inner_stride)
{
size_t offset = 0;
Nd4jLong outerLoopCount = outerLoopStop - outerLoopStart;
Nd4jLong coords[MAX_RANK] = {};
Nd4jLong* ptr_coords = (Nd4jLong*)&coords;
if (outerLoopStart > 0) {
sd::index2coords_C(outerLoopStart, rank - 1, bases, ptr_coords);
offset = sd::offset_from_coords(strides, ptr_coords, rank);
}
Z startIndex = outerLoopStart * innerLoopCount;
argCurrent = startIndex;
current = buffer[offset];
LOG_CALLS(0)
for (Z i = 0; i < outerLoopCount; i++) {
const X* inner_buffer = &(buffer[offset]);
//typename std::make_signed<Z>::type iArgMax = -1;
for (Z j = 0; j < innerLoopCount; j++) {
ReductionOp::update(current, argCurrent, inner_buffer[j * inner_stride], startIndex + j);
}
offset = inc_coords<true>(bases, strides, ptr_coords, offset, rank, 1);
//offset = inc_coords<LastIndexFaster>(bases, strides, ptr_coords, offset, rank, 1);
//if (iArgMax >= 0) argCurrent = startIndex + iArgMax;
startIndex += innerLoopCount;
}
}
template<typename X, typename Z, typename ReductionOp>
FORCEINLINE void indexInnerReductionRank1Block4WithMerge(const X* buffer, X& current, Z& argCurrent, const Nd4jLong& loopCount)
{
argCurrent = 0;
current = buffer[0];
LOG_CALLS(0)
Nd4jLong loopCount4 = loopCount / 4;
Nd4jLong loopCountEnd = loopCount4 + (loopCount & 3);
const X* buffer1 = buffer + 1 * loopCount4;
const X* buffer2 = buffer1 + 1 * loopCount4;
const X* buffer3 = buffer2 + 1 * loopCount4;
X current1 = *buffer1;
X current2 = *buffer2;
X current3 = *buffer3;
Z argCurrent1 = 0;
Z argCurrent2 = 0;
Z argCurrent3 = 0;
for (Z j = 0; j < loopCount4; j++) {
ReductionOp::update(current, argCurrent, buffer[j], j);
ReductionOp::update(current1, argCurrent1, buffer1[j], j);
ReductionOp::update(current2, argCurrent2, buffer2[j], j);
ReductionOp::update(current3, argCurrent3, buffer3[j], j);
}
//tail
for (Z j = loopCount4; j < loopCountEnd; j++) {
ReductionOp::update(current3, argCurrent3, buffer3[j], j);
}
//merge
argCurrent1 += loopCount4;
argCurrent2 += 2 * loopCount4;
argCurrent3 += 3 * loopCount4;
ReductionOp::update(current, argCurrent, current1, argCurrent1);
ReductionOp::update(current, argCurrent, current2, argCurrent2);
ReductionOp::update(current, argCurrent, current3, argCurrent3);
}
template<typename X, typename Z, typename ReductionOp>
FORCEINLINE void indexInnerReductionRank1Block4WithMerge(const X* buffer, X& current, Z& argCurrent, const Nd4jLong& loopCount, const Nd4jLong& inner_stride)
{
argCurrent = 0;
current = buffer[0];
LOG_CALLS(0)
Nd4jLong loopCount4 = loopCount / 4;
Nd4jLong loopCountEnd = loopCount4 + (loopCount & 3);
const X* buffer1 = buffer + inner_stride * loopCount4;
const X* buffer2 = buffer1 + inner_stride * loopCount4;
const X* buffer3 = buffer2 + inner_stride * loopCount4;
X current1 = *buffer1;
X current2 = *buffer2;
X current3 = *buffer3;
Z argCurrent1 = 0;
Z argCurrent2 = 0;
Z argCurrent3 = 0;
Nd4jLong j_offset = 0;
for (Z j = 0; j < loopCount4; j++) {
ReductionOp::update(current, argCurrent, buffer[j_offset], j);
ReductionOp::update(current1, argCurrent1, buffer1[j_offset], j);
ReductionOp::update(current2, argCurrent2, buffer2[j_offset], j);
ReductionOp::update(current3, argCurrent3, buffer3[j_offset], j);
j_offset += inner_stride;
}
//tail
for (Z j = loopCount4; j < loopCountEnd; j++) {
ReductionOp::update(current3, argCurrent3, buffer3[j_offset], j);
j_offset += inner_stride;
}
//merge
argCurrent1 += loopCount4;
argCurrent2 += 2 * loopCount4;
argCurrent3 += 3 * loopCount4;
ReductionOp::update(current, argCurrent, current1, argCurrent1);
ReductionOp::update(current, argCurrent, current2, argCurrent2);
ReductionOp::update(current, argCurrent, current3, argCurrent3);
}
template<typename X, typename Z, typename ReductionOp>
FORCEINLINE void indexInnerReductionRank1Block4(const X* buffer, const X* buffer1, const X* buffer2, const X* buffer3, Z* output, Z* output1, Z* output2, Z* output3, const Nd4jLong& loopCount)
{
LOG_CALLS(0)
Z argCurrent = 0;
Z argCurrent1 = 0;
Z argCurrent2 = 0;
Z argCurrent3 = 0;
X current = buffer[0];
X current1 = buffer1[0];
X current2 = buffer2[0];
X current3 = buffer3[0];
for (Z j = 0; j < loopCount; j++) {
ReductionOp::update(current, argCurrent, buffer[j], j);
ReductionOp::update(current1, argCurrent1, buffer1[j], j);
ReductionOp::update(current2, argCurrent2, buffer2[j], j);
ReductionOp::update(current3, argCurrent3, buffer3[j], j);
}
*output = argCurrent;
*output1 = argCurrent1;
*output2 = argCurrent2;
*output3 = argCurrent3;
return;
}
template<typename X, typename Z, typename ReductionOp>
FORCEINLINE void indexInnerReductionRank1Block4(const X* buffer, const X* buffer1, const X* buffer2, const X* buffer3, Z* output, Z* output1, Z* output2, Z* output3, const Nd4jLong& loopCount, const Nd4jLong& inner_stride)
{
LOG_CALLS(0)
Z argCurrent = 0;
Z argCurrent1 = 0;
Z argCurrent2 = 0;
Z argCurrent3 = 0;
X current = buffer[0];
X current1 = buffer1[0];
X current2 = buffer2[0];
X current3 = buffer3[0];
Nd4jLong j_offset = 0;
for (Z j = 0; j < loopCount; j++) {
ReductionOp::update(current, argCurrent, buffer[j_offset], j);
ReductionOp::update(current1, argCurrent1, buffer1[j_offset], j);
ReductionOp::update(current2, argCurrent2, buffer2[j_offset], j);
ReductionOp::update(current3, argCurrent3, buffer3[j_offset], j);
j_offset += inner_stride;
}
*output = argCurrent;
*output1 = argCurrent1;
*output2 = argCurrent2;
*output3 = argCurrent3;
return;
}
template<typename X, typename Z, typename ReductionOp, size_t constRank, bool LastIndexFaster = true>
FORCEINLINE void indexInnerReductionConstRankBlock4(const X* buffer, const X* buffer1, const X* buffer2, const X* buffer3,
Z* output, Z* output1, Z* output2, Z* output3, const Nd4jLong* bases, const Nd4jLong* strides,
const Nd4jLong& outerLoopCount, const Nd4jLong& innerLoopCount)
{
LOG_CALLS(0)
//skip 1 from the beginning or end depending the Order
constexpr size_t updated_index = LastIndexFaster ? 0 : 1;
constexpr size_t updated_rank = constRank - 1;
sd::CoordsState<updated_rank - 1> cst;
//we skip 1
size_t offset = sd::init_coords<updated_rank, 0, LastIndexFaster>(cst, 0, bases + updated_index, strides + updated_index);
Z startIndex = 0;
Z argCurrent = 0;
Z argCurrent1 = 0;
Z argCurrent2 = 0;
Z argCurrent3 = 0;
X current = buffer[0];
X current1 = buffer1[0];
X current2 = buffer2[0];
X current3 = buffer3[0];
//LOG_CALLS(0)
for (Z i = 0; i < outerLoopCount; i++) {
const X* inner_buffer = &(buffer[offset]);
const X* inner_buffer1 = &(buffer1[offset]);
const X* inner_buffer2 = &(buffer2[offset]);
const X* inner_buffer3 = &(buffer3[offset]);
//typename std::make_signed<Z>::type iArgMax = -1;
for (Z j = 0; j < innerLoopCount; j++) {
ReductionOp::update(current, argCurrent, inner_buffer[j], j + startIndex);
ReductionOp::update(current1, argCurrent1, inner_buffer1[j], j + startIndex);
ReductionOp::update(current2, argCurrent2, inner_buffer2[j], j + startIndex);
ReductionOp::update(current3, argCurrent3, inner_buffer3[j], j + startIndex);
}
//we skip 1
offset = sd::inc_coords<updated_rank, 0, LastIndexFaster>(cst, offset);
startIndex += innerLoopCount;
}
*output = argCurrent;
*output1 = argCurrent1;
*output2 = argCurrent2;
*output3 = argCurrent3;
return;
}
template<typename X, typename Z, typename ReductionOp, size_t constRank, bool LastIndexFaster = true>
FORCEINLINE void indexInnerReductionConstRankBlock4(const X* buffer, const X* buffer1, const X* buffer2, const X* buffer3,
Z* output, Z* output1, Z* output2, Z* output3, const Nd4jLong* bases, const Nd4jLong* strides,
const Nd4jLong& outerLoopCount, const Nd4jLong& innerLoopCount, const Nd4jLong& inner_stride)
{
LOG_CALLS(0)
//skip 1 from the beginning or end depending the Order
constexpr size_t updated_index = LastIndexFaster ? 0 : 1;
constexpr size_t updated_rank = constRank - 1;
sd::CoordsState<updated_rank - 1> cst;
//we skip 1
size_t offset = sd::init_coords<updated_rank, 0, LastIndexFaster>(cst, 0, bases + updated_index, strides + updated_index);
Z startIndex = 0;
Z argCurrent = 0;
Z argCurrent1 = 0;
Z argCurrent2 = 0;
Z argCurrent3 = 0;
X current = buffer[0];
X current1 = buffer1[0];
X current2 = buffer2[0];
X current3 = buffer3[0];
//LOG_CALLS(0)
for (Z i = 0; i < outerLoopCount; i++) {
const X* inner_buffer = &(buffer[offset]);
const X* inner_buffer1 = &(buffer1[offset]);
const X* inner_buffer2 = &(buffer2[offset]);
const X* inner_buffer3 = &(buffer3[offset]);
//typename std::make_signed<Z>::type iArgMax = -1;
Nd4jLong inner_offset = 0;
for (Z j = 0; j < innerLoopCount; j++) {
ReductionOp::update(current, argCurrent, inner_buffer[inner_offset], j + startIndex);
ReductionOp::update(current1, argCurrent1, inner_buffer1[inner_offset], j + startIndex);
ReductionOp::update(current2, argCurrent2, inner_buffer2[inner_offset], j + startIndex);
ReductionOp::update(current3, argCurrent3, inner_buffer3[inner_offset], j + startIndex);
inner_offset += inner_stride;
}
//we skip 1
offset = sd::inc_coords<updated_rank, 0, LastIndexFaster>(cst, offset);
startIndex += innerLoopCount;
}
*output = argCurrent;
*output1 = argCurrent1;
*output2 = argCurrent2;
*output3 = argCurrent3;
return;
}
template<typename X, typename Z, typename ReductionOp, bool LastIndexFaster = true>
void argIndexCase1Scalar(const int& second_rank,const Nd4jLong* inner_bases,const Nd4jLong* inner_strides, const X* bufferX, Z* outputZ)
{
Nd4jLong inner_total;
Nd4jLong inner_last = 0;
int maxThreads = sd::Environment::getInstance()->maxMasterThreads();
if (second_rank == 1) {
inner_total = inner_bases[0];
if (inner_total < threadingThreshold) {
maxThreads = 1;
}
}
else {
inner_total = getLength<LastIndexFaster>(inner_bases, second_rank, 1, inner_last);
if (inner_total * inner_last < threadingThreshold) {
maxThreads = 1;
}
}
std::unique_ptr<X[]> maxValues(new X[maxThreads]);
std::unique_ptr<Z[]> maxIndices(new Z[maxThreads]);
X* ptrMaxValues = maxValues.get();
Z* ptrMaxIndices = maxIndices.get();
auto func = [ptrMaxValues, ptrMaxIndices, inner_last, second_rank, inner_bases, inner_strides, bufferX](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void {
//LOG_CALLS(0)
const Nd4jLong inner_stride = LastIndexFaster ? inner_strides[second_rank - 1] : inner_strides[0];
Z argCurrent; X current;
if (second_rank == 1) {
const Nd4jLong loopTotal = stop - start;
if (inner_stride == 1) {
indexInnerReductionRank1Block4WithMerge<X, Z, ReductionOp>(&(bufferX[start]), current, argCurrent, loopTotal);
}
else {
indexInnerReductionRank1Block4WithMerge<X, Z, ReductionOp>(&(bufferX[start * inner_stride]), current, argCurrent, loopTotal, inner_stride);
}
ptrMaxIndices[thread_id] = argCurrent + start;
}
else {
if (inner_stride == 1) {
indexInnerReduction<X, Z, ReductionOp, LastIndexFaster>(second_rank, bufferX, current, argCurrent, inner_bases, inner_strides, start, stop, inner_last, inner_stride);
}
else {
indexInnerReduction<X, Z, ReductionOp, LastIndexFaster>(second_rank, bufferX, current, argCurrent, inner_bases, inner_strides, start, stop, inner_last, inner_stride);
}
ptrMaxIndices[thread_id] = argCurrent;
}
ptrMaxValues[thread_id] = current;
};
#if 0
int Count = 0;
func(0, 0, inner_total, 1);
#else
int Count = samediff::Threads::parallel_tad(func, 0, inner_total, 1, maxThreads);
#endif
Z arg = 0;
X current = ptrMaxValues[0];
for (Z i = 1; i < Count; i++) {
ReductionOp::update(current, arg, ptrMaxValues[i], i);
}
*outputZ = ptrMaxIndices[arg];
}
template<typename X, typename Z, typename ReductionOp, typename Movement, bool LastIndexFaster = true>
void argReductionInnerCases(Movement& movement, Nd4jLong loopTotal, const int& second_rank,const Nd4jLong* inner_bases,const Nd4jLong* inner_strides, const X* bufferX, Z* outputZ)
{
Nd4jLong inner_stride = true /*LastIndexFaster*/ ? inner_strides[second_rank - 1] : inner_strides[0];
Nd4jLong loopTotal_K = loopTotal / 4;
Nd4jLong loopTotal_Tail = loopTotal & 3;
if (inner_stride == 1) {
if (second_rank == 1) {
LOG_CALLS(0)
Nd4jLong inner_total = getLength<true>(inner_bases, second_rank);
for (Nd4jLong i = 0; i < loopTotal_K; i++) {
const X* buffer0 = &(bufferX[movement.First()]);
Z* output0 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer1 = &(bufferX[movement.First()]);
Z* output1 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer2 = &(bufferX[movement.First()]);
Z* output2 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer3 = &(bufferX[movement.First()]);
Z* output3 = &(outputZ[movement.Second()]);
movement.increment();
indexInnerReductionRank1Block4<X, Z, ReductionOp>(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_total);
}
if (inner_total >= 2048) {
for (Nd4jLong i = 0; i < loopTotal_Tail; i++) {
X current;
const X* buffer0 = &(bufferX[movement.First()]);
indexInnerReductionRank1Block4WithMerge<X, Z, ReductionOp>(buffer0, current, outputZ[movement.Second()], inner_total);
movement.increment();
}
}
else {
for (Nd4jLong i = 0; i < loopTotal_Tail; i++) {
X current;
const X* buffer0 = &(bufferX[movement.First()]);
indexInnerReductionRank1<X, Z, ReductionOp>(buffer0, current, outputZ[movement.Second()], inner_total);
movement.increment();
}
}
}
else {
Nd4jLong inner_last;
Nd4jLong inner_loop = getLength<true>(inner_bases, second_rank, 1, inner_last);
if (second_rank == 2) {
LOG_CALLS(1)
for (Nd4jLong i = 0; i < loopTotal_K; i++) {
const X* buffer0 = &(bufferX[movement.First()]);
Z* output0 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer1 = &(bufferX[movement.First()]);
Z* output1 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer2 = &(bufferX[movement.First()]);
Z* output2 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer3 = &(bufferX[movement.First()]);
Z* output3 = &(outputZ[movement.Second()]);
movement.increment();
indexInnerReductionConstRankBlock4<X, Z, ReductionOp, 2>(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_bases, inner_strides,
inner_loop, inner_last);
}
for (Nd4jLong i = 0; i < loopTotal_Tail; i++) {
X current;
const X* buffer0 = &(bufferX[movement.First()]);
indexInnerReductionConstRank<X, Z, ReductionOp, 2>(buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides, inner_loop, inner_last);
movement.increment();
}
}
else if (second_rank == 3) {
LOG_CALLS(2)
for (Nd4jLong i = 0; i < loopTotal_K; i++) {
const X* buffer0 = &(bufferX[movement.First()]);
Z* output0 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer1 = &(bufferX[movement.First()]);
Z* output1 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer2 = &(bufferX[movement.First()]);
Z* output2 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer3 = &(bufferX[movement.First()]);
Z* output3 = &(outputZ[movement.Second()]);
movement.increment();
indexInnerReductionConstRankBlock4<X, Z, ReductionOp, 3>(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_bases, inner_strides,
inner_loop, inner_last);
}
for (Nd4jLong i = 0; i < loopTotal_Tail; i++) {
X current;
const X* buffer0 = &(bufferX[movement.First()]);
indexInnerReductionConstRank<X, Z, ReductionOp, 3>(buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides,
inner_loop, inner_last);
movement.increment();
}
}
else {
LOG_CALLS(3)
//nd4j_printf("-----%d \n", loopTotal);
for (Nd4jLong i = 0; i < loopTotal; i++) {
X current;
const X* buffer0 = &(bufferX[movement.First()]);
indexInnerReduction<X, Z, ReductionOp>(second_rank, buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides, 0,
inner_loop, inner_last);
movement.increment();
}
}
}
}
else {
if (second_rank == 1) {
LOG_CALLS(10)
Nd4jLong inner_total = getLength<true>(inner_bases, second_rank);
for (Nd4jLong i = 0; i < loopTotal_K; i++) {
const X* buffer0 = &(bufferX[movement.First()]);
Z* output0 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer1 = &(bufferX[movement.First()]);
Z* output1 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer2 = &(bufferX[movement.First()]);
Z* output2 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer3 = &(bufferX[movement.First()]);
Z* output3 = &(outputZ[movement.Second()]);
movement.increment();
indexInnerReductionRank1Block4<X, Z, ReductionOp>(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_total, inner_stride);
}
if (inner_total >= 2048) {
for (Nd4jLong i = 0; i < loopTotal_Tail; i++) {
X current;
const X* buffer0 = &(bufferX[movement.First()]);
indexInnerReductionRank1Block4WithMerge<X, Z, ReductionOp>(buffer0, current, outputZ[movement.Second()], inner_total, inner_stride);
movement.increment();
}
}
else {
for (Nd4jLong i = 0; i < loopTotal_Tail; i++) {
X current;
const X* buffer0 = &(bufferX[movement.First()]);
indexInnerReductionRank1<X, Z, ReductionOp>(buffer0, current, outputZ[movement.Second()], inner_total, inner_stride);
movement.increment();
}
}
}
else {
Nd4jLong inner_last;
Nd4jLong inner_loop = getLength<true>(inner_bases, second_rank, 1, inner_last);
if (second_rank == 2) {
LOG_CALLS(11)
for (Nd4jLong i = 0; i < loopTotal_K; i++) {
const X* buffer0 = &(bufferX[movement.First()]);
Z* output0 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer1 = &(bufferX[movement.First()]);
Z* output1 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer2 = &(bufferX[movement.First()]);
Z* output2 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer3 = &(bufferX[movement.First()]);
Z* output3 = &(outputZ[movement.Second()]);
movement.increment();
indexInnerReductionConstRankBlock4<X, Z, ReductionOp, 2>(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_bases, inner_strides,
inner_loop, inner_last, inner_stride);
}
for (Nd4jLong i = 0; i < loopTotal_Tail; i++) {
X current;
const X* buffer0 = &(bufferX[movement.First()]);
indexInnerReductionConstRank<X, Z, ReductionOp, 2>(buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides,
inner_loop, inner_last, inner_stride);
movement.increment();
}
}
else if (second_rank == 3) {
LOG_CALLS(12)
for (Nd4jLong i = 0; i < loopTotal_K; i++) {
const X* buffer0 = &(bufferX[movement.First()]);
Z* output0 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer1 = &(bufferX[movement.First()]);
Z* output1 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer2 = &(bufferX[movement.First()]);
Z* output2 = &(outputZ[movement.Second()]);
movement.increment();
const X* buffer3 = &(bufferX[movement.First()]);
Z* output3 = &(outputZ[movement.Second()]);
movement.increment();
indexInnerReductionConstRankBlock4<X, Z, ReductionOp, 3>(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_bases, inner_strides,
inner_loop, inner_last, inner_stride);
}
for (Nd4jLong i = 0; i < loopTotal_Tail; i++) {
X current;
const X* buffer0 = &(bufferX[movement.First()]);
indexInnerReductionConstRank<X, Z, ReductionOp, 3>(buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides,
inner_loop, inner_last, inner_stride);
movement.increment();
}
}
else {
LOG_CALLS(13)
//nd4j_printf("-------%d inner loop %d inner_last %d\n", loopTotal, inner_loop,inner_last);
for (Nd4jLong i = 0; i < loopTotal; i++) {
X current;
const X* buffer0 = &(bufferX[movement.First()]);
indexInnerReduction<X, Z, ReductionOp>(second_rank, buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides, 0,
inner_loop, inner_last, inner_stride);
movement.increment();
}
}
}
}
}
template<typename X, typename Z, typename ReductionOp, bool LastIndexFaster = true>
void argIndexCaseNonScalar(const int& first_rank, const int& output_rank, bool squashed, const int& second_rank,
const Nd4jLong*& outer_bases,const Nd4jLong* outer_strides,const Nd4jLong* output_strides, const Nd4jLong &output_stride,
const Nd4jLong*& inner_bases,const Nd4jLong* inner_strides, const X* bufferX, Z* outputZ)
{
Nd4jLong total = getLength<LastIndexFaster>(outer_bases, first_rank);
Nd4jLong inner_stride = true /*LastIndexFaster*/ ? inner_strides[second_rank - 1] : inner_strides[0];
Nd4jLong outer_stride = LastIndexFaster ? outer_strides[second_rank - 1] : outer_strides[0];
auto func = [first_rank, output_rank, squashed, outer_bases, outer_strides, output_strides, output_stride, second_rank, inner_bases, inner_strides, bufferX, outputZ](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void {
Nd4jLong loopTotal = stop - start;
Nd4jLong stride = LastIndexFaster ? outer_strides[first_rank - 1] : outer_strides[0];
if (first_rank == 1) {
if (stride == 1) {
ZipGenericCoordsRank1Stride1 movement;
movement.init(nullptr, nullptr, nullptr, 0, start);
argReductionInnerCases<X, Z, ReductionOp>(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ);
}
else {
ZipGenericCoordsRank1BothStrideN movement;
movement.init(nullptr, &stride, &output_stride, 0, start);
argReductionInnerCases<X, Z, ReductionOp>(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ);
}
}
else if (squashed && first_rank <= output_rank) {
if (first_rank == 2) {
if (output_stride == 1) {
ZipGenericCoordsConstMovementSecondStride1<2, LastIndexFaster> movement;
movement.init(outer_bases, outer_strides, nullptr, first_rank, start);
argReductionInnerCases<X, Z, ReductionOp>(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ);
}
else {
ZipGenericCoordsConstMovementSecondStrideN<2, LastIndexFaster> movement;
movement.init(outer_bases, outer_strides, &output_stride, first_rank, start);
argReductionInnerCases<X, Z, ReductionOp>(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ);
}
}
else if (first_rank == 3) {
if (output_stride == 1) {
ZipGenericCoordsConstMovementSecondStride1<3, LastIndexFaster> movement;
movement.init(outer_bases, outer_strides, nullptr, first_rank, start);
argReductionInnerCases<X, Z, ReductionOp>(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ);
}
else {
ZipGenericCoordsConstMovementSecondStrideN<3, LastIndexFaster> movement;
movement.init(outer_bases, outer_strides, &output_stride, first_rank, start);
argReductionInnerCases<X, Z, ReductionOp>(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ);
}
}
else {
ZipGenericCoordsMovementSecondStrideN< LastIndexFaster> movement;
movement.init(outer_bases, outer_strides, &output_stride, first_rank, start);
argReductionInnerCases<X, Z, ReductionOp>(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ);
}
}
else {
ZipGenericCoordsMovement<LastIndexFaster> movement;
movement.init(outer_bases, outer_strides, output_strides, first_rank, start);
argReductionInnerCases<X, Z, ReductionOp>(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ);
}
};
#if 0
func(0, 0, total, 1);
#else
//
uint32_t numThreads = sd::Environment::getInstance()->maxMasterThreads();
Nd4jLong inner_total = getLength<true>(inner_bases, second_rank);
if (total * inner_total <= threadingThreshold) {
numThreads = 1;
}
else {
if (inner_stride > outer_stride && total <= 256) {
auto desired = total > 4 ? (total / 4) : 1;
numThreads = numThreads > desired ? desired : numThreads;
}
}
samediff::Threads::parallel_tad(func, 0, total, 1, numThreads);
#endif
}
template<typename X, typename Z, typename ReductionOp>
void argIndex_(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
char input_order = input.ordering();
bool try_squash_outer = (input_order == output.ordering()) && output.ews() != 0;
const Nd4jLong* input_shapeInfo = input.shapeInfo();
const Nd4jLong* output_shapeInfo = output.shapeInfo();
const Nd4jLong rank = input_shapeInfo[0];
const Nd4jLong* input_bases = &(input_shapeInfo[1]);
const Nd4jLong* input_strides = &(input_shapeInfo[rank + 1]);
const Nd4jLong output_rank = output_shapeInfo[0];
const Nd4jLong* output_strides = &(output_shapeInfo[output_rank + 1]);
Nd4jLong new_bases[MAX_RANK];
Nd4jLong new_strides[MAX_RANK];
int first_begin, first_end, second_begin, second_end;
//rePartition into two parts based on the selection
rePartition(input_order, dimensions, rank, input_bases, input_strides, new_bases, new_strides, first_begin, first_end, second_begin, second_end, try_squash_outer, input_order == 'c');
int first_rank = first_end - first_begin; //the first rank can be 0 for scalar cases
int second_rank = second_end - second_begin;
auto bufferX = input.bufferAsT<X>();
auto outputZ = output.bufferAsT<Z>();
const Nd4jLong* outer_bases = &(new_bases[first_begin]);
const Nd4jLong* outer_strides = &(new_strides[first_begin]);
const Nd4jLong* inner_bases = &(new_bases[second_begin]);
const Nd4jLong* inner_strides = &(new_strides[second_begin]);
const Nd4jLong output_stride = output.ordering() == 'c' ? output_strides[output_rank-1]:output_strides[0];
if (input_order == 'c') {
if (first_rank == 0) {
argIndexCase1Scalar<X, Z, ReductionOp>(second_rank, inner_bases, inner_strides, bufferX, outputZ);
}
else {
argIndexCaseNonScalar<X, Z, ReductionOp>(first_rank, output_rank, try_squash_outer, second_rank, outer_bases, outer_strides, output_strides,
output_stride,inner_bases, inner_strides, bufferX, outputZ);
}
}
else {
if (first_rank == 0) {
LOG_CALLS(0);
if (second_rank == 1) {
argIndexCase1Scalar<X, Z, ReductionOp, false>(second_rank, inner_bases, inner_strides, bufferX, outputZ);
}
else {
argIndexCase1Scalar<X, Z, ReductionOp, true>(second_rank, inner_bases, inner_strides, bufferX, outputZ);
}
}
else {
LOG_CALLS(1);
argIndexCaseNonScalar<X, Z, ReductionOp,false>(first_rank, output_rank, try_squash_outer, second_rank, outer_bases, outer_strides, output_strides,
output_stride, inner_bases, inner_strides, bufferX, outputZ);
}
}
}
template <typename X, typename Z>
struct IndexMax {
static FORCEINLINE void update(X& current, Z& currentIndex, const X& candidate, const Z& candidateIndex) {
if (candidate > current) {
current = candidate;
currentIndex = candidateIndex;
}
}
};
template <typename X, typename Z>
struct IndexMin {
static FORCEINLINE void update(X& current, Z& currentIndex, const X& candidate, const Z& candidateIndex) {
if (candidate < current) {
current = candidate;
currentIndex = candidateIndex;
}
}
};
template <typename X, typename Z>
struct IndexAbsMax {
static FORCEINLINE void update(X& current, Z& currentIndex, const X& candidate, const Z& candidateIndex) {
auto absCandidate = sd::math::nd4j_abs<X>(candidate);
if (absCandidate > current) {
current = absCandidate;
currentIndex = candidateIndex;
}
}
};
template <typename X, typename Z>
struct IndexAbsMin {
static FORCEINLINE void update(X& current, Z& currentIndex, const X& candidate, const Z& candidateIndex) {
auto absCandidate = sd::math::nd4j_abs<X>(candidate);
if (absCandidate < current) {
current = absCandidate;
currentIndex = candidateIndex;
}
}
};
//////////////////////////////////////////////////////////////////////////
template<typename X, typename Z>
void argMax_(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
return argIndex_<X, Z, IndexMax<X, Z>>(input, output, dimensions);
}
template<typename X, typename Z>
void argMin_(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
return argIndex_<X, Z, IndexMin<X, Z>>(input, output, dimensions);
}
template<typename X, typename Z>
void argAbsMax_(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
return argIndex_<X, Z, IndexAbsMax<X, Z>>(input, output, dimensions);
}
template<typename X, typename Z>
void argAbsMin_(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
return argIndex_<X, Z, IndexAbsMin<X, Z>>(input, output, dimensions);
}
}
}
}

View File

@ -0,0 +1,106 @@
/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <ops/declarable/helpers/reductions.h>
#include <legacy/NativeOpExecutioner.h>
#include <helpers/ConstantTadHelper.h>
namespace sd {
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////////
void argMax(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
NDArray::prepareSpecialUse({&output}, {&input});
if (output.isScalar()) {
NativeOpExecutioner::execIndexReduceScalar(LaunchContext::defaultContext(), indexreduce::Ops::IndexMax, input.buffer(), input.shapeInfo(), input.specialBuffer(), input.specialShapeInfo(), nullptr, output.buffer(), output.shapeInfo(), output.specialBuffer(), output.specialShapeInfo());
}
else {
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dimensions);
NativeOpExecutioner::execIndexReduce(LaunchContext::defaultContext(), indexreduce::Ops::IndexMax,
input.buffer(), input.shapeInfo(), input.specialBuffer(), input.specialShapeInfo(),
nullptr,
output.buffer(), output.shapeInfo(), output.specialBuffer(), output.specialShapeInfo(),
(int*) nullptr, dimensions.size(),
tadPack.specialShapeInfo(), tadPack.specialOffsets());
}
NDArray::registerSpecialUse({ &output }, { &input });
}
void argMin(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
NDArray::prepareSpecialUse({ &output }, { &input });
if (output.isScalar()) {
NativeOpExecutioner::execIndexReduceScalar(LaunchContext::defaultContext(), indexreduce::Ops::IndexMin, input.buffer(), input.shapeInfo(), input.specialBuffer(), input.specialShapeInfo(), nullptr, output.buffer(), output.shapeInfo(), output.specialBuffer(), output.specialShapeInfo());
}
else {
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dimensions);
NativeOpExecutioner::execIndexReduce(LaunchContext::defaultContext(), indexreduce::Ops::IndexMin,
input.buffer(), input.shapeInfo(), input.specialBuffer(), input.specialShapeInfo(),
nullptr,
output.buffer(), output.shapeInfo(), output.specialBuffer(), output.specialShapeInfo(),
(int*) nullptr, dimensions.size(),
tadPack.specialShapeInfo(), tadPack.specialOffsets());
}
NDArray::registerSpecialUse({ &output }, { &input });
}
void argAbsMax(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
NDArray::prepareSpecialUse({ &output }, { &input });
if (output.isScalar()) {
NativeOpExecutioner::execIndexReduceScalar(LaunchContext::defaultContext(), indexreduce::Ops::IndexAbsoluteMax, input.buffer(), input.shapeInfo(), input.specialBuffer(), input.specialShapeInfo(), nullptr, output.buffer(), output.shapeInfo(), output.specialBuffer(), output.specialShapeInfo());
}
else {
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dimensions);
NativeOpExecutioner::execIndexReduce(LaunchContext::defaultContext(), indexreduce::Ops::IndexAbsoluteMax,
input.buffer(), input.shapeInfo(), input.specialBuffer(), input.specialShapeInfo(),
nullptr,
output.buffer(), output.shapeInfo(), output.specialBuffer(), output.specialShapeInfo(),
(int*) nullptr, dimensions.size(),
tadPack.specialShapeInfo(), tadPack.specialOffsets());
}
NDArray::registerSpecialUse({ &output }, { &input });
}
void argAbsMin(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
NDArray::prepareSpecialUse({ &output }, { &input });
if (output.isScalar()) {
NativeOpExecutioner::execIndexReduceScalar(LaunchContext::defaultContext(), indexreduce::Ops::IndexAbsoluteMin, input.buffer(), input.shapeInfo(), input.specialBuffer(), input.specialShapeInfo(), nullptr, output.buffer(), output.shapeInfo(), output.specialBuffer(), output.specialShapeInfo());
}
else {
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dimensions);
NativeOpExecutioner::execIndexReduce(LaunchContext::defaultContext(), indexreduce::Ops::IndexAbsoluteMin,
input.buffer(), input.shapeInfo(), input.specialBuffer(), input.specialShapeInfo(),
nullptr,
output.buffer(), output.shapeInfo(), output.specialBuffer(), output.specialShapeInfo(),
(int *) nullptr, dimensions.size(),
tadPack.specialShapeInfo(), tadPack.specialOffsets());
}
NDArray::registerSpecialUse({&output}, {&input});
}
}
}
}

View File

@ -0,0 +1,41 @@
/*******************************************************************************
* Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author AbdelRauf (rauf@konduit.ai)
//
#ifndef LIBND4J_HELPERS_REDUCTIONS_H
#define LIBND4J_HELPERS_REDUCTIONS_H
#include <system/op_boilerplate.h>
#include <math/templatemath.h>
#include <array/NDArray.h>
namespace sd {
namespace ops {
namespace helpers {
void argMax(const NDArray& input, NDArray& output, const std::vector<int>& dimensions);
void argAbsMax(const NDArray& input, NDArray& output, const std::vector<int>& dimensions);
void argMin(const NDArray& input, NDArray& output, const std::vector<int>& dimensions);
void argAbsMin(const NDArray& input, NDArray& output, const std::vector<int>& dimensions);
}
}
}
#endif

View File

@ -40,6 +40,19 @@ public:
} }
}; };
TEST_F(DeclarableOpsTests19, test_argmax_maxint_vector_1) {
auto x = NDArrayFactory::create<float>('c', {3}, {0.1f, 0.5f, 0.7f});
auto z = NDArrayFactory::create<Nd4jLong>(0);
auto e = NDArrayFactory::create<Nd4jLong>(2);
sd::ops::argmax op;
auto status = op.execute({&x}, {&z}, {DataTypeUtils::max<int>()});
ASSERT_EQ(Status::OK(), status);
ASSERT_EQ(e, z);
}
TEST_F(DeclarableOpsTests19, test_threshold_encode_1) { TEST_F(DeclarableOpsTests19, test_threshold_encode_1) {
auto x = NDArrayFactory::create<double>('c', {3}, {1.5, 2.5, -3.5}); auto x = NDArrayFactory::create<double>('c', {3}, {1.5, 2.5, -3.5});
auto exp_encoded = NDArrayFactory::create<int>('c', {7}, {3, 3, 1056964608, 0, 1, 2, -3}); auto exp_encoded = NDArrayFactory::create<int>('c', {7}, {3, 3, 1056964608, 0, 1, 2, -3});
@ -276,6 +289,7 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_decode_2) {
} }
TEST_F(DeclarableOpsTests19, test_matmul_ccc) { TEST_F(DeclarableOpsTests19, test_matmul_ccc) {
auto x = NDArrayFactory::create<float>('c', {10, 10}); auto x = NDArrayFactory::create<float>('c', {10, 10});
auto y = NDArrayFactory::create<float>('c', {10, 10}); auto y = NDArrayFactory::create<float>('c', {10, 10});

View File

@ -43,9 +43,12 @@
#include <array> #include <array>
#include <performance/benchmarking/FullBenchmarkSuit.h> #include <performance/benchmarking/FullBenchmarkSuit.h>
#include <performance/benchmarking/LightBenchmarkSuit.h> #include <performance/benchmarking/LightBenchmarkSuit.h>
#include <random>
#include <ops/declarable/helpers/legacy_helpers.h> #include <ops/declarable/helpers/legacy_helpers.h>
#include <ops/declarable/helpers/addBias.h> #include <ops/declarable/helpers/addBias.h>
#include <ops/declarable/helpers/axis.h>
#include <ops/declarable/helpers/reductions.h>
#include <helpers/LoopsCoordsHelper.h>
using namespace sd; using namespace sd;
using namespace sd::graph; using namespace sd::graph;
@ -275,6 +278,256 @@ TEST_F(PlaygroundTests, test_one_off_ops_1) {
op.execute({&x, &y}, {&z}); op.execute({&x, &y}, {&z});
} }
#if defined(INDEX_REDUCTIONS_BENCH_TESTS)
//temporarly, testing against the original one
void original_argmax(const NDArray& input, std::vector<int>& axis, NDArray& output) {
sd::ops::helpers::adjustAxis(input.rankOf(), axis);
input.applyIndexReduce(sd::indexreduce::IndexMax, output, axis);
}
template<typename T>
void fill_random(sd::NDArray& arr) {
Nd4jLong coords[MAX_RANK] = {};
std::random_device rd;
std::mt19937 gen(rd());
//for floats
std::uniform_real_distribution<T> dis((T)-10.0, (T)22.9);
T* x = arr.bufferAsT<T>();
Nd4jLong* shapeInfo = arr.getShapeInfo();
Nd4jLong* strides = arr.stridesOf();
Nd4jLong rank = shapeInfo[0];
Nd4jLong* bases = &(shapeInfo[1]);
size_t t = 1;
for (size_t i = 0; i < rank ; i++) {
t *= bases[i];
}
size_t offset = 0;
if (arr.ordering() == 'c') {
for (size_t i = 0; i < t; i++) {
x[offset] = dis(gen) ;
offset = sd::inc_coords(bases, strides, coords, offset, rank);
}
}
else {
for (size_t i = 0; i < t; i++) {
x[offset] = dis(gen) ;
offset = sd::inc_coords<false>(bases, strides, coords, offset, rank);
}
}
}
void testLegacy(bool random) {
#if 0
int bases[] = { 3, 2, 4, 5, 7 };
constexpr int Loop = 1;
#else
int bases[] = { 8, 32, 64, 32, 64 };
constexpr int Loop = 10;
#endif
constexpr int N = 5;
auto x = NDArrayFactory::create<float>('c', { bases[0], bases[1], bases[2], bases[3], bases[4] });
if (!random) {
x.linspace(1);
}
else{
fill_random<float>(x);
}
#define COMBINATIONS 1
#if COMBINATIONS
//https://www.rosettacode.org/wiki/Combinations#C.2B.2B
for (int k = N; k >= 1; k--) {
std::string bitmask(k, 1); // K leading 1's
bitmask.resize(N, 0); // N-K trailing 0's
do {
std::vector<int> dimension;
std::vector<Nd4jLong> output_bases;
for (int i = 0; i < N; ++i) // [0..N-1] integers
{
if (bitmask[i]) dimension.push_back(i);
else {
output_bases.push_back(bases[i]);
}
}
#else
std::vector<int> dimension = { 0,1,2,3 };
int k = 4;
#endif
auto dim = NDArrayFactory::create<int>(dimension);
#if 1
nd4j_printf("C(N:%d K:%d) \n", N, k);
dim.printIndexedBuffer("Dimension");
for (int xind : dimension) {
nd4j_printf(" %d ,", bases[xind]);
}
nd4j_printf("%s", "\n");
#endif
std::vector<Nd4jLong> values;
sd::ResultSet result;
for (int e = 0; e < Loop; e++) {
auto timeStart = std::chrono::system_clock::now();
NDArray exp = output_bases.size() > 0 ? NDArrayFactory::create<Nd4jLong>('c', output_bases) : NDArrayFactory::create<Nd4jLong>(0);
original_argmax(x, dimension, exp);
auto timeEnd = std::chrono::system_clock::now();
auto outerTime = std::chrono::duration_cast<std::chrono::microseconds>(timeEnd - timeStart).count();
values.emplace_back(outerTime);
}
std::sort(values.begin(), values.end());
nd4j_printf("Time: %lld us;\n", values[values.size() / 2]);
#if COMBINATIONS
} while (std::prev_permutation(bitmask.begin(), bitmask.end()));
}
#endif
}
#define DEBUG 1
void testNewReduction(bool random, bool checkCorrectness = false , char order ='c') {
std::vector<Nd4jLong> arr_dimensions;
#if defined(DEBUG)
int bases[] = { 3, 2, 3, 3, 5 ,4,7,4,7,7 };
constexpr int Loop = 1;
constexpr int N = 10;
#else
int bases[] = { 8, 32, 64, 32, 64 };
constexpr int Loop = 10;
constexpr int N = 5;
#endif
for (int i = 0; i < N; i++) {
arr_dimensions.push_back(bases[i]);
}
auto x = NDArrayFactory::create<float>(order,arr_dimensions);
if (!random) {
x.linspace(1);
}
else {
fill_random<float>(x);
}
#define COMBINATIONS 1
#if COMBINATIONS
//https://www.rosettacode.org/wiki/Combinations#C.2B.2B
for (int k = N; k >= 1; k--) {
std::string bitmask(k, 1); // K leading 1's
bitmask.resize(N, 0); // N-K trailing 0's
do {
std::vector<int> dimension;
std::vector<Nd4jLong> output_bases;
for (int i = 0; i < N; ++i) // [0..N-1] integers
{
if (bitmask[i]) dimension.push_back(i);
else {
output_bases.push_back(bases[i]);
}
}
#else
std::vector<int> dimension = { 0,1,2,3 };
int k = 4;
#endif
auto dim = NDArrayFactory::create<int>(dimension);
#if 1
nd4j_printf("C(N:%d K:%d) \n", N, k);
dim.printIndexedBuffer("Dimension");
for (int xind : dimension) {
nd4j_printf(" %d ,", bases[xind]);
}
nd4j_printf("%s", "\n");
#endif
sd::ops::argmax op;
std::vector<Nd4jLong> values;
sd::ResultSet result;
for (int e = 0; e < Loop; e++) {
auto timeStart = std::chrono::system_clock::now();
result = op.evaluate({ &x, &dim }, {}, {});
auto timeEnd = std::chrono::system_clock::now();
auto outerTime = std::chrono::duration_cast<std::chrono::microseconds>(timeEnd - timeStart).count();
values.emplace_back(outerTime);
}
auto z = result.at(0);
if (checkCorrectness) {
//check for the correctness
NDArray exp = output_bases.size() > 0 ? NDArrayFactory::create<Nd4jLong>('c', output_bases) : NDArrayFactory::create<Nd4jLong>(0);
original_argmax(x, dimension, exp);
#if 0// defined(DEBUG)
x.printIndexedBuffer("X");
exp.printIndexedBuffer("Expected");
z->printIndexedBuffer("Z");
#endif
ASSERT_TRUE(exp.isSameShape(z));
ASSERT_TRUE(exp.equalsTo(z));
}
std::sort(values.begin(), values.end());
nd4j_printf("Time: %lld us;\n", values[values.size() / 2]);
#if COMBINATIONS
} while (std::prev_permutation(bitmask.begin(), bitmask.end()));
}
#endif
}
constexpr bool test_corr = true;
#if !defined(DEBUG)
TEST_F(PlaygroundTests, ArgMaxPerfLinspace) {
testNewReduction(false, test_corr);
}
#endif
TEST_F(PlaygroundTests, ArgMaxPerfRandom) {
testNewReduction(true, test_corr);
}
TEST_F(PlaygroundTests, ArgMaxPerfRandomOrderF) {
testNewReduction(true, test_corr, 'f');
}
#if !defined(DEBUG)
TEST_F(PlaygroundTests, ArgMaxPerfLegacyLinspace) {
testLegacy(false);
}
TEST_F(PlaygroundTests, ArgMaxPerfLegacyRandom) {
testLegacy(true);
}
#endif
#endif
/* /*

View File

@ -106,7 +106,7 @@ public class SDBaseOps {
public SDVariable argmax(SDVariable in, boolean keepDims, int... dimensions) { public SDVariable argmax(SDVariable in, boolean keepDims, int... dimensions) {
SDValidation.validateNumerical("argmax", "in", in); SDValidation.validateNumerical("argmax", "in", in);
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
return new org.nd4j.linalg.api.ops.impl.indexaccum.IMax(sd,in, keepDims, dimensions).outputVariable(); return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd,in, keepDims, dimensions).outputVariable();
} }
/** /**
@ -130,7 +130,7 @@ public class SDBaseOps {
public SDVariable argmax(String name, SDVariable in, boolean keepDims, int... dimensions) { public SDVariable argmax(String name, SDVariable in, boolean keepDims, int... dimensions) {
SDValidation.validateNumerical("argmax", "in", in); SDValidation.validateNumerical("argmax", "in", in);
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IMax(sd,in, keepDims, dimensions).outputVariable(); SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd,in, keepDims, dimensions).outputVariable();
return sd.updateVariableNameAndReference(out, name); return sd.updateVariableNameAndReference(out, name);
} }
@ -153,7 +153,7 @@ public class SDBaseOps {
public SDVariable argmax(SDVariable in, int... dimensions) { public SDVariable argmax(SDVariable in, int... dimensions) {
SDValidation.validateNumerical("argmax", "in", in); SDValidation.validateNumerical("argmax", "in", in);
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
return new org.nd4j.linalg.api.ops.impl.indexaccum.IMax(sd,in, false, dimensions).outputVariable(); return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd,in, false, dimensions).outputVariable();
} }
/** /**
@ -176,7 +176,7 @@ public class SDBaseOps {
public SDVariable argmax(String name, SDVariable in, int... dimensions) { public SDVariable argmax(String name, SDVariable in, int... dimensions) {
SDValidation.validateNumerical("argmax", "in", in); SDValidation.validateNumerical("argmax", "in", in);
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IMax(sd,in, false, dimensions).outputVariable(); SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd,in, false, dimensions).outputVariable();
return sd.updateVariableNameAndReference(out, name); return sd.updateVariableNameAndReference(out, name);
} }
@ -203,7 +203,7 @@ public class SDBaseOps {
public SDVariable argmin(SDVariable in, boolean keepDims, int... dimensions) { public SDVariable argmin(SDVariable in, boolean keepDims, int... dimensions) {
SDValidation.validateNumerical("argmin", "in", in); SDValidation.validateNumerical("argmin", "in", in);
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
return new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(sd,in, keepDims, dimensions).outputVariable(); return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd,in, keepDims, dimensions).outputVariable();
} }
/** /**
@ -230,7 +230,7 @@ public class SDBaseOps {
public SDVariable argmin(String name, SDVariable in, boolean keepDims, int... dimensions) { public SDVariable argmin(String name, SDVariable in, boolean keepDims, int... dimensions) {
SDValidation.validateNumerical("argmin", "in", in); SDValidation.validateNumerical("argmin", "in", in);
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(sd,in, keepDims, dimensions).outputVariable(); SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd,in, keepDims, dimensions).outputVariable();
return sd.updateVariableNameAndReference(out, name); return sd.updateVariableNameAndReference(out, name);
} }
@ -256,7 +256,7 @@ public class SDBaseOps {
public SDVariable argmin(SDVariable in, int... dimensions) { public SDVariable argmin(SDVariable in, int... dimensions) {
SDValidation.validateNumerical("argmin", "in", in); SDValidation.validateNumerical("argmin", "in", in);
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
return new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(sd,in, false, dimensions).outputVariable(); return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd,in, false, dimensions).outputVariable();
} }
/** /**
@ -282,7 +282,7 @@ public class SDBaseOps {
public SDVariable argmin(String name, SDVariable in, int... dimensions) { public SDVariable argmin(String name, SDVariable in, int... dimensions) {
SDValidation.validateNumerical("argmin", "in", in); SDValidation.validateNumerical("argmin", "in", in);
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(sd,in, false, dimensions).outputVariable(); SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd,in, false, dimensions).outputVariable();
return sd.updateVariableNameAndReference(out, name); return sd.updateVariableNameAndReference(out, name);
} }

View File

@ -1875,7 +1875,7 @@ public class SDMath extends SDOps {
public SDVariable iamax(SDVariable in, int... dimensions) { public SDVariable iamax(SDVariable in, int... dimensions) {
SDValidation.validateNumerical("iamax", "in", in); SDValidation.validateNumerical("iamax", "in", in);
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
return new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(sd,in, false, dimensions).outputVariable(); return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd,in, false, dimensions).outputVariable();
} }
/** /**
@ -1890,7 +1890,7 @@ public class SDMath extends SDOps {
public SDVariable iamax(String name, SDVariable in, int... dimensions) { public SDVariable iamax(String name, SDVariable in, int... dimensions) {
SDValidation.validateNumerical("iamax", "in", in); SDValidation.validateNumerical("iamax", "in", in);
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(sd,in, false, dimensions).outputVariable(); SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd,in, false, dimensions).outputVariable();
return sd.updateVariableNameAndReference(out, name); return sd.updateVariableNameAndReference(out, name);
} }
@ -1906,7 +1906,7 @@ public class SDMath extends SDOps {
public SDVariable iamax(SDVariable in, boolean keepDims, int... dimensions) { public SDVariable iamax(SDVariable in, boolean keepDims, int... dimensions) {
SDValidation.validateNumerical("iamax", "in", in); SDValidation.validateNumerical("iamax", "in", in);
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
return new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(sd,in, keepDims, dimensions).outputVariable(); return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd,in, keepDims, dimensions).outputVariable();
} }
/** /**
@ -1922,7 +1922,7 @@ public class SDMath extends SDOps {
public SDVariable iamax(String name, SDVariable in, boolean keepDims, int... dimensions) { public SDVariable iamax(String name, SDVariable in, boolean keepDims, int... dimensions) {
SDValidation.validateNumerical("iamax", "in", in); SDValidation.validateNumerical("iamax", "in", in);
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(sd,in, keepDims, dimensions).outputVariable(); SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd,in, keepDims, dimensions).outputVariable();
return sd.updateVariableNameAndReference(out, name); return sd.updateVariableNameAndReference(out, name);
} }
@ -1937,7 +1937,7 @@ public class SDMath extends SDOps {
public SDVariable iamin(SDVariable in, int... dimensions) { public SDVariable iamin(SDVariable in, int... dimensions) {
SDValidation.validateNumerical("iamin", "in", in); SDValidation.validateNumerical("iamin", "in", in);
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
return new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(sd,in, false, dimensions).outputVariable(); return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd,in, false, dimensions).outputVariable();
} }
/** /**
@ -1952,7 +1952,7 @@ public class SDMath extends SDOps {
public SDVariable iamin(String name, SDVariable in, int... dimensions) { public SDVariable iamin(String name, SDVariable in, int... dimensions) {
SDValidation.validateNumerical("iamin", "in", in); SDValidation.validateNumerical("iamin", "in", in);
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(sd,in, false, dimensions).outputVariable(); SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd,in, false, dimensions).outputVariable();
return sd.updateVariableNameAndReference(out, name); return sd.updateVariableNameAndReference(out, name);
} }
@ -1968,7 +1968,7 @@ public class SDMath extends SDOps {
public SDVariable iamin(SDVariable in, boolean keepDims, int... dimensions) { public SDVariable iamin(SDVariable in, boolean keepDims, int... dimensions) {
SDValidation.validateNumerical("iamin", "in", in); SDValidation.validateNumerical("iamin", "in", in);
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
return new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(sd,in, keepDims, dimensions).outputVariable(); return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd,in, keepDims, dimensions).outputVariable();
} }
/** /**
@ -1984,7 +1984,7 @@ public class SDMath extends SDOps {
public SDVariable iamin(String name, SDVariable in, boolean keepDims, int... dimensions) { public SDVariable iamin(String name, SDVariable in, boolean keepDims, int... dimensions) {
SDValidation.validateNumerical("iamin", "in", in); SDValidation.validateNumerical("iamin", "in", in);
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(sd,in, keepDims, dimensions).outputVariable(); SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd,in, keepDims, dimensions).outputVariable();
return sd.updateVariableNameAndReference(out, name); return sd.updateVariableNameAndReference(out, name);
} }

View File

@ -682,14 +682,6 @@ public class LegacyOpMapper {
public static Class<?> indexReduceClass(int opNum){ public static Class<?> indexReduceClass(int opNum){
switch (opNum){ switch (opNum){
case 0:
return IMax.class;
case 1:
return IMin.class;
case 2:
return IAMax.class;
case 3:
return IAMin.class;
case 4: case 4:
return FirstIndex.class; return FirstIndex.class;
case 5: case 5:

View File

@ -1055,10 +1055,6 @@ public class OpValidation {
IsNumericTensor.class, IsNumericTensor.class,
//Exclude index accumulations (index out, not real-valued) //Exclude index accumulations (index out, not real-valued)
FirstIndex.class, FirstIndex.class,
IAMax.class,
IAMin.class,
IMax.class,
IMin.class,
LastIndex.class, LastIndex.class,
ArgMax.class, ArgMax.class,
ArgMin.class, ArgMin.class,

View File

@ -105,13 +105,11 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.image.ResizeNearestNeighbor.class, org.nd4j.linalg.api.ops.impl.image.ResizeNearestNeighbor.class,
org.nd4j.linalg.api.ops.impl.image.ResizeArea.class, org.nd4j.linalg.api.ops.impl.image.ResizeArea.class,
org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex.class, org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex.class,
org.nd4j.linalg.api.ops.impl.indexaccum.IAMax.class,
org.nd4j.linalg.api.ops.impl.indexaccum.IAMin.class,
org.nd4j.linalg.api.ops.impl.indexaccum.IMax.class,
org.nd4j.linalg.api.ops.impl.indexaccum.IMin.class,
org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex.class, org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex.class,
org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax.class, org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax.class,
org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin.class, org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin.class,
org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgAmax.class,
org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgAmin.class,
org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction.class, org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D.class, org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D.class,
org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling3D.class, org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling3D.class,

View File

@ -1,78 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.indexaccum;
import lombok.Data;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseIndexAccumulation;
import java.util.Collections;
import java.util.List;
/**
* Calculate the index of the max absolute value over a vector
*
* @author Adam Gibson
*/
@Data
public class IAMax extends BaseIndexAccumulation {
public IAMax(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
super(sameDiff, i_v, keepDims, dimensions);
}
public IAMax() {}
public IAMax(INDArray x, int... dimensions) {
this(x, false, dimensions);
}
public IAMax(INDArray x, boolean keepDims, int... dimensions) {
this(x, null, dimensions);
this.keepDims = keepDims;
}
public IAMax(INDArray x, INDArray z, int... dimensions) {
super(x, z, dimensions);
}
@Override
public int opNum() {
return 2;
}
@Override
public String opName() {
return "iamax";
}
@Override
public String onnxName() {
return "AbsArgMax";
}
@Override
public String tensorflowName() {
return "absargmax";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> grad){
return Collections.singletonList(sameDiff.zerosLike(arg()));
}
}

View File

@ -1,80 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.indexaccum;
import lombok.Data;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseIndexAccumulation;
import java.util.Collections;
import java.util.List;
/**
* Calculate the index of the max absolute value over a vector
*
* @author Adam Gibson
*/
@Data
public class IAMin extends BaseIndexAccumulation {
public IAMin(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
super(sameDiff, i_v, keepDims, dimensions);
}
public IAMin() {}
public IAMin(INDArray x, int... dimensions) {
super(x, dimensions);
}
public IAMin(INDArray in, boolean keepDims, int... dimnesions){
super(in, null, dimnesions);
this.keepDims = keepDims;
}
public IAMin(INDArray x, INDArray z, int... dimensions) {
super(x, z, dimensions);
}
@Override
public int opNum() {
return 3;
}
@Override
public String opName() {
return "iamin";
}
@Override
public String onnxName() {
return "AbsArgMin";
}
@Override
public String tensorflowName() {
return "absargmin";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> grad){
return Collections.singletonList(sameDiff.zerosLike(arg()));
}
}

View File

@ -1,87 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.indexaccum;
import lombok.Data;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseIndexAccumulation;
import java.util.Collections;
import java.util.List;
/**
* Calculate the index
* of max value over a vector
*
* @author Alex Black
*/
@Data
public class IMax extends BaseIndexAccumulation {
public IMax(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
super(sameDiff, i_v, keepDims, dimensions);
}
public IMax() {
}
public IMax(INDArray x, INDArray z, int... dimensions) {
super(x, z, dimensions);
}
public IMax(INDArray x, int... dimensions) {
super(x, null, dimensions);
}
public IMax(INDArray x, boolean keepDims, int... dimensions) {
super(x, null, dimensions);
this.keepDims = keepDims;
}
@Override
public int opNum() {
return 0;
}
@Override
public String opName() {
return "imax";
}
@Override
public String onnxName() {
return "arg_max";
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
@Override
public Type opType() {
return Type.INDEXREDUCE;
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
//Not differentiable, but (assuming no ties) output does not change for a given infinitesimal change in the input
return Collections.singletonList(sameDiff.zerosLike(arg()));
}
}

View File

@ -1,83 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.indexaccum;
import lombok.Data;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BaseIndexAccumulation;
import java.util.Collections;
import java.util.List;
/**
* Calculate the index of min value over a vector
*
* @author Alex Black
*/
@Data
public class IMin extends BaseIndexAccumulation {
public IMin(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
super(sameDiff, i_v, keepDims, dimensions);
}
public IMin() {
}
public IMin(INDArray x, int... dimensions) {
super(x, dimensions);
}
public IMin(INDArray x, boolean keepDims, int... dimensions) {
super(x, keepDims, dimensions);
}
public IMin(INDArray x, INDArray z, int... dimensions) {
super(x, z, dimensions);
}
@Override
public int opNum() {
return 1;
}
@Override
public String opName() {
return "imin";
}
@Override
public String onnxName() {
return "ArgMin";
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
//Not differentiable, but (assuming no ties) output does not change for a given infinitesimal change in the input
return Collections.singletonList(sameDiff.zerosLike(arg()));
}
}

View File

@ -0,0 +1,111 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
import lombok.Data;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@Data
public class ArgAmax extends DynamicCustomOp {
protected boolean keepDims = false;
private int[] dimensions;
protected DataType outputType = DataType.INT64;
public ArgAmax(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
super(sameDiff, i_v);
this.keepDims = keepDims;
this.dimensions = dimensions;
if (dimensions != null && dimensions.length > 0)
addIArgument(dimensions);
addBArgument(keepDims);
addDArgument(outputType);
}
public ArgAmax() {
}
public ArgAmax(INDArray x, INDArray z, boolean keepDims, int... dimensions) {
super(new INDArray[]{x}, z != null ? new INDArray[] {z} : new INDArray[0]);
this.keepDims = keepDims;
this.dimensions = dimensions;
if (dimensions != null && dimensions.length > 0)
addIArgument(dimensions);
addBArgument(keepDims);
addDArgument(outputType);
}
public ArgAmax(INDArray x, INDArray z, int... dimensions) {
this(x, z, false, dimensions);
}
public ArgAmax(INDArray x, int... dimensions) {
this(x, null, dimensions);
}
public ArgAmax(INDArray x, boolean keepDims, int... dimensions) {
this(x, null, keepDims, dimensions);
}
@Override
public String opName() {
return "argamax";
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
if(attributesForNode.containsKey("output_type")) {
outputType = TFGraphMapper.convertType(attributesForNode.get("output_type").getType());
} else {
outputType = DataType.LONG;
}
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 1 || inputDataTypes.size() == 2),
"Expected 1 or 2 input datatype to argamax, got %s", inputDataTypes); //2nd input: axis
return Collections.singletonList(outputType == null ? DataType.LONG : outputType);
}
}

View File

@ -0,0 +1,111 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
import lombok.Data;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@Data
public class ArgAmin extends DynamicCustomOp {
protected boolean keepDims = false;
private int[] dimensions;
protected DataType outputType = DataType.INT64;
public ArgAmin(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
super(sameDiff, i_v);
this.keepDims = keepDims;
this.dimensions = dimensions;
if (dimensions != null && dimensions.length > 0)
addIArgument(dimensions);
addBArgument(keepDims);
addDArgument(outputType);
}
public ArgAmin() {
}
public ArgAmin(INDArray x, INDArray z, boolean keepDims, int... dimensions) {
super(new INDArray[]{x}, z != null ? new INDArray[] {z} : new INDArray[0]);
this.keepDims = keepDims;
this.dimensions = dimensions;
if (dimensions != null && dimensions.length > 0)
addIArgument(dimensions);
addBArgument(keepDims);
addDArgument(outputType);
}
public ArgAmin(INDArray x, INDArray z, int... dimensions) {
this(x, z, false, dimensions);
}
public ArgAmin(INDArray x, int... dimensions) {
this(x, null, dimensions);
}
public ArgAmin(INDArray x, boolean keepDims, int... dimensions) {
this(x, null, keepDims, dimensions);
}
@Override
public String opName() {
return "argamin";
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
if(attributesForNode.containsKey("output_type")) {
outputType = TFGraphMapper.convertType(attributesForNode.get("output_type").getType());
} else {
outputType = DataType.LONG;
}
}
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 1 || inputDataTypes.size() == 2),
"Expected 1 or 2 input datatype to argamin, got %s", inputDataTypes); //2nd input: axis
return Collections.singletonList(outputType == null ? DataType.LONG : outputType);
}
}

View File

@ -17,10 +17,12 @@
package org.nd4j.linalg.api.ops.impl.indexaccum.custom; package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
import lombok.Data; import lombok.Data;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions; import org.nd4j.common.base.Preconditions;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
@ -32,8 +34,53 @@ import java.util.Map;
@Data @Data
public class ArgMax extends DynamicCustomOp { public class ArgMax extends DynamicCustomOp {
protected boolean keepDims = false;
private int[] dimensions;
protected DataType outputType; protected DataType outputType = DataType.INT64;
public ArgMax(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
super(sameDiff, i_v);
this.keepDims = keepDims;
this.dimensions = dimensions;
if (dimensions != null && dimensions.length > 0)
addIArgument(dimensions);
addBArgument(keepDims);
addDArgument(outputType);
}
public ArgMax() {
}
public ArgMax(INDArray x, INDArray z, boolean keepDims, int... dimensions) {
super(new INDArray[]{x}, z != null ? new INDArray[] {z} : new INDArray[0]);
this.keepDims = keepDims;
this.dimensions = dimensions;
if (dimensions != null && dimensions.length > 0)
addIArgument(dimensions);
addBArgument(keepDims);
addDArgument(outputType);
}
public ArgMax(INDArray x, INDArray z, int... dimensions) {
this(x, z, false, dimensions);
}
public ArgMax(INDArray x, int... dimensions) {
this(x, null, dimensions);
}
public ArgMax(INDArray x, boolean keepDims, int... dimensions) {
this(x, null, keepDims, dimensions);
}
@Override @Override
public String opName() { public String opName() {

View File

@ -17,10 +17,12 @@
package org.nd4j.linalg.api.ops.impl.indexaccum.custom; package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
import lombok.Data; import lombok.Data;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.common.base.Preconditions; import org.nd4j.common.base.Preconditions;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
@ -37,8 +39,53 @@ import java.util.Map;
*/ */
@Data @Data
public class ArgMin extends DynamicCustomOp { public class ArgMin extends DynamicCustomOp {
protected boolean keepDims = false;
private int[] dimensions;
protected DataType outputType = DataType.LONG; protected DataType outputType = DataType.INT64;
public ArgMin(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
super(sameDiff, i_v);
this.keepDims = keepDims;
this.dimensions = dimensions;
if (dimensions != null && dimensions.length > 0)
addIArgument(dimensions);
addBArgument(keepDims);
addDArgument(outputType);
}
public ArgMin() {
}
public ArgMin(INDArray x, INDArray z, boolean keepDims, int... dimensions) {
super(new INDArray[]{x}, z != null ? new INDArray[] {z} : new INDArray[0]);
this.keepDims = keepDims;
this.dimensions = dimensions;
if (dimensions != null && dimensions.length > 0)
addIArgument(dimensions);
addBArgument(keepDims);
addDArgument(outputType);
}
public ArgMin(INDArray x, INDArray z, int... dimensions) {
this(x, z, false, dimensions);
}
public ArgMin(INDArray x, int... dimensions) {
this(x, null, dimensions);
}
public ArgMin(INDArray x, boolean keepDims, int... dimensions) {
this(x, null, keepDims, dimensions);
}
@Override @Override
public String opName() { public String opName() {

View File

@ -17,6 +17,8 @@
package org.nd4j.linalg.factory; package org.nd4j.linalg.factory;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin;
import org.nd4j.linalg.factory.ops.*; import org.nd4j.linalg.factory.ops.*;
import org.nd4j.shade.guava.primitives.Ints; import org.nd4j.shade.guava.primitives.Ints;
import org.nd4j.shade.guava.primitives.Longs; import org.nd4j.shade.guava.primitives.Longs;
@ -50,8 +52,6 @@ import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.ops.OpContext;
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner; import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.indexaccum.IMax;
import org.nd4j.linalg.api.ops.impl.indexaccum.IMin;
import org.nd4j.linalg.api.ops.impl.reduce.Mmul; import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
import org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans; import org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans;
import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate; import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
@ -627,16 +627,16 @@ public class Nd4j {
* @return array of maximum values. * @return array of maximum values.
*/ */
public static INDArray argMax(INDArray arr, @NonNull int... dimension) { public static INDArray argMax(INDArray arr, @NonNull int... dimension) {
IMax imax = new IMax(arr, dimension); val imax = new ArgMax(arr, dimension);
return Nd4j.getExecutioner().exec(imax); return Nd4j.getExecutioner().exec(imax)[0];
} }
/** /**
* See {@link #argMax(INDArray, int...)} but return minimum values. * See {@link #argMax(INDArray, int...)} but return minimum values.
*/ */
public static INDArray argMin(INDArray arr, @NonNull int... dimension) { public static INDArray argMin(INDArray arr, @NonNull int... dimension) {
IMin imin = new IMin(arr, dimension); val imin = new ArgMin(arr, dimension);
return Nd4j.getExecutioner().exec(imin); return Nd4j.getExecutioner().exec(imin)[0];
} }
/** /**

View File

@ -75,7 +75,7 @@ public class NDBase {
public INDArray argmax(INDArray in, boolean keepDims, int... dimensions) { public INDArray argmax(INDArray in, boolean keepDims, int... dimensions) {
NDValidation.validateNumerical("argmax", "in", in); NDValidation.validateNumerical("argmax", "in", in);
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IMax(in, keepDims, dimensions)); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(in, keepDims, dimensions))[0];
} }
/** /**
@ -97,7 +97,7 @@ public class NDBase {
public INDArray argmax(INDArray in, int... dimensions) { public INDArray argmax(INDArray in, int... dimensions) {
NDValidation.validateNumerical("argmax", "in", in); NDValidation.validateNumerical("argmax", "in", in);
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IMax(in, false, dimensions)); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(in, false, dimensions))[0];
} }
/** /**
@ -123,7 +123,7 @@ public class NDBase {
public INDArray argmin(INDArray in, boolean keepDims, int... dimensions) { public INDArray argmin(INDArray in, boolean keepDims, int... dimensions) {
NDValidation.validateNumerical("argmin", "in", in); NDValidation.validateNumerical("argmin", "in", in);
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(in, keepDims, dimensions)); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(in, keepDims, dimensions))[0];
} }
/** /**
@ -148,7 +148,7 @@ public class NDBase {
public INDArray argmin(INDArray in, int... dimensions) { public INDArray argmin(INDArray in, int... dimensions) {
NDValidation.validateNumerical("argmin", "in", in); NDValidation.validateNumerical("argmin", "in", in);
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(in, false, dimensions)); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(in, false, dimensions))[0];
} }
/** /**

View File

@ -896,7 +896,7 @@ public class NDMath {
public INDArray iamax(INDArray in, int... dimensions) { public INDArray iamax(INDArray in, int... dimensions) {
NDValidation.validateNumerical("iamax", "in", in); NDValidation.validateNumerical("iamax", "in", in);
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(in, false, dimensions)); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(in, false, dimensions))[0];
} }
/** /**
@ -911,7 +911,7 @@ public class NDMath {
public INDArray iamax(INDArray in, boolean keepDims, int... dimensions) { public INDArray iamax(INDArray in, boolean keepDims, int... dimensions) {
NDValidation.validateNumerical("iamax", "in", in); NDValidation.validateNumerical("iamax", "in", in);
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(in, keepDims, dimensions)); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(in, keepDims, dimensions))[0];
} }
/** /**
@ -925,7 +925,7 @@ public class NDMath {
public INDArray iamin(INDArray in, int... dimensions) { public INDArray iamin(INDArray in, int... dimensions) {
NDValidation.validateNumerical("iamin", "in", in); NDValidation.validateNumerical("iamin", "in", in);
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(in, false, dimensions)); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(in, false, dimensions))[0];
} }
/** /**
@ -940,7 +940,7 @@ public class NDMath {
public INDArray iamin(INDArray in, boolean keepDims, int... dimensions) { public INDArray iamin(INDArray in, boolean keepDims, int... dimensions) {
NDValidation.validateNumerical("iamin", "in", in); NDValidation.validateNumerical("iamin", "in", in);
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(in, keepDims, dimensions)); return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(in, keepDims, dimensions))[0];
} }
/** /**

View File

@ -17469,6 +17469,60 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
} }
// #endif // #endif
/**
* This operation returns index of absolute max element in a given NDArray (optionally: along given dimension(s))
* Expected input:
* 0: N-dimensional array
* 1: optional axis vector
*
* Int args:
* 0: optional axis
*/
// #if NOT_EXCLUDED(OP_argamax)
@Namespace("sd::ops") public static class argamax extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public argamax(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public argamax(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public argamax position(long position) {
return (argamax)super.position(position);
}
public argamax() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
/**
* This operation returns index of absolute min element in a given NDArray (optionally: along given dimension(s))
* Expected input:
* 0: N-dimensional array
* 1: optional axis vector
*
* Int args:
* 0: optional axis
*/
// #if NOT_EXCLUDED(OP_argamin)
@Namespace("sd::ops") public static class argamin extends DeclarableCustomOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public argamin(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public argamin(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public argamin position(long position) {
return (argamin)super.position(position);
}
public argamin() { super((Pointer)null); allocate(); }
private native void allocate();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
}
// #endif
/** /**
* This operation provides various normalization modes: * This operation provides various normalization modes:
* 0: frobenius * 0: frobenius

View File

@ -32,8 +32,8 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax; import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgAmax;
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMin; import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgAmin;
import org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyWithLogitsLoss; import org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyWithLogitsLoss;
import org.nd4j.linalg.api.ops.impl.reduce.Moments; import org.nd4j.linalg.api.ops.impl.reduce.Moments;
import org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments; import org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments;
@ -863,12 +863,12 @@ public class ReductionOpValidation extends BaseOpValidation {
break; break;
case 2: case 2:
reduce = sd.math().iamax(s, dim); reduce = sd.math().iamax(s, dim);
exp = Nd4j.getExecutioner().exec(new IAMax(in.dup(), dim)); exp = Nd4j.getExecutioner().exec(new ArgAmax(in.dup(), dim))[0];
name = "iamax"; name = "iamax";
break; break;
case 3: case 3:
reduce = sd.math().iamin(s, dim); reduce = sd.math().iamin(s, dim);
exp = Nd4j.getExecutioner().exec(new IAMin(in.dup(), dim)); exp = Nd4j.getExecutioner().exec(new ArgAmin(in.dup(), dim))[0];
name = "iamin"; name = "iamin";
break; break;
case 4: case 4:

View File

@ -144,7 +144,7 @@ public class NameScopeTests extends BaseNd4jTest {
scope.close(); scope.close();
assertTrue("Var with name test/imax exists", SD.variableMap().containsKey("test/imax")); assertTrue("Var with name test/argmax exists", SD.variableMap().containsKey("test/argmax"));
} }
@Test @Test

View File

@ -52,10 +52,10 @@ import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastEqualTo;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThan; import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThan;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThanOrEqual; import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThanOrEqual;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastLessThan; import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastLessThan;
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax; import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgAmax;
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMin; import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgAmin;
import org.nd4j.linalg.api.ops.impl.indexaccum.IMax; import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
import org.nd4j.linalg.api.ops.impl.indexaccum.IMin; import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D;
import org.nd4j.linalg.api.ops.impl.layers.convolution.Im2col; import org.nd4j.linalg.api.ops.impl.layers.convolution.Im2col;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
@ -3765,10 +3765,10 @@ public class Nd4jTestsC extends BaseNd4jTest {
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ALL); Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ALL);
INDArray arr = Nd4j.create(new double[] {-0.24, -0.26, -0.07, -0.01}); INDArray arr = Nd4j.create(new double[] {-0.24, -0.26, -0.07, -0.01});
IMax iMax = new IMax(arr); val iMax = new ArgMax(arr);
IAMax iaMax = new IAMax(arr.dup()); val iaMax = new ArgAmax(arr.dup());
val imax = Nd4j.getExecutioner().execAndReturn(iMax).getFinalResult().intValue(); val imax = Nd4j.getExecutioner().exec(iMax)[0].getInt(0);
val iamax = Nd4j.getExecutioner().execAndReturn(iaMax).getFinalResult().intValue(); val iamax = Nd4j.getExecutioner().exec(iaMax)[0].getInt(0);
// System.out.println("IMAX: " + imax); // System.out.println("IMAX: " + imax);
// System.out.println("IAMAX: " + iamax); // System.out.println("IAMAX: " + iamax);
assertEquals(1, iamax); assertEquals(1, iamax);
@ -3780,10 +3780,10 @@ public class Nd4jTestsC extends BaseNd4jTest {
public void testIMinIAMin() { public void testIMinIAMin() {
INDArray arr = Nd4j.create(new double[] {-0.24, -0.26, -0.07, -0.01}); INDArray arr = Nd4j.create(new double[] {-0.24, -0.26, -0.07, -0.01});
INDArray abs = Transforms.abs(arr); INDArray abs = Transforms.abs(arr);
IAMin iaMin = new IAMin(abs); val iaMin = new ArgAmin(abs);
IMin iMin = new IMin(arr.dup()); val iMin = new ArgMin(arr.dup());
double imin = Nd4j.getExecutioner().execAndReturn(iMin).getFinalResult().doubleValue(); double imin = Nd4j.getExecutioner().exec(iMin)[0].getDouble(0);
double iamin = Nd4j.getExecutioner().execAndReturn(iaMin).getFinalResult().doubleValue(); double iamin = Nd4j.getExecutioner().exec(iaMin)[0].getDouble(0);
// System.out.println("IMin: " + imin); // System.out.println("IMin: " + imin);
// System.out.println("IAMin: " + iamin); // System.out.println("IAMin: " + iamin);
assertEquals(3, iamin, 1e-12); assertEquals(3, iamin, 1e-12);
@ -4077,7 +4077,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
arr.get(NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.all()).assign(Nd4j.create(slices[i])); arr.get(NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.all()).assign(Nd4j.create(slices[i]));
} }
INDArray out = Nd4j.getExecutioner().exec(new IMax(arr, 1,2)); INDArray out = Nd4j.exec(new ArgMax(arr, 1,2))[0];
assertEquals(DataType.LONG, out.dataType()); assertEquals(DataType.LONG, out.dataType());
@ -4119,8 +4119,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
} }
} }
INDArray actC = Nd4j.getExecutioner().exec(new IMax(arr.dup('c'), 0,1)); INDArray actC = Nd4j.getExecutioner().exec(new ArgMax(arr.dup('c'), 0,1))[0];
INDArray actF = Nd4j.getExecutioner().exec(new IMax(arr.dup('f'), 0,1)); INDArray actF = Nd4j.getExecutioner().exec(new ArgMax(arr.dup('f'), 0,1))[0];
// //
assertEquals(exp, actC); assertEquals(exp, actC);
assertEquals(exp, actF); assertEquals(exp, actF);
@ -4153,8 +4153,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
} }
} }
actC = Nd4j.getExecutioner().exec(new IMax(arr.dup('c'), 2, 3)); actC = Nd4j.getExecutioner().exec(new ArgMax(arr.dup('c'), 2, 3))[0];
actF = Nd4j.getExecutioner().exec(new IMax(arr.dup('f'), 2, 3)); actF = Nd4j.getExecutioner().exec(new ArgMax(arr.dup('f'), 2, 3))[0];
assertEquals(exp, actC); assertEquals(exp, actC);
assertEquals(exp, actF); assertEquals(exp, actF);

View File

@ -25,7 +25,7 @@ import org.junit.runners.Parameterized;
import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.impl.indexaccum.IMax; import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance; import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax; import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
@ -122,7 +122,7 @@ public class CrashTest extends BaseNd4jTest {
float sum = x.sumNumber().floatValue(); float sum = x.sumNumber().floatValue();
// index reduction // index reduction
Nd4j.getExecutioner().exec(new IMax(x)); Nd4j.getExecutioner().exec(new ArgMax(x));
// casual transform // casual transform
Nd4j.getExecutioner().exec(new Sqrt(x, x)); Nd4j.getExecutioner().exec(new Sqrt(x, x));

View File

@ -26,9 +26,9 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax; import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgAmax;
import org.nd4j.linalg.api.ops.impl.indexaccum.IMax; import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
import org.nd4j.linalg.api.ops.impl.indexaccum.IMin; import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin;
import org.nd4j.linalg.api.ops.impl.reduce.floating.Mean; import org.nd4j.linalg.api.ops.impl.reduce.floating.Mean;
import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2; import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2;
import org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax; import org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax;
@ -282,9 +282,9 @@ public class OpExecutionerTests extends BaseNd4jTest {
public void testIamax2() { public void testIamax2() {
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
assertEquals(getFailureMessage(), 3, Nd4j.getBlasWrapper().iamax(linspace)); assertEquals(getFailureMessage(), 3, Nd4j.getBlasWrapper().iamax(linspace));
val op = new IAMax(linspace); val op = new ArgAmax(linspace);
int iamax = Nd4j.getExecutioner().execAndReturn(op).getFinalResult().intValue(); int iamax = Nd4j.getExecutioner().exec(op)[0].getInt(0);
assertEquals(3, iamax); assertEquals(3, iamax);
} }
@ -565,24 +565,24 @@ public class OpExecutionerTests extends BaseNd4jTest {
@Test @Test
public void testIMax() { public void testIMax() {
INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE); INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE);
IMax imax = new IMax(arr); ArgMax imax = new ArgMax(arr);
assertEquals(9, Nd4j.getExecutioner().execAndReturn(imax).getFinalResult().intValue()); assertEquals(9, Nd4j.getExecutioner().exec(imax)[0].getInt(0));
arr.muli(-1); arr.muli(-1);
imax = new IMax(arr); imax = new ArgMax(arr);
int maxIdx = Nd4j.getExecutioner().execAndReturn(imax).getFinalResult().intValue(); int maxIdx = Nd4j.getExecutioner().exec(imax)[0].getInt(0);
assertEquals(0, maxIdx); assertEquals(0, maxIdx);
} }
@Test @Test
public void testIMin() { public void testIMin() {
INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE); INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE);
IMin imin = new IMin(arr); ArgMin imin = new ArgMin(arr);
assertEquals(0, Nd4j.getExecutioner().execAndReturn(imin).getFinalResult().intValue()); assertEquals(0, Nd4j.getExecutioner().exec(imin)[0].getInt(0));
arr.muli(-1); arr.muli(-1);
imin = new IMin(arr); imin = new ArgMin(arr);
int minIdx = Nd4j.getExecutioner().execAndReturn(imin).getFinalResult().intValue(); int minIdx = Nd4j.getExecutioner().exec(imin)[0].getInt(0);
assertEquals(9, minIdx); assertEquals(9, minIdx);
} }

View File

@ -32,8 +32,8 @@ import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp; import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.ops.impl.indexaccum.IMax; import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
import org.nd4j.linalg.api.ops.impl.indexaccum.IMin; import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin;
import org.nd4j.linalg.api.ops.impl.reduce.floating.Mean; import org.nd4j.linalg.api.ops.impl.reduce.floating.Mean;
import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2; import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2;
import org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax; import org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax;
@ -478,24 +478,24 @@ public class OpExecutionerTestsC extends BaseNd4jTest {
@Test @Test
public void testIMax() { public void testIMax() {
INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE); INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE);
IMax imax = new IMax(arr); ArgMax imax = new ArgMax(arr);
assertEquals(9, Nd4j.getExecutioner().execAndReturn(imax).getFinalResult().intValue()); assertEquals(9, Nd4j.getExecutioner().exec(imax)[0].getInt(0));
arr.muli(-1); arr.muli(-1);
imax = new IMax(arr); imax = new ArgMax(arr);
int maxIdx = Nd4j.getExecutioner().execAndReturn(imax).getFinalResult().intValue(); int maxIdx = Nd4j.getExecutioner().exec(imax)[0].getInt(0);
assertEquals(0, maxIdx); assertEquals(0, maxIdx);
} }
@Test @Test
public void testIMin() { public void testIMin() {
INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE); INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE);
IMin imin = new IMin(arr); ArgMin imin = new ArgMin(arr);
assertEquals(0, Nd4j.getExecutioner().execAndReturn(imin).getFinalResult().intValue()); assertEquals(0, Nd4j.getExecutioner().exec(imin)[0].getInt(0));
arr.muli(-1); arr.muli(-1);
imin = new IMin(arr); imin = new ArgMin(arr);
int minIdx = Nd4j.getExecutioner().execAndReturn(imin).getFinalResult().intValue(); int minIdx = Nd4j.getExecutioner().exec(imin)[0].getInt(0);
assertEquals(9, minIdx); assertEquals(9, minIdx);
} }

View File

@ -26,6 +26,7 @@ import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.reduce.bool.All; import org.nd4j.linalg.api.ops.impl.reduce.bool.All;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.factory.Nd4jBackend;
@ -234,7 +235,7 @@ public class EmptyTests extends BaseNd4jTest {
assertEquals(e, reduced); assertEquals(e, reduced);
} }
@Test(expected = IllegalArgumentException.class) @Test(expected = ND4JIllegalStateException.class)
public void testEmptyReduction_4() { public void testEmptyReduction_4() {
val x = Nd4j.create(DataType.FLOAT, 2, 0); val x = Nd4j.create(DataType.FLOAT, 2, 0);
val e = Nd4j.create(DataType.FLOAT, 0); val e = Nd4j.create(DataType.FLOAT, 0);