- new implementations for Index Reductions (#421)
* - new implementations for Index Reductions - small fix in the legacy reduction - disabled index reduction bench tests inside Playground Signed-off-by: Abdelrauf <rauf@konduit.ai> * Allow LIBND4J_TYPES Signed-off-by: Abdelrauf <rauf@konduit.ai> * index reduction stuff split into bunch of units * meh * IMax switched to new impl Signed-off-by: raver119@gmail.com <raver119@gmail.com> * minor fix + test * minor fix * index range fix Signed-off-by: Abdelrauf <rauf@konduit.ai> * noop on empty outputs * minor fix * minor fix Signed-off-by: Abdelrauf <rauf@konduit.ai> * ArgMax replaces IMax Signed-off-by: raver119@gmail.com <raver119@gmail.com> * argmax/argmin/argamax/argamin shape functions updated * ArgAmax/ArgAmin/ArgMin replaces IAMax/IAMin/IMin Signed-off-by: raver119@gmail.com <raver119@gmail.com> * argmax/argmin/argamax/argamin CUDA * IMax replaced in dl4j Signed-off-by: raver119@gmail.com <raver119@gmail.com> * Codegen output * imports fixed Signed-off-by: raver119@gmail.com <raver119@gmail.com> * fix compilation issue Signed-off-by: Abdelrauf <rauf@konduit.ai> * Auto-generate compilation units Signed-off-by: Abdelrauf <rauf@konduit.ai> * Should fix NDArray refactored function calls in indexReductions.cu Signed-off-by: Abdelrauf <rauf@konduit.ai> Co-authored-by: raver119@gmail.com <raver119@gmail.com> Co-authored-by: Alexander Stoyakin <alexander.stoyakin@gmail.com>master
parent
62e9dc83e0
commit
69d91e272a
|
@ -20,7 +20,7 @@ import org.deeplearning4j.clustering.algorithm.Distance;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.ReduceOp;
|
import org.nd4j.linalg.api.ops.ReduceOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.indexaccum.IMin;
|
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.common.primitives.Pair;
|
import org.nd4j.common.primitives.Pair;
|
||||||
|
|
||||||
|
@ -29,7 +29,7 @@ public class CentersHolder {
|
||||||
private long index = 0;
|
private long index = 0;
|
||||||
|
|
||||||
protected transient ReduceOp op;
|
protected transient ReduceOp op;
|
||||||
protected IMin imin;
|
protected ArgMin imin;
|
||||||
protected transient INDArray distances;
|
protected transient INDArray distances;
|
||||||
protected transient INDArray argMin;
|
protected transient INDArray argMin;
|
||||||
|
|
||||||
|
@ -60,7 +60,7 @@ public class CentersHolder {
|
||||||
|
|
||||||
if (op == null) {
|
if (op == null) {
|
||||||
op = ClusterUtils.createDistanceFunctionOp(distanceFunction, centers, point.getArray(), 1);
|
op = ClusterUtils.createDistanceFunctionOp(distanceFunction, centers, point.getArray(), 1);
|
||||||
imin = new IMin(distances, argMin);
|
imin = new ArgMin(distances, argMin);
|
||||||
op.setZ(distances);
|
op.setZ(distances);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -84,7 +84,7 @@ public class CentersHolder {
|
||||||
|
|
||||||
if (op == null) {
|
if (op == null) {
|
||||||
op = ClusterUtils.createDistanceFunctionOp(distanceFunction, centers, point.getArray(), 1);
|
op = ClusterUtils.createDistanceFunctionOp(distanceFunction, centers, point.getArray(), 1);
|
||||||
imin = new IMin(distances, argMin);
|
imin = new ArgMin(distances, argMin);
|
||||||
op.setZ(distances);
|
op.setZ(distances);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,7 @@ import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.junit.Rule;
|
import org.junit.Rule;
|
||||||
import org.junit.rules.TemporaryFolder;
|
import org.junit.rules.TemporaryFolder;
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
import org.nd4j.common.io.ClassPathResource;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
|
||||||
import org.deeplearning4j.models.word2vec.VocabWord;
|
import org.deeplearning4j.models.word2vec.VocabWord;
|
||||||
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
import org.deeplearning4j.models.word2vec.wordstore.VocabCache;
|
||||||
import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareFileSentenceIterator;
|
import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareFileSentenceIterator;
|
||||||
|
@ -31,7 +32,6 @@ import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFac
|
||||||
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.impl.indexaccum.IMax;
|
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
import org.nd4j.linalg.dataset.DataSet;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.common.util.SerializationUtils;
|
import org.nd4j.common.util.SerializationUtils;
|
||||||
|
@ -111,7 +111,7 @@ public class BagOfWordsVectorizerTest extends BaseDL4JTest {
|
||||||
INDArray labelz = dataSet.getLabels();
|
INDArray labelz = dataSet.getLabels();
|
||||||
log.info("Labels array: " + labelz);
|
log.info("Labels array: " + labelz);
|
||||||
|
|
||||||
int idx2 = Nd4j.getExecutioner().exec(new IMax(labelz)).getInt(0);
|
int idx2 = Nd4j.getExecutioner().exec(new ArgMax(labelz))[0].getInt(0);
|
||||||
//int idx2 = ((IndexAccumulation) Nd4j.getExecutioner().exec(new IMax(labelz))).getFinalResult().intValue();
|
//int idx2 = ((IndexAccumulation) Nd4j.getExecutioner().exec(new IMax(labelz))).getFinalResult().intValue();
|
||||||
|
|
||||||
// assertEquals(1.0, dataSet.getLabels().getDouble(0), 0.1);
|
// assertEquals(1.0, dataSet.getLabels().getDouble(0), 0.1);
|
||||||
|
@ -125,7 +125,7 @@ public class BagOfWordsVectorizerTest extends BaseDL4JTest {
|
||||||
assertEquals(1, dataSet.getFeatures().getDouble(vocabCache.tokenFor("1").getIndex()), 0.1);
|
assertEquals(1, dataSet.getFeatures().getDouble(vocabCache.tokenFor("1").getIndex()), 0.1);
|
||||||
assertEquals(0, dataSet.getFeatures().getDouble(vocabCache.tokenFor("2").getIndex()), 0.1);
|
assertEquals(0, dataSet.getFeatures().getDouble(vocabCache.tokenFor("2").getIndex()), 0.1);
|
||||||
|
|
||||||
int idx1 = Nd4j.getExecutioner().exec(new IMax(dataSet.getLabels())).getInt(0);
|
int idx1 = Nd4j.getExecutioner().exec(new ArgMax(dataSet.getLabels()))[0].getInt(0);
|
||||||
//int idx1 = ((IndexAccumulation) Nd4j.getExecutioner().exec(new IMax(dataSet.getLabels()))).getFinalResult().intValue();
|
//int idx1 = ((IndexAccumulation) Nd4j.getExecutioner().exec(new IMax(dataSet.getLabels()))).getFinalResult().intValue();
|
||||||
|
|
||||||
//assertEquals(0.0, dataSet.getLabels().getDouble(0), 0.1);
|
//assertEquals(0.0, dataSet.getLabels().getDouble(0), 0.1);
|
||||||
|
|
|
@ -294,12 +294,26 @@ elseif(SD_CPU)
|
||||||
file(GLOB_RECURSE LEGACY_SOURCES false ../include/legacy/impl/*.cpp ../include/legacy/cpu/*.cpp ../include/legacy/*.h)
|
file(GLOB_RECURSE LEGACY_SOURCES false ../include/legacy/impl/*.cpp ../include/legacy/cpu/*.cpp ../include/legacy/*.h)
|
||||||
file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/*.cpp ../include/loops/*.h)
|
file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/*.cpp ../include/loops/*.h)
|
||||||
|
|
||||||
|
|
||||||
|
file(GLOB_RECURSE COMPILATION_UNITS false ../include/ops/declarable/helpers/cpu/compilation_units/*.cpp.in)
|
||||||
|
foreach(FL_ITEM ${COMPILATION_UNITS})
|
||||||
|
string(REGEX MATCH "^(.*)\\.cpp\.in$" dummy ${FL_ITEM})
|
||||||
|
set(FL_ITEM_WLE ${CMAKE_MATCH_1})
|
||||||
|
foreach(FL_TYPE_INDEX RANGE 0 9)
|
||||||
|
message( "${FL_ITEM_WLE}_${FL_TYPE_INDEX}.cpp")
|
||||||
|
configure_file( "${FL_ITEM}" "${FL_ITEM_WLE}_${FL_TYPE_INDEX}.cpp" @ONLY)
|
||||||
|
LIST(APPEND CUSTOMOPS_GENERIC_SOURCES ${FL_ITEM_WLE}_${FL_TYPE_INDEX}.cpp )
|
||||||
|
endforeach()
|
||||||
|
endforeach()
|
||||||
|
|
||||||
if (SD_X86_BUILD)
|
if (SD_X86_BUILD)
|
||||||
# we disable platform optimizations for certains files for linux/macos
|
# we disable platform optimizations for certains files for linux/macos
|
||||||
set_source_files_properties(cpu/NativeOps.cpp PROPERTIES COMPILE_FLAGS "-march=x86-64 -mtune=generic")
|
set_source_files_properties(cpu/NativeOps.cpp PROPERTIES COMPILE_FLAGS "-march=x86-64 -mtune=generic")
|
||||||
set_source_files_properties(../include/helpers/impl/OpTracker.cpp PROPERTIES COMPILE_FLAGS "-march=x86-64 -mtune=generic")
|
set_source_files_properties(../include/helpers/impl/OpTracker.cpp PROPERTIES COMPILE_FLAGS "-march=x86-64 -mtune=generic")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if(SD_CHECK_VECTORIZATION)
|
if(SD_CHECK_VECTORIZATION)
|
||||||
set(VECT_FILES cpu/NativeOps.cpp ${OPS_SOURCES} ${HELPERS_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${LOOPS_SOURCES})
|
set(VECT_FILES cpu/NativeOps.cpp ${OPS_SOURCES} ${HELPERS_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${LOOPS_SOURCES})
|
||||||
if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
|
if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU")
|
||||||
|
|
|
@ -19,12 +19,13 @@
|
||||||
//
|
//
|
||||||
#ifndef LIBND4J_LOOPCOORDSHELPER_H
|
#ifndef LIBND4J_LOOPCOORDSHELPER_H
|
||||||
#define LIBND4J_LOOPCOORDSHELPER_H
|
#define LIBND4J_LOOPCOORDSHELPER_H
|
||||||
|
#include <vector>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <type_traits>
|
#include <type_traits>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <system/pointercast.h>
|
#include <system/pointercast.h>
|
||||||
#include <system/op_boilerplate.h>
|
#include <system/op_boilerplate.h>
|
||||||
|
#include <helpers/shape.h>
|
||||||
namespace sd {
|
namespace sd {
|
||||||
|
|
||||||
#if defined(__GNUC__)
|
#if defined(__GNUC__)
|
||||||
|
@ -125,7 +126,7 @@ namespace sd {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
FORCEINLINE zip_size_t offset_from_coords(const Nd4jLong*& x_strides, const Nd4jLong*& z_strides, const Nd4jLong* coords, const Nd4jLong& rank) {
|
FORCEINLINE zip_size_t offset_from_coords(const Nd4jLong* x_strides, const Nd4jLong* z_strides, const Nd4jLong* coords, const Nd4jLong& rank) {
|
||||||
|
|
||||||
zip_size_t offset = { 0,0 };
|
zip_size_t offset = { 0,0 };
|
||||||
size_t rank_4 = rank & -4;
|
size_t rank_4 = rank & -4;
|
||||||
|
@ -435,6 +436,509 @@ namespace sd {
|
||||||
return last_offset;
|
return last_offset;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
struct triple_size_t {
|
||||||
|
size_t first;
|
||||||
|
size_t second;
|
||||||
|
size_t third;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
template<bool Last_Index_Faster = true>
|
||||||
|
FORCEINLINE triple_size_t inc_coords(const Nd4jLong* bases, const Nd4jLong* x_strides, const Nd4jLong* y_strides, const Nd4jLong* z_strides, Nd4jLong* coords, triple_size_t last_offset, const size_t rank, const size_t skip = 0) {
|
||||||
|
|
||||||
|
Nd4jLong val = 0;
|
||||||
|
for (int i = rank - skip - 1; i >= 0; i--) {
|
||||||
|
val = coords[i] + 1;
|
||||||
|
if (likely(val < bases[i])) {
|
||||||
|
coords[i] = val;
|
||||||
|
last_offset.first += x_strides[i];
|
||||||
|
last_offset.second += y_strides[i];
|
||||||
|
last_offset.third += z_strides[i];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
last_offset.first -= coords[i] * x_strides[i];
|
||||||
|
last_offset.second -= coords[i] * y_strides[i];
|
||||||
|
last_offset.third -= coords[i] * z_strides[i];
|
||||||
|
coords[i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return last_offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<>
|
||||||
|
FORCEINLINE triple_size_t inc_coords<false>(const Nd4jLong* bases, const Nd4jLong* x_strides, const Nd4jLong* y_strides, const Nd4jLong* z_strides, Nd4jLong* coords, triple_size_t last_offset, const size_t rank, const size_t skip) {
|
||||||
|
|
||||||
|
Nd4jLong val = 0;
|
||||||
|
for (int i = skip; i < rank; i++) {
|
||||||
|
val = coords[i] + 1;
|
||||||
|
if (likely(val < bases[i])) {
|
||||||
|
coords[i] = val;
|
||||||
|
|
||||||
|
last_offset.first += x_strides[i];
|
||||||
|
last_offset.second += y_strides[i];
|
||||||
|
last_offset.third += z_strides[i];
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
last_offset.first -= coords[i] * x_strides[i];
|
||||||
|
last_offset.second -= coords[i] * y_strides[i];
|
||||||
|
last_offset.third -= coords[i] * z_strides[i];
|
||||||
|
coords[i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return last_offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
FORCEINLINE triple_size_t offset_from_coords(const Nd4jLong* x_strides, const Nd4jLong* y_strides, const Nd4jLong* z_strides, const Nd4jLong* coords, const Nd4jLong& rank) {
|
||||||
|
|
||||||
|
triple_size_t offset = { 0,0 ,0 };
|
||||||
|
size_t rank_4 = rank & -4;
|
||||||
|
for (int i = 0; i < rank_4; i += 4) {
|
||||||
|
offset.first = offset.first
|
||||||
|
+ coords[i] * x_strides[i]
|
||||||
|
+ coords[i + 1] * x_strides[i + 1]
|
||||||
|
+ coords[i + 2] * x_strides[i + 2]
|
||||||
|
+ coords[i + 3] * x_strides[i + 3];
|
||||||
|
offset.second = offset.second
|
||||||
|
+ coords[i] * y_strides[i]
|
||||||
|
+ coords[i + 1] * y_strides[i + 1]
|
||||||
|
+ coords[i + 2] * y_strides[i + 2]
|
||||||
|
+ coords[i + 3] * y_strides[i + 3];
|
||||||
|
offset.third = offset.third
|
||||||
|
+ coords[i] * z_strides[i]
|
||||||
|
+ coords[i + 1] * z_strides[i + 1]
|
||||||
|
+ coords[i + 2] * z_strides[i + 2]
|
||||||
|
+ coords[i + 3] * z_strides[i + 3];
|
||||||
|
}
|
||||||
|
for (int i = rank_4; i < rank; i++) {
|
||||||
|
offset.first += coords[i] * x_strides[i];
|
||||||
|
offset.second += coords[i] * y_strides[i];
|
||||||
|
offset.third += coords[i] * z_strides[i];
|
||||||
|
}
|
||||||
|
return offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template<bool Last_Index_Faster = true>
|
||||||
|
FORCEINLINE Nd4jLong getLength(const Nd4jLong* bases, int rank, int skip = 0)
|
||||||
|
{
|
||||||
|
if (skip < 0 || skip >= rank) skip = 0;
|
||||||
|
Nd4jLong total = 1;
|
||||||
|
for (int i = 0; i < rank - skip; i++) {
|
||||||
|
total *= bases[i];
|
||||||
|
}
|
||||||
|
return total;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template<>
|
||||||
|
FORCEINLINE Nd4jLong getLength<false>(const Nd4jLong* bases, int rank, int skip)
|
||||||
|
{
|
||||||
|
if (skip < 0 || skip >= rank) skip = 0;
|
||||||
|
Nd4jLong total = 1;
|
||||||
|
for (int i = skip; i < rank; i++) {
|
||||||
|
total *= bases[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
return total;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template<bool Last_Index_Faster = true>
|
||||||
|
FORCEINLINE Nd4jLong getLength(const Nd4jLong* bases, int rank, int skip, Nd4jLong& outSkippedLength)
|
||||||
|
{
|
||||||
|
if (skip < 0 || skip >= rank) skip = 0;
|
||||||
|
Nd4jLong total = 1;
|
||||||
|
for (int i = 0; i < rank - skip; i++) {
|
||||||
|
total *= bases[i];
|
||||||
|
}
|
||||||
|
if (skip > 0) {
|
||||||
|
outSkippedLength = 1;
|
||||||
|
for (int i = rank - skip; i < rank; i++) {
|
||||||
|
outSkippedLength *= bases[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
outSkippedLength = 0;
|
||||||
|
}
|
||||||
|
return total;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template<>
|
||||||
|
FORCEINLINE Nd4jLong getLength<false>(const Nd4jLong* bases, int rank, int skip, Nd4jLong& outSkippedLength)
|
||||||
|
{
|
||||||
|
if (skip < 0 || skip >= rank) skip = 0;
|
||||||
|
if (skip > 0) {
|
||||||
|
outSkippedLength = 1;
|
||||||
|
for (int i = 0; i < skip; i++) {
|
||||||
|
outSkippedLength *= bases[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
outSkippedLength = 0;
|
||||||
|
}
|
||||||
|
Nd4jLong total = 1;
|
||||||
|
for (int i = skip; i < rank; i++) {
|
||||||
|
total *= bases[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
return total;
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
for ODR rule it willbe declared as inline
|
||||||
|
rePartition for reductions and et cet
|
||||||
|
Indices mentioned in the dimension list will be moved to the tail
|
||||||
|
This way it will be splitted into two parts
|
||||||
|
the first part will contain output part,the second tail part will be used for reductions and other purposes
|
||||||
|
if squash is True then it will attempt to minimize the output ( for both orders) and the tail
|
||||||
|
*/
|
||||||
|
|
||||||
|
FORCEINLINE void rePartition(char order, const std::vector<int>& dimensions, const size_t rank, const Nd4jLong* bases, const Nd4jLong* strides, Nd4jLong(&new_bases)[MAX_RANK], Nd4jLong(&new_strides)[MAX_RANK], int& first_begin, int& first_end, int& second_begin, int& second_end, bool first_squash = false, bool second_squash = true) {
|
||||||
|
|
||||||
|
bool indices[MAX_RANK] = {};
|
||||||
|
int ind = 0;
|
||||||
|
size_t second_rank;
|
||||||
|
if (dimensions.size() == 0 || (dimensions.size() == 1 && dimensions.at(0) == sd::DataTypeUtils::max<int>())){
|
||||||
|
first_end = 0;
|
||||||
|
first_begin = 0;
|
||||||
|
//treat it as the whole
|
||||||
|
for (int i = 0; i < rank; i++) {
|
||||||
|
new_bases[i] = bases[i];
|
||||||
|
new_strides[i] = strides[i];
|
||||||
|
}
|
||||||
|
second_rank = rank;
|
||||||
|
second_end = rank;
|
||||||
|
second_begin = 0;
|
||||||
|
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for (int index : dimensions) {
|
||||||
|
if (index < 0) index = rank + index;
|
||||||
|
if (index >= 0 && index < rank) {
|
||||||
|
indices[index] = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
//move output ones and
|
||||||
|
for (int i = 0; i < rank; i++) {
|
||||||
|
|
||||||
|
if (!indices[i]) {
|
||||||
|
|
||||||
|
new_bases[ind] = bases[i];
|
||||||
|
new_strides[ind] = strides[i];
|
||||||
|
ind++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
int first_rank = ind;
|
||||||
|
|
||||||
|
first_end = ind;
|
||||||
|
first_begin = 0;
|
||||||
|
//nd4j_printf("rffrr ss & %d ind-- %d %d\n", first_rank, first_begin, first_end);
|
||||||
|
//squash output rank
|
||||||
|
if (first_squash && first_rank > 1) {
|
||||||
|
|
||||||
|
if (order == 'c') {
|
||||||
|
int uniq_ind = first_end-1;
|
||||||
|
for (int i = first_end - 2; i >= first_begin; i--) {
|
||||||
|
if (new_strides[i] == new_bases[uniq_ind] * new_strides[uniq_ind]) {
|
||||||
|
new_bases[uniq_ind] = new_bases[i] * new_bases[uniq_ind];
|
||||||
|
new_strides[uniq_ind] = new_strides[uniq_ind];
|
||||||
|
--first_rank;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
--uniq_ind;
|
||||||
|
new_bases[uniq_ind] = new_bases[i];
|
||||||
|
new_strides[uniq_ind] = new_strides[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
first_begin = first_end - first_rank;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
//squash fortran
|
||||||
|
int uniq_ind = 0;
|
||||||
|
for (int i = 1; i < first_end; i++) {
|
||||||
|
if (new_strides[i] == new_bases[uniq_ind] * new_strides[uniq_ind]) {
|
||||||
|
new_bases[uniq_ind] = new_bases[i] * new_bases[uniq_ind];
|
||||||
|
new_strides[uniq_ind] = new_strides[uniq_ind];
|
||||||
|
--first_rank;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
uniq_ind++;
|
||||||
|
new_bases[uniq_ind] = new_bases[i];
|
||||||
|
new_strides[uniq_ind] = new_strides[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
first_end = first_begin + first_rank;
|
||||||
|
|
||||||
|
}
|
||||||
|
ind = first_end;
|
||||||
|
}
|
||||||
|
|
||||||
|
//nd4j_printf("rffrr ss & %d ind-- %d %d\n", first_rank, first_begin, first_end);
|
||||||
|
//move process indices
|
||||||
|
for (int i = 0; i < rank; i++) {
|
||||||
|
if (indices[i]) {
|
||||||
|
new_bases[ind] = bases[i];
|
||||||
|
new_strides[ind] = strides[i];
|
||||||
|
ind++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
second_rank = ind - first_end;
|
||||||
|
second_end = ind;
|
||||||
|
second_begin = first_end;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
if (second_squash && second_rank > 1) {
|
||||||
|
|
||||||
|
if (order == 'c') {
|
||||||
|
int uniq_ind = second_end - 1;
|
||||||
|
for (int i = second_end - 2; i >= second_begin; i--) {
|
||||||
|
if (new_strides[i] == new_bases[uniq_ind] * new_strides[uniq_ind]) {
|
||||||
|
new_bases[uniq_ind] = new_bases[i] * new_bases[uniq_ind];
|
||||||
|
new_strides[uniq_ind] = new_strides[uniq_ind];
|
||||||
|
--second_rank;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
--uniq_ind;
|
||||||
|
new_bases[uniq_ind] = new_bases[i];
|
||||||
|
new_strides[uniq_ind] = new_strides[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
second_begin = second_end - second_rank;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
int uniq_ind = second_begin;
|
||||||
|
for (int i = second_begin+1; i < second_end; i++) {
|
||||||
|
if (new_strides[i] == new_bases[uniq_ind] * new_strides[uniq_ind]) {
|
||||||
|
new_bases[uniq_ind] = new_bases[i] * new_bases[uniq_ind];
|
||||||
|
new_strides[uniq_ind] = new_strides[uniq_ind];
|
||||||
|
--second_rank;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
uniq_ind++;
|
||||||
|
new_bases[uniq_ind] = new_bases[i];
|
||||||
|
new_strides[uniq_ind] = new_strides[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
second_end = second_begin + second_rank;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
//basic CRTP static polymorphism classes for offset increments
|
||||||
|
|
||||||
|
template<typename Derived>
|
||||||
|
struct CoordsBaseMovement {
|
||||||
|
void init(const Nd4jLong* bases, const Nd4jLong* strides1, const Nd4jLong* strides2, int rank, int start = 0) {
|
||||||
|
static_cast<Derived*>(this)->initImpl(bases, strides1, strides2, rank, start);
|
||||||
|
}
|
||||||
|
|
||||||
|
void increment(int skipRank = 0) {
|
||||||
|
static_cast<Derived*>(this)->incrementImpl(skipRank);
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong First() { return static_cast<Derived*>(this)->FirstImpl(); };
|
||||||
|
Nd4jLong Second() { return static_cast<Derived*>(this)->SecondImpl(); };
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
struct ZipGenericCoordsRank1Stride1 : CoordsBaseMovement<ZipGenericCoordsRank1Stride1> {
|
||||||
|
|
||||||
|
size_t offset1;
|
||||||
|
size_t offset2;
|
||||||
|
|
||||||
|
|
||||||
|
void initImpl(const Nd4jLong* bases, const Nd4jLong* strides1, const Nd4jLong* strides2, int rank, int start = 0) {
|
||||||
|
offset1 = start;
|
||||||
|
offset2 = start;
|
||||||
|
}
|
||||||
|
|
||||||
|
void incrementImpl(int skipRank = 0) {
|
||||||
|
offset1 += 1;
|
||||||
|
offset2 += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong FirstImpl() { return offset1; };
|
||||||
|
Nd4jLong SecondImpl() { return offset2; };
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
struct ZipGenericCoordsRank1BothStrideN : CoordsBaseMovement<ZipGenericCoordsRank1BothStrideN> {
|
||||||
|
size_t stride1;
|
||||||
|
size_t stride2;
|
||||||
|
size_t offset1;
|
||||||
|
size_t offset2;
|
||||||
|
|
||||||
|
|
||||||
|
void initImpl(const Nd4jLong* bases, const Nd4jLong* strides1, const Nd4jLong* strides2, int rank, int start = 0) {
|
||||||
|
stride1 = strides1[0];
|
||||||
|
stride2 = strides2[0];
|
||||||
|
offset1 = start * stride1;
|
||||||
|
offset2 = start * stride2;
|
||||||
|
}
|
||||||
|
|
||||||
|
void incrementImpl(int skipRank = 0) {
|
||||||
|
offset1 += stride1;
|
||||||
|
offset2 += stride2;
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong FirstImpl() { return offset1; };
|
||||||
|
Nd4jLong SecondImpl() { return offset2; };
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
template<int ConstRank, bool LastIndexFaster = true>
|
||||||
|
struct ZipGenericCoordsConstMovementSecondStride1 : CoordsBaseMovement<ZipGenericCoordsConstMovementSecondStride1<ConstRank, LastIndexFaster>> {
|
||||||
|
sd::CoordsState<ConstRank - 1> cst;
|
||||||
|
Nd4jLong coords[MAX_RANK];
|
||||||
|
size_t offset1;
|
||||||
|
size_t offset2;
|
||||||
|
int _rank;
|
||||||
|
|
||||||
|
void initImpl(const Nd4jLong* bases, const Nd4jLong* strides1, const Nd4jLong* strides2, int rank, int start = 0) {
|
||||||
|
offset1 = sd::init_coords<ConstRank, 0, LastIndexFaster>(cst, start, bases, strides1);
|
||||||
|
offset2 = start * 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
void incrementImpl(int skipRank = 0) {
|
||||||
|
offset1 = sd::inc_coords<ConstRank, 0, LastIndexFaster>(cst, offset1);
|
||||||
|
offset2 += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong FirstImpl() { return offset1; };
|
||||||
|
Nd4jLong SecondImpl() { return offset2; };
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
template<int ConstRank, bool LastIndexFaster = true>
|
||||||
|
struct ZipGenericCoordsConstMovementSecondStrideN : CoordsBaseMovement<ZipGenericCoordsConstMovementSecondStrideN<ConstRank, LastIndexFaster>> {
|
||||||
|
sd::CoordsState<ConstRank - 1> cst;
|
||||||
|
Nd4jLong _stride2;
|
||||||
|
Nd4jLong coords[MAX_RANK];
|
||||||
|
size_t offset1;
|
||||||
|
size_t offset2;
|
||||||
|
int _rank;
|
||||||
|
|
||||||
|
void initImpl(const Nd4jLong* bases, const Nd4jLong* strides1, const Nd4jLong* strides2, int rank, int start = 0) {
|
||||||
|
_stride2 = strides2[0];
|
||||||
|
offset1 = sd::init_coords<ConstRank, 0, LastIndexFaster>(cst, start, bases, strides1);
|
||||||
|
offset2 = start * _stride2;
|
||||||
|
}
|
||||||
|
|
||||||
|
void incrementImpl(int skipRank = 0) {
|
||||||
|
offset1 = sd::inc_coords<ConstRank, 0, LastIndexFaster>(cst, offset1);
|
||||||
|
offset2 += _stride2;
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong FirstImpl() { return offset1; };
|
||||||
|
Nd4jLong SecondImpl() { return offset2; };
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
template<bool LastIndexFaster = true>
|
||||||
|
struct ZipGenericCoordsMovementSecondStrideN : CoordsBaseMovement<ZipGenericCoordsMovementSecondStrideN<LastIndexFaster>> {
|
||||||
|
const Nd4jLong* _bases;
|
||||||
|
const Nd4jLong* _strides1;
|
||||||
|
Nd4jLong _stride2;
|
||||||
|
Nd4jLong coords[MAX_RANK];
|
||||||
|
zip_size_t offset;
|
||||||
|
int _rank;
|
||||||
|
|
||||||
|
void initImpl(const Nd4jLong* bases, const Nd4jLong* strides1, const Nd4jLong* strides2, int rank, int start = 0) {
|
||||||
|
|
||||||
|
_bases = bases;
|
||||||
|
_strides1 = strides1;
|
||||||
|
_stride2 = strides2[0];
|
||||||
|
_rank = rank;
|
||||||
|
if (start == 0) {
|
||||||
|
for (int i = 0; i < MAX_RANK; i++) {
|
||||||
|
coords[i] = 0;
|
||||||
|
}
|
||||||
|
offset = { 0,0 };
|
||||||
|
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
if (LastIndexFaster) {
|
||||||
|
sd::index2coords_C(start, rank, bases, (Nd4jLong*)&coords);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
sd::index2coords_F(start, rank, bases, (Nd4jLong*)&coords);
|
||||||
|
}
|
||||||
|
offset.first = sd::offset_from_coords(strides1, (Nd4jLong*)&coords, rank);
|
||||||
|
offset.second = start * _stride2;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
void incrementImpl(int skipRank = 0) {
|
||||||
|
offset.first = inc_coords<LastIndexFaster>(_bases, _strides1, (Nd4jLong*)&coords, offset.first, _rank, skipRank);
|
||||||
|
offset.second += _stride2;
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong FirstImpl() { return offset.first; };
|
||||||
|
Nd4jLong SecondImpl() { return offset.second; };
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
|
template<bool LastIndexFaster = true>
|
||||||
|
struct ZipGenericCoordsMovement : CoordsBaseMovement<ZipGenericCoordsMovement<LastIndexFaster>> {
|
||||||
|
const Nd4jLong* _bases;
|
||||||
|
const Nd4jLong* _strides1;
|
||||||
|
const Nd4jLong* _strides2;
|
||||||
|
Nd4jLong coords[MAX_RANK];
|
||||||
|
zip_size_t offset;
|
||||||
|
int _rank;
|
||||||
|
|
||||||
|
void initImpl(const Nd4jLong* bases, const Nd4jLong* strides1, const Nd4jLong* strides2, int rank, int start = 0) {
|
||||||
|
|
||||||
|
_bases = bases;
|
||||||
|
_strides1 = strides1;
|
||||||
|
_strides2 = strides2;
|
||||||
|
_rank = rank;
|
||||||
|
if (start == 0) {
|
||||||
|
for (int i = 0; i < MAX_RANK; i++) {
|
||||||
|
coords[i] = 0;
|
||||||
|
}
|
||||||
|
offset = { 0,0 };
|
||||||
|
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
if (LastIndexFaster) {
|
||||||
|
sd::index2coords_C(start, rank, bases, (Nd4jLong*)&coords);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
sd::index2coords_F(start, rank, bases, (Nd4jLong*)&coords);
|
||||||
|
}
|
||||||
|
offset = sd::offset_from_coords(strides1, strides2, (Nd4jLong*)&coords, rank);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
void incrementImpl(int skipRank = 0) {
|
||||||
|
offset = inc_coords<LastIndexFaster>(_bases, _strides1, _strides2, (Nd4jLong*)&coords, offset, _rank, skipRank);
|
||||||
|
}
|
||||||
|
|
||||||
|
Nd4jLong FirstImpl() { return offset.first; };
|
||||||
|
Nd4jLong SecondImpl() { return offset.second; };
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#endif
|
#endif
|
|
@ -69,7 +69,7 @@ Nd4jLong IndexReduce<X, Y>::execScalar(const void *vx, const Nd4jLong *xShapeInf
|
||||||
for (int e = 0; e < maxThreads; e++)
|
for (int e = 0; e < maxThreads; e++)
|
||||||
intermediatery[e].index = -1;
|
intermediatery[e].index = -1;
|
||||||
|
|
||||||
if (xEws == 1) {
|
if (xEws == 1 && shape::order(xShapeInfo) == 'c') {
|
||||||
auto func = PRAGMA_THREADS_FOR {
|
auto func = PRAGMA_THREADS_FOR {
|
||||||
intermediatery[thread_id] = OpType::startingIndexValue(x);
|
intermediatery[thread_id] = OpType::startingIndexValue(x);
|
||||||
|
|
||||||
|
|
|
@ -188,7 +188,7 @@ namespace functions {
|
||||||
auto reductionBuffer = static_cast<X*>(vreductionBuffer);
|
auto reductionBuffer = static_cast<X*>(vreductionBuffer);
|
||||||
auto order = shape::order(xShapeInfo);
|
auto order = shape::order(xShapeInfo);
|
||||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
__shared__ volatile int resultScalar;
|
__shared__ volatile bool resultScalar;
|
||||||
|
|
||||||
//shared memory space for storing intermediate results
|
//shared memory space for storing intermediate results
|
||||||
__shared__ IndexValue<X>* sPartials;
|
__shared__ IndexValue<X>* sPartials;
|
||||||
|
@ -214,17 +214,10 @@ namespace functions {
|
||||||
zLen = shape::length(zShapeInfo);
|
zLen = shape::length(zShapeInfo);
|
||||||
else zLen = 1;
|
else zLen = 1;
|
||||||
|
|
||||||
if (dimensionLength == 1) {
|
|
||||||
if (zLen == 1 && (dimension == nullptr || dimension[0] == MAX_DIMENSION))
|
|
||||||
resultScalar = 1;
|
|
||||||
else
|
|
||||||
resultScalar = 0;
|
|
||||||
}
|
|
||||||
else
|
|
||||||
resultScalar = 0;
|
|
||||||
|
|
||||||
if (zLen == 1)
|
if (zLen == 1)
|
||||||
resultScalar = 1;
|
resultScalar = true;
|
||||||
|
else
|
||||||
|
resultScalar = false;
|
||||||
|
|
||||||
xLength = shape::length(xShapeInfo);
|
xLength = shape::length(xShapeInfo);
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,95 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
// Created by Abdelrauf 2020 (based on argmax)
|
||||||
|
|
||||||
|
#include <system/op_boilerplate.h>
|
||||||
|
#if NOT_EXCLUDED(OP_argamax)
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/axis.h>
|
||||||
|
#include <ops/declarable/helpers/reductions.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
DECLARE_TYPES(argamax) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes({ ALL_FLOATS,ALL_INTS })
|
||||||
|
->setAllowedOutputTypes({ ALL_INTS });
|
||||||
|
}
|
||||||
|
|
||||||
|
CUSTOM_OP_IMPL(argamax, 1, 1, false, 0, -2) {
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
if (output->isEmpty())
|
||||||
|
return Status::OK();
|
||||||
|
|
||||||
|
auto axis = *block.getIArguments();
|
||||||
|
|
||||||
|
// axis might be dynamic (i.e. tf mode)
|
||||||
|
if (block.width() > 1 && axis.size() == 0) {
|
||||||
|
auto axisVector = INPUT_VARIABLE(1);
|
||||||
|
helpers::adjustAxis(input->rankOf(), axisVector, axis);
|
||||||
|
helpers::argAbsMax(*input, *output, axis);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
helpers::argAbsMax(*input, *output, axis);
|
||||||
|
}
|
||||||
|
|
||||||
|
STORE_RESULT(output);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(argamax) {
|
||||||
|
std::vector<int> dims;
|
||||||
|
|
||||||
|
if (block.width() == 1) {
|
||||||
|
dims = *block.getIArguments();
|
||||||
|
} else {
|
||||||
|
auto y = INPUT_VARIABLE(1);
|
||||||
|
dims = y->template asVectorT<int>();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto keepDims = block.numB() ? B_ARG(0) : false;
|
||||||
|
auto dtype = block.numD() ? D_ARG(0) : DataType::INT64;
|
||||||
|
|
||||||
|
// we're resolving negative axis here
|
||||||
|
helpers::adjustAxis(shape::rank(inputShape->at(0)), dims);
|
||||||
|
|
||||||
|
auto in = inputShape->at(0);
|
||||||
|
for (auto d : dims) {
|
||||||
|
// we have special case here
|
||||||
|
if (d == sd::DataTypeUtils::max<int>())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
REQUIRE_TRUE(d < shape::rank(in), 0, "ArgAmax: axis can't be above rank")
|
||||||
|
REQUIRE_TRUE(in[d + 1] != 0, 0, "ArgAmax: you can't reduce along axis with 0 in shape");
|
||||||
|
}
|
||||||
|
|
||||||
|
// special case - output is scalar
|
||||||
|
if (dims.empty() || (dims.size() == 1 && dims.at(0) == sd::DataTypeUtils::max<int>())) {
|
||||||
|
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(dtype));
|
||||||
|
}
|
||||||
|
|
||||||
|
return SHAPELIST(ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), dtype, keepDims, false, block.getWorkspace()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -0,0 +1,95 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
// Created by Abdelrauf 2020 (based on argmax)
|
||||||
|
|
||||||
|
#include <system/op_boilerplate.h>
|
||||||
|
#if NOT_EXCLUDED(OP_argamin)
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/axis.h>
|
||||||
|
#include <ops/declarable/helpers/reductions.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
DECLARE_TYPES(argamin) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes({ ALL_FLOATS,ALL_INTS })
|
||||||
|
->setAllowedOutputTypes({ ALL_INTS });
|
||||||
|
}
|
||||||
|
|
||||||
|
CUSTOM_OP_IMPL(argamin, 1, 1, false, 0, -2) {
|
||||||
|
auto input = INPUT_VARIABLE(0);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
if (output->isEmpty())
|
||||||
|
return Status::OK();
|
||||||
|
|
||||||
|
auto axis = *block.getIArguments();
|
||||||
|
|
||||||
|
// axis might be dynamic (i.e. tf mode)
|
||||||
|
if (block.width() > 1 && axis.size() == 0) {
|
||||||
|
auto axisVector = INPUT_VARIABLE(1);
|
||||||
|
helpers::adjustAxis(input->rankOf(), axisVector, axis);
|
||||||
|
helpers::argAbsMin(*input, *output, axis);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
helpers::argAbsMin(*input, *output, axis);
|
||||||
|
}
|
||||||
|
|
||||||
|
STORE_RESULT(output);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(argamin) {
|
||||||
|
std::vector<int> dims;
|
||||||
|
|
||||||
|
if (block.width() == 1) {
|
||||||
|
dims = *block.getIArguments();
|
||||||
|
} else {
|
||||||
|
auto y = INPUT_VARIABLE(1);
|
||||||
|
dims = y->template asVectorT<int>();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto keepDims = block.numB() ? B_ARG(0) : false;
|
||||||
|
auto dtype = block.numD() ? D_ARG(0) : DataType::INT64;
|
||||||
|
|
||||||
|
// we're resolving negative axis here
|
||||||
|
helpers::adjustAxis(shape::rank(inputShape->at(0)), dims);
|
||||||
|
|
||||||
|
auto in = inputShape->at(0);
|
||||||
|
for (auto d : dims) {
|
||||||
|
// we have special case here
|
||||||
|
if (d == sd::DataTypeUtils::max<int>())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
REQUIRE_TRUE(d < shape::rank(in), 0, "ArgAmin: axis can't be above rank")
|
||||||
|
REQUIRE_TRUE(in[d + 1] != 0, 0, "ArgAmin: you can't reduce along axis with 0 in shape");
|
||||||
|
}
|
||||||
|
|
||||||
|
// special case - output is scalar
|
||||||
|
if (dims.empty() || (dims.size() == 1 && dims.at(0) == sd::DataTypeUtils::max<int>())) {
|
||||||
|
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(dtype));
|
||||||
|
}
|
||||||
|
|
||||||
|
return SHAPELIST(ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), dtype, keepDims, false, block.getWorkspace()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -1,6 +1,6 @@
|
||||||
/*******************************************************************************
|
/*******************************************************************************
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
*
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
* This program and the accompanying materials are made available under the
|
* This program and the accompanying materials are made available under the
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
@ -22,6 +22,7 @@
|
||||||
#if NOT_EXCLUDED(OP_argmax)
|
#if NOT_EXCLUDED(OP_argmax)
|
||||||
|
|
||||||
#include <ops/declarable/helpers/axis.h>
|
#include <ops/declarable/helpers/axis.h>
|
||||||
|
#include <ops/declarable/helpers/reductions.h>
|
||||||
#include <ops/declarable/CustomOperations.h>
|
#include <ops/declarable/CustomOperations.h>
|
||||||
#include <helpers/ConstantTadHelper.h>
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
|
||||||
|
@ -29,7 +30,7 @@ namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
DECLARE_TYPES(argmax) {
|
DECLARE_TYPES(argmax) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(sd::DataType::ANY)
|
->setAllowedInputTypes({ ALL_FLOATS,ALL_INTS })
|
||||||
->setAllowedOutputTypes({ALL_INTS});
|
->setAllowedOutputTypes({ALL_INTS});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -37,18 +38,19 @@ namespace sd {
|
||||||
auto input = INPUT_VARIABLE(0);
|
auto input = INPUT_VARIABLE(0);
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
if (output->isEmpty())
|
||||||
|
return Status::OK();
|
||||||
|
|
||||||
auto axis = *block.getIArguments();
|
auto axis = *block.getIArguments();
|
||||||
|
|
||||||
// axis might be dynamic (i.e. tf mode)
|
// axis might be dynamic (i.e. tf mode)
|
||||||
if (block.width() > 1 && axis.size() == 0) {
|
if (block.width() > 1 && axis.size() == 0) {
|
||||||
auto axisVector = INPUT_VARIABLE(1);
|
auto axisVector = INPUT_VARIABLE(1);
|
||||||
helpers::adjustAxis(input->rankOf(), axisVector, axis);
|
helpers::adjustAxis(input->rankOf(), axisVector, axis);
|
||||||
|
helpers::argMax(*input, *output, axis);
|
||||||
input->applyIndexReduce(indexreduce::IndexMax, *output, axis);
|
|
||||||
} else {
|
} else {
|
||||||
helpers::adjustAxis(input->rankOf(), axis);
|
helpers::argMax(*input, *output, axis);
|
||||||
|
|
||||||
input->applyIndexReduce(indexreduce::IndexMax, *output, axis);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
STORE_RESULT(output);
|
STORE_RESULT(output);
|
||||||
|
@ -66,23 +68,28 @@ namespace sd {
|
||||||
dims = y->template asVectorT<int>();
|
dims = y->template asVectorT<int>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto keepDims = block.numB() ? B_ARG(0) : false;
|
||||||
|
auto dtype = block.numD() ? D_ARG(0) : DataType::INT64;
|
||||||
|
|
||||||
// we're resolving negative axis here
|
// we're resolving negative axis here
|
||||||
helpers::adjustAxis(shape::rank(inputShape->at(0)), dims);
|
helpers::adjustAxis(shape::rank(inputShape->at(0)), dims);
|
||||||
|
|
||||||
if (dims.size() > 1)
|
auto in = inputShape->at(0);
|
||||||
std::sort(dims.begin(), dims.end());
|
for (auto d : dims) {
|
||||||
|
// we have special case here
|
||||||
|
if (d == sd::DataTypeUtils::max<int>())
|
||||||
|
continue;
|
||||||
|
|
||||||
|
REQUIRE_TRUE(d < shape::rank(in), 0, "ArgMax: axis can't be above rank")
|
||||||
for (auto d:dims) {
|
REQUIRE_TRUE(in[d + 1] != 0, 0, "ArgMax: you can't reduce along axis with 0 in shape");
|
||||||
REQUIRE_TRUE(inputShape->at(0)[d+1] != 0, 0, "ArgMax: you can't reduce along axis with 0 in shape");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// special case - output is scalar
|
// special case - output is scalar
|
||||||
if (dims.size() == 0 || (dims.size() == 1 && dims.at(0) == sd::DataTypeUtils::max<int>())) {
|
if (dims.empty() || (dims.size() == 1 && dims.at(0) == sd::DataTypeUtils::max<int>())) {
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(sd::DataType::INT64));
|
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(dtype));
|
||||||
}
|
}
|
||||||
|
|
||||||
return SHAPELIST(ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), DataType::INT64, false, false, block.getWorkspace()));
|
return SHAPELIST(ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), dtype, keepDims, false, block.getWorkspace()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -21,15 +21,17 @@
|
||||||
#include <system/op_boilerplate.h>
|
#include <system/op_boilerplate.h>
|
||||||
#if NOT_EXCLUDED(OP_argmin)
|
#if NOT_EXCLUDED(OP_argmin)
|
||||||
|
|
||||||
#include <ops/declarable/CustomOperations.h>
|
|
||||||
#include <ops/declarable/helpers/axis.h>
|
#include <ops/declarable/helpers/axis.h>
|
||||||
|
#include <ops/declarable/helpers/reductions.h>
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
||||||
|
|
||||||
DECLARE_TYPES(argmin) {
|
DECLARE_TYPES(argmin) {
|
||||||
getOpDescriptor()
|
getOpDescriptor()
|
||||||
->setAllowedInputTypes(sd::DataType::ANY)
|
->setAllowedInputTypes({ ALL_FLOATS,ALL_INTS })
|
||||||
->setAllowedOutputTypes({ALL_INTS});
|
->setAllowedOutputTypes({ALL_INTS});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -39,16 +41,18 @@ namespace sd {
|
||||||
|
|
||||||
auto output = OUTPUT_VARIABLE(0);
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
if (output->isEmpty())
|
||||||
|
return Status::OK();
|
||||||
|
|
||||||
// axis might be dynamic (i.e. tf mode)
|
// axis might be dynamic (i.e. tf mode)
|
||||||
if (block.width() > 1 && axis.size() == 0) {
|
if (block.width() > 1 && axis.size() == 0) {
|
||||||
auto axisVector = INPUT_VARIABLE(1);
|
auto axisVector = INPUT_VARIABLE(1);
|
||||||
helpers::adjustAxis(input->rankOf(), axisVector, axis);
|
helpers::adjustAxis(input->rankOf(), axisVector, axis);
|
||||||
|
helpers::argMin(*input, *output, axis);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
helpers::argMin(*input, *output, axis);
|
||||||
|
|
||||||
input->applyIndexReduce(indexreduce::IndexMin, *output, axis);
|
|
||||||
} else {
|
|
||||||
helpers::adjustAxis(input->rankOf(), axis);
|
|
||||||
|
|
||||||
input->applyIndexReduce(indexreduce::IndexMin, *output, axis);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
STORE_RESULT(output);
|
STORE_RESULT(output);
|
||||||
|
@ -58,7 +62,7 @@ namespace sd {
|
||||||
|
|
||||||
DECLARE_SHAPE_FN(argmin) {
|
DECLARE_SHAPE_FN(argmin) {
|
||||||
std::vector<int> dims;
|
std::vector<int> dims;
|
||||||
auto in = inputShape->at(0);
|
|
||||||
if (block.width() == 1) {
|
if (block.width() == 1) {
|
||||||
dims = *block.getIArguments();
|
dims = *block.getIArguments();
|
||||||
} else {
|
} else {
|
||||||
|
@ -66,23 +70,28 @@ namespace sd {
|
||||||
dims = y->template asVectorT<int>();
|
dims = y->template asVectorT<int>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto keepDims = block.numB() ? B_ARG(0) : false;
|
||||||
|
auto dtype = block.numD() ? D_ARG(0) : DataType::INT64;
|
||||||
|
|
||||||
// we're resolving negative axis here
|
// we're resolving negative axis here
|
||||||
helpers::adjustAxis(shape::rank(in), dims);
|
helpers::adjustAxis(shape::rank(inputShape->at(0)), dims);
|
||||||
|
|
||||||
if (dims.size() > 1)
|
auto in = inputShape->at(0);
|
||||||
std::sort(dims.begin(), dims.end());
|
for (auto d : dims) {
|
||||||
|
// we have special case here
|
||||||
|
if (d == sd::DataTypeUtils::max<int>())
|
||||||
|
continue;
|
||||||
|
|
||||||
for (auto d:dims) {
|
REQUIRE_TRUE(d < shape::rank(in), 0, "ArgMin: axis can't be above rank")
|
||||||
REQUIRE_TRUE(inputShape->at(0)[d+1] != 0, 0, "ArgMin: you can't reduce along axis with 0 in shape");
|
REQUIRE_TRUE(in[d + 1] != 0, 0, "ArgMin: you can't reduce along axis with 0 in shape");
|
||||||
}
|
}
|
||||||
|
|
||||||
// special case - output is scalar
|
// special case - output is scalar
|
||||||
if (dims.size() == 0 || (dims.size() == 1 && dims.at(0) == sd::DataTypeUtils::max<int>())) {
|
if (dims.empty() || (dims.size() == 1 && dims.at(0) == sd::DataTypeUtils::max<int>())) {
|
||||||
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT64));
|
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(dtype));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto newShape = ShapeUtils::evalReduceShapeInfo('c', dims, in, DataType::INT64, false, false, block.getWorkspace());
|
return SHAPELIST(ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), dtype, keepDims, false, block.getWorkspace()));
|
||||||
return SHAPELIST(newShape);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -52,6 +52,32 @@ namespace sd {
|
||||||
DECLARE_CUSTOM_OP(argmin, 1, 1, false, 0, -2);
|
DECLARE_CUSTOM_OP(argmin, 1, 1, false, 0, -2);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This operation returns index of absolute max element in a given NDArray (optionally: along given dimension(s))
|
||||||
|
* Expected input:
|
||||||
|
* 0: N-dimensional array
|
||||||
|
* 1: optional axis vector
|
||||||
|
*
|
||||||
|
* Int args:
|
||||||
|
* 0: optional axis
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_argamax)
|
||||||
|
DECLARE_CUSTOM_OP(argamax, 1, 1, false, 0, -2);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This operation returns index of absolute min element in a given NDArray (optionally: along given dimension(s))
|
||||||
|
* Expected input:
|
||||||
|
* 0: N-dimensional array
|
||||||
|
* 1: optional axis vector
|
||||||
|
*
|
||||||
|
* Int args:
|
||||||
|
* 0: optional axis
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_argamin)
|
||||||
|
DECLARE_CUSTOM_OP(argamin, 1, 1, false, 0, -2);
|
||||||
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This operation provides various normalization modes:
|
* This operation provides various normalization modes:
|
||||||
* 0: frobenius
|
* 0: frobenius
|
||||||
|
|
|
@ -0,0 +1,28 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
//
|
||||||
|
// @author AbdelRauf
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/cpu/indexReductions.hpp>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template void argAbsMax_, (const NDArray& input, NDArray& output, const std::vector<int>& dimensions), LIBND4J_TYPES_@FL_TYPE_INDEX@, INDEXING_TYPES);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,28 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
//
|
||||||
|
// @author AbdelRauf
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/cpu/indexReductions.hpp>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template void argAbsMin_, (const NDArray& input, NDArray& output, const std::vector<int>& dimensions), LIBND4J_TYPES_@FL_TYPE_INDEX@, INDEXING_TYPES);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,28 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
//
|
||||||
|
// @author AbdelRauf
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/cpu/indexReductions.hpp>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template void argMax_, (const NDArray& input, NDArray& output, const std::vector<int>& dimensions), LIBND4J_TYPES_@FL_TYPE_INDEX@, INDEXING_TYPES);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,28 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
//
|
||||||
|
// @author AbdelRauf
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/cpu/indexReductions.hpp>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
BUILD_DOUBLE_TEMPLATE(template void argMin_, (const NDArray& input, NDArray& output, const std::vector<int>& dimensions), LIBND4J_TYPES_@FL_TYPE_INDEX@, INDEXING_TYPES);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -19,7 +19,7 @@
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ops/declarable/helpers/crop_and_resize.h>
|
#include <ops/declarable/helpers/crop_and_resize.h>
|
||||||
#include "../crop_and_resize.hpp"
|
#include "ops/declarable/helpers/cpu/crop_and_resize.hpp"
|
||||||
|
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
|
@ -19,7 +19,7 @@
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ops/declarable/helpers/crop_and_resize.h>
|
#include <ops/declarable/helpers/crop_and_resize.h>
|
||||||
#include "../crop_and_resize.hpp"
|
#include "ops/declarable/helpers/cpu/crop_and_resize.hpp"
|
||||||
|
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
|
@ -19,7 +19,7 @@
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ops/declarable/helpers/crop_and_resize.h>
|
#include <ops/declarable/helpers/crop_and_resize.h>
|
||||||
#include "../crop_and_resize.hpp"
|
#include "ops/declarable/helpers/cpu/crop_and_resize.hpp"
|
||||||
|
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
|
@ -19,7 +19,7 @@
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ops/declarable/helpers/crop_and_resize.h>
|
#include <ops/declarable/helpers/crop_and_resize.h>
|
||||||
#include "../crop_and_resize.hpp"
|
#include "ops/declarable/helpers/cpu/crop_and_resize.hpp"
|
||||||
|
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
|
@ -19,7 +19,7 @@
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ops/declarable/helpers/crop_and_resize.h>
|
#include <ops/declarable/helpers/crop_and_resize.h>
|
||||||
#include "../crop_and_resize.hpp"
|
#include "ops/declarable/helpers/cpu/crop_and_resize.hpp"
|
||||||
|
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
|
@ -19,7 +19,7 @@
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ops/declarable/helpers/crop_and_resize.h>
|
#include <ops/declarable/helpers/crop_and_resize.h>
|
||||||
#include "../crop_and_resize.hpp"
|
#include "ops/declarable/helpers/cpu/crop_and_resize.hpp"
|
||||||
|
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
|
@ -19,7 +19,7 @@
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ops/declarable/helpers/crop_and_resize.h>
|
#include <ops/declarable/helpers/crop_and_resize.h>
|
||||||
#include "../crop_and_resize.hpp"
|
#include "ops/declarable/helpers/cpu/crop_and_resize.hpp"
|
||||||
|
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
|
@ -19,7 +19,7 @@
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ops/declarable/helpers/crop_and_resize.h>
|
#include <ops/declarable/helpers/crop_and_resize.h>
|
||||||
#include "../crop_and_resize.hpp"
|
#include "ops/declarable/helpers/cpu/crop_and_resize.hpp"
|
||||||
|
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
|
@ -19,7 +19,7 @@
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ops/declarable/helpers/crop_and_resize.h>
|
#include <ops/declarable/helpers/crop_and_resize.h>
|
||||||
#include "../crop_and_resize.hpp"
|
#include "ops/declarable/helpers/cpu/crop_and_resize.hpp"
|
||||||
|
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
|
@ -19,7 +19,7 @@
|
||||||
//
|
//
|
||||||
|
|
||||||
#include <ops/declarable/helpers/crop_and_resize.h>
|
#include <ops/declarable/helpers/crop_and_resize.h>
|
||||||
#include "../crop_and_resize.hpp"
|
#include "ops/declarable/helpers/cpu/crop_and_resize.hpp"
|
||||||
|
|
||||||
namespace sd {
|
namespace sd {
|
||||||
namespace ops {
|
namespace ops {
|
|
@ -0,0 +1,56 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
//
|
||||||
|
// @author AbdelRauf
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/reductions.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template<typename X, typename Z>
|
||||||
|
void argMax_(const NDArray& input, NDArray& output, const std::vector<int>& dimensions);
|
||||||
|
|
||||||
|
template<typename X, typename Z>
|
||||||
|
void argMin_(const NDArray& input, NDArray& output, const std::vector<int>& dimensions);
|
||||||
|
|
||||||
|
template<typename X, typename Z>
|
||||||
|
void argAbsMax_(const NDArray& input, NDArray& output, const std::vector<int>& dimensions);
|
||||||
|
|
||||||
|
template<typename X, typename Z>
|
||||||
|
void argAbsMin_(const NDArray& input, NDArray& output, const std::vector<int>& dimensions);
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
void argMax(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), argMax_, (input, output, dimensions), LIBND4J_TYPES, INDEXING_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
void argMin(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), argMin_, (input, output, dimensions), LIBND4J_TYPES, INDEXING_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
void argAbsMax(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), argAbsMax_, (input, output, dimensions), LIBND4J_TYPES, INDEXING_TYPES);
|
||||||
|
}
|
||||||
|
|
||||||
|
void argAbsMin(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), argAbsMin_, (input, output, dimensions), LIBND4J_TYPES, INDEXING_TYPES);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,900 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
//
|
||||||
|
// @author AbdelRauf
|
||||||
|
//
|
||||||
|
#include <type_traits>
|
||||||
|
#include <cmath>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <memory>
|
||||||
|
#include <execution/Threads.h>
|
||||||
|
#include <execution/ThreadPool.h>
|
||||||
|
#include <helpers/LoopsCoordsHelper.h>
|
||||||
|
#include <ops/declarable/helpers/reductions.h>
|
||||||
|
#if 1
|
||||||
|
#define LOG_CALLS(X)
|
||||||
|
#else
|
||||||
|
|
||||||
|
#define LOG_CALLS(X) nd4j_printf("___%s_________%d+\n", __PRETTY_FUNCTION__, X);
|
||||||
|
#endif
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
constexpr int threadingThreshold = 4096;
|
||||||
|
template<typename X, typename Z, typename ReductionOp>
|
||||||
|
FORCEINLINE void indexInnerReductionRank1(const X* buffer, X& current, Z& argCurrent, const Nd4jLong& loopCount)
|
||||||
|
{
|
||||||
|
argCurrent = 0;
|
||||||
|
current = buffer[0];
|
||||||
|
LOG_CALLS(0)
|
||||||
|
Nd4jLong j_offset = 0;
|
||||||
|
for (Z j = 0; j < loopCount; j++) {
|
||||||
|
ReductionOp::update(current, argCurrent, buffer[j], j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename X, typename Z, typename ReductionOp>
|
||||||
|
FORCEINLINE void indexInnerReductionRank1(const X* buffer, X& current, Z& argCurrent, const Nd4jLong& loopCount, const Nd4jLong& inner_stride)
|
||||||
|
{
|
||||||
|
argCurrent = 0;
|
||||||
|
current = buffer[0];
|
||||||
|
LOG_CALLS(0)
|
||||||
|
Nd4jLong j_offset = 0;
|
||||||
|
for (Z j = 0; j < loopCount; j++) {
|
||||||
|
ReductionOp::update(current, argCurrent, buffer[j_offset], j);
|
||||||
|
j_offset += inner_stride;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename X, typename Z, typename ReductionOp, size_t constRank, bool LastIndexFaster = true>
|
||||||
|
FORCEINLINE void indexInnerReductionConstRank(const X* buffer, X& current, Z& argCurrent, const Nd4jLong* bases, const Nd4jLong* strides, const Nd4jLong outerLoopCount, const Nd4jLong& innerLoopCount)
|
||||||
|
{
|
||||||
|
//skip 1 from the beginning or end depending the Order
|
||||||
|
constexpr size_t updated_index = LastIndexFaster ? 0 : 1;
|
||||||
|
constexpr size_t updated_rank = constRank - 1;
|
||||||
|
sd::CoordsState<updated_rank - 1> cst;
|
||||||
|
//we skip 1
|
||||||
|
size_t offset = sd::init_coords<updated_rank, 0, LastIndexFaster>(cst, 0, bases + updated_index, strides + updated_index);
|
||||||
|
Z startIndex = 0;
|
||||||
|
argCurrent = 0;
|
||||||
|
current = buffer[offset];
|
||||||
|
LOG_CALLS(0)
|
||||||
|
for (Z i = 0; i < outerLoopCount; i++) {
|
||||||
|
const X* inner_buffer = &(buffer[offset]);
|
||||||
|
//typename std::make_signed<Z>::type iArgMax = -1;
|
||||||
|
for (Z j = 0; j < innerLoopCount; j++) {
|
||||||
|
ReductionOp::update(current, argCurrent, inner_buffer[j], j + startIndex);
|
||||||
|
}
|
||||||
|
//we skip 1
|
||||||
|
offset = sd::inc_coords<updated_rank, 0, LastIndexFaster>(cst, offset);
|
||||||
|
startIndex += innerLoopCount;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename X, typename Z, typename ReductionOp, size_t constRank, bool LastIndexFaster = true>
|
||||||
|
FORCEINLINE void indexInnerReductionConstRank(const X* buffer, X& current, Z& argCurrent, const Nd4jLong* bases, const Nd4jLong* strides, const Nd4jLong outerLoopCount, const Nd4jLong& innerLoopCount, const Nd4jLong& inner_stride)
|
||||||
|
{
|
||||||
|
//skip 1 from the beginning or end depending the Order
|
||||||
|
constexpr size_t updated_index = LastIndexFaster ? 0 : 1;
|
||||||
|
constexpr size_t updated_rank = constRank - 1;
|
||||||
|
sd::CoordsState<updated_rank - 1> cst;
|
||||||
|
//we skip 1
|
||||||
|
size_t offset = sd::init_coords<updated_rank, 0, LastIndexFaster>(cst, 0, bases + updated_index, strides + updated_index);
|
||||||
|
Z startIndex = 0;
|
||||||
|
argCurrent = 0;
|
||||||
|
current = buffer[offset];
|
||||||
|
LOG_CALLS(0)
|
||||||
|
for (Z i = 0; i < outerLoopCount; i++) {
|
||||||
|
const X* inner_buffer = &(buffer[offset]);
|
||||||
|
for (Z j = 0; j < innerLoopCount; j++) {
|
||||||
|
ReductionOp::update(current, argCurrent, *inner_buffer, j + startIndex);
|
||||||
|
inner_buffer += inner_stride;
|
||||||
|
}
|
||||||
|
//we alreaddy skiped
|
||||||
|
offset = sd::inc_coords<updated_rank, 0, LastIndexFaster>(cst, offset);
|
||||||
|
startIndex += innerLoopCount;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename X, typename Z, typename ReductionOp, bool LastIndexFaster = true>
|
||||||
|
FORCEINLINE void indexInnerReduction(const int& rank, const X* buffer, X& current, Z& argCurrent, const Nd4jLong* bases, const Nd4jLong* strides, const Nd4jLong& outerLoopStart, const Nd4jLong& outerLoopStop, const Nd4jLong& innerLoopCount)
|
||||||
|
{
|
||||||
|
size_t offset = 0;
|
||||||
|
Nd4jLong outerLoopCount = outerLoopStop - outerLoopStart;
|
||||||
|
Nd4jLong coords[MAX_RANK] = {};
|
||||||
|
Nd4jLong* ptr_coords = (Nd4jLong*)&coords;
|
||||||
|
if (outerLoopStart > 0) {
|
||||||
|
sd::index2coords_C(outerLoopStart, rank - 1, bases, ptr_coords);
|
||||||
|
offset = sd::offset_from_coords(strides, ptr_coords, rank);
|
||||||
|
}
|
||||||
|
Z startIndex = outerLoopStart * innerLoopCount;
|
||||||
|
argCurrent = startIndex;
|
||||||
|
current = buffer[offset];
|
||||||
|
LOG_CALLS(0)
|
||||||
|
for (Z i = 0; i < outerLoopCount; i++) {
|
||||||
|
const X* inner_buffer = &(buffer[offset]);
|
||||||
|
//typename std::make_signed<Z>::type iArgMax = -1;
|
||||||
|
for (Z j = 0; j < innerLoopCount; j++) {
|
||||||
|
//nd4j_printf("%f\n", inner_buffer[j]);
|
||||||
|
ReductionOp::update(current, argCurrent, inner_buffer[j], j + startIndex);
|
||||||
|
}
|
||||||
|
offset = inc_coords<true>(bases, strides, ptr_coords, offset, rank, 1);
|
||||||
|
//if (iArgMax >= 0) argCurrent = startIndex + iArgMax;
|
||||||
|
startIndex += innerLoopCount;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename X, typename Z, typename ReductionOp, bool LastIndexFaster = true>
|
||||||
|
FORCEINLINE void indexInnerReduction(const int& rank, const X* buffer, X& current, Z& argCurrent, const Nd4jLong* bases, const Nd4jLong* strides, const Nd4jLong& outerLoopStart, const Nd4jLong& outerLoopStop, const Nd4jLong& innerLoopCount, const Nd4jLong& inner_stride)
|
||||||
|
{
|
||||||
|
size_t offset = 0;
|
||||||
|
Nd4jLong outerLoopCount = outerLoopStop - outerLoopStart;
|
||||||
|
Nd4jLong coords[MAX_RANK] = {};
|
||||||
|
Nd4jLong* ptr_coords = (Nd4jLong*)&coords;
|
||||||
|
if (outerLoopStart > 0) {
|
||||||
|
sd::index2coords_C(outerLoopStart, rank - 1, bases, ptr_coords);
|
||||||
|
offset = sd::offset_from_coords(strides, ptr_coords, rank);
|
||||||
|
}
|
||||||
|
Z startIndex = outerLoopStart * innerLoopCount;
|
||||||
|
argCurrent = startIndex;
|
||||||
|
current = buffer[offset];
|
||||||
|
LOG_CALLS(0)
|
||||||
|
for (Z i = 0; i < outerLoopCount; i++) {
|
||||||
|
const X* inner_buffer = &(buffer[offset]);
|
||||||
|
//typename std::make_signed<Z>::type iArgMax = -1;
|
||||||
|
for (Z j = 0; j < innerLoopCount; j++) {
|
||||||
|
ReductionOp::update(current, argCurrent, inner_buffer[j * inner_stride], startIndex + j);
|
||||||
|
}
|
||||||
|
offset = inc_coords<true>(bases, strides, ptr_coords, offset, rank, 1);
|
||||||
|
//offset = inc_coords<LastIndexFaster>(bases, strides, ptr_coords, offset, rank, 1);
|
||||||
|
//if (iArgMax >= 0) argCurrent = startIndex + iArgMax;
|
||||||
|
startIndex += innerLoopCount;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename X, typename Z, typename ReductionOp>
|
||||||
|
FORCEINLINE void indexInnerReductionRank1Block4WithMerge(const X* buffer, X& current, Z& argCurrent, const Nd4jLong& loopCount)
|
||||||
|
{
|
||||||
|
argCurrent = 0;
|
||||||
|
current = buffer[0];
|
||||||
|
LOG_CALLS(0)
|
||||||
|
Nd4jLong loopCount4 = loopCount / 4;
|
||||||
|
Nd4jLong loopCountEnd = loopCount4 + (loopCount & 3);
|
||||||
|
const X* buffer1 = buffer + 1 * loopCount4;
|
||||||
|
const X* buffer2 = buffer1 + 1 * loopCount4;
|
||||||
|
const X* buffer3 = buffer2 + 1 * loopCount4;
|
||||||
|
X current1 = *buffer1;
|
||||||
|
X current2 = *buffer2;
|
||||||
|
X current3 = *buffer3;
|
||||||
|
Z argCurrent1 = 0;
|
||||||
|
Z argCurrent2 = 0;
|
||||||
|
Z argCurrent3 = 0;
|
||||||
|
for (Z j = 0; j < loopCount4; j++) {
|
||||||
|
ReductionOp::update(current, argCurrent, buffer[j], j);
|
||||||
|
ReductionOp::update(current1, argCurrent1, buffer1[j], j);
|
||||||
|
ReductionOp::update(current2, argCurrent2, buffer2[j], j);
|
||||||
|
ReductionOp::update(current3, argCurrent3, buffer3[j], j);
|
||||||
|
}
|
||||||
|
//tail
|
||||||
|
for (Z j = loopCount4; j < loopCountEnd; j++) {
|
||||||
|
ReductionOp::update(current3, argCurrent3, buffer3[j], j);
|
||||||
|
}
|
||||||
|
//merge
|
||||||
|
argCurrent1 += loopCount4;
|
||||||
|
argCurrent2 += 2 * loopCount4;
|
||||||
|
argCurrent3 += 3 * loopCount4;
|
||||||
|
ReductionOp::update(current, argCurrent, current1, argCurrent1);
|
||||||
|
ReductionOp::update(current, argCurrent, current2, argCurrent2);
|
||||||
|
ReductionOp::update(current, argCurrent, current3, argCurrent3);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename X, typename Z, typename ReductionOp>
|
||||||
|
FORCEINLINE void indexInnerReductionRank1Block4WithMerge(const X* buffer, X& current, Z& argCurrent, const Nd4jLong& loopCount, const Nd4jLong& inner_stride)
|
||||||
|
{
|
||||||
|
argCurrent = 0;
|
||||||
|
current = buffer[0];
|
||||||
|
LOG_CALLS(0)
|
||||||
|
Nd4jLong loopCount4 = loopCount / 4;
|
||||||
|
Nd4jLong loopCountEnd = loopCount4 + (loopCount & 3);
|
||||||
|
const X* buffer1 = buffer + inner_stride * loopCount4;
|
||||||
|
const X* buffer2 = buffer1 + inner_stride * loopCount4;
|
||||||
|
const X* buffer3 = buffer2 + inner_stride * loopCount4;
|
||||||
|
X current1 = *buffer1;
|
||||||
|
X current2 = *buffer2;
|
||||||
|
X current3 = *buffer3;
|
||||||
|
Z argCurrent1 = 0;
|
||||||
|
Z argCurrent2 = 0;
|
||||||
|
Z argCurrent3 = 0;
|
||||||
|
Nd4jLong j_offset = 0;
|
||||||
|
for (Z j = 0; j < loopCount4; j++) {
|
||||||
|
ReductionOp::update(current, argCurrent, buffer[j_offset], j);
|
||||||
|
ReductionOp::update(current1, argCurrent1, buffer1[j_offset], j);
|
||||||
|
ReductionOp::update(current2, argCurrent2, buffer2[j_offset], j);
|
||||||
|
ReductionOp::update(current3, argCurrent3, buffer3[j_offset], j);
|
||||||
|
j_offset += inner_stride;
|
||||||
|
}
|
||||||
|
//tail
|
||||||
|
for (Z j = loopCount4; j < loopCountEnd; j++) {
|
||||||
|
ReductionOp::update(current3, argCurrent3, buffer3[j_offset], j);
|
||||||
|
j_offset += inner_stride;
|
||||||
|
}
|
||||||
|
//merge
|
||||||
|
argCurrent1 += loopCount4;
|
||||||
|
argCurrent2 += 2 * loopCount4;
|
||||||
|
argCurrent3 += 3 * loopCount4;
|
||||||
|
ReductionOp::update(current, argCurrent, current1, argCurrent1);
|
||||||
|
ReductionOp::update(current, argCurrent, current2, argCurrent2);
|
||||||
|
ReductionOp::update(current, argCurrent, current3, argCurrent3);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename X, typename Z, typename ReductionOp>
|
||||||
|
FORCEINLINE void indexInnerReductionRank1Block4(const X* buffer, const X* buffer1, const X* buffer2, const X* buffer3, Z* output, Z* output1, Z* output2, Z* output3, const Nd4jLong& loopCount)
|
||||||
|
{
|
||||||
|
LOG_CALLS(0)
|
||||||
|
Z argCurrent = 0;
|
||||||
|
Z argCurrent1 = 0;
|
||||||
|
Z argCurrent2 = 0;
|
||||||
|
Z argCurrent3 = 0;
|
||||||
|
X current = buffer[0];
|
||||||
|
X current1 = buffer1[0];
|
||||||
|
X current2 = buffer2[0];
|
||||||
|
X current3 = buffer3[0];
|
||||||
|
for (Z j = 0; j < loopCount; j++) {
|
||||||
|
ReductionOp::update(current, argCurrent, buffer[j], j);
|
||||||
|
ReductionOp::update(current1, argCurrent1, buffer1[j], j);
|
||||||
|
ReductionOp::update(current2, argCurrent2, buffer2[j], j);
|
||||||
|
ReductionOp::update(current3, argCurrent3, buffer3[j], j);
|
||||||
|
}
|
||||||
|
*output = argCurrent;
|
||||||
|
*output1 = argCurrent1;
|
||||||
|
*output2 = argCurrent2;
|
||||||
|
*output3 = argCurrent3;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename X, typename Z, typename ReductionOp>
|
||||||
|
FORCEINLINE void indexInnerReductionRank1Block4(const X* buffer, const X* buffer1, const X* buffer2, const X* buffer3, Z* output, Z* output1, Z* output2, Z* output3, const Nd4jLong& loopCount, const Nd4jLong& inner_stride)
|
||||||
|
{
|
||||||
|
LOG_CALLS(0)
|
||||||
|
Z argCurrent = 0;
|
||||||
|
Z argCurrent1 = 0;
|
||||||
|
Z argCurrent2 = 0;
|
||||||
|
Z argCurrent3 = 0;
|
||||||
|
X current = buffer[0];
|
||||||
|
X current1 = buffer1[0];
|
||||||
|
X current2 = buffer2[0];
|
||||||
|
X current3 = buffer3[0];
|
||||||
|
Nd4jLong j_offset = 0;
|
||||||
|
for (Z j = 0; j < loopCount; j++) {
|
||||||
|
ReductionOp::update(current, argCurrent, buffer[j_offset], j);
|
||||||
|
ReductionOp::update(current1, argCurrent1, buffer1[j_offset], j);
|
||||||
|
ReductionOp::update(current2, argCurrent2, buffer2[j_offset], j);
|
||||||
|
ReductionOp::update(current3, argCurrent3, buffer3[j_offset], j);
|
||||||
|
j_offset += inner_stride;
|
||||||
|
}
|
||||||
|
*output = argCurrent;
|
||||||
|
*output1 = argCurrent1;
|
||||||
|
*output2 = argCurrent2;
|
||||||
|
*output3 = argCurrent3;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename X, typename Z, typename ReductionOp, size_t constRank, bool LastIndexFaster = true>
|
||||||
|
FORCEINLINE void indexInnerReductionConstRankBlock4(const X* buffer, const X* buffer1, const X* buffer2, const X* buffer3,
|
||||||
|
Z* output, Z* output1, Z* output2, Z* output3, const Nd4jLong* bases, const Nd4jLong* strides,
|
||||||
|
const Nd4jLong& outerLoopCount, const Nd4jLong& innerLoopCount)
|
||||||
|
{
|
||||||
|
LOG_CALLS(0)
|
||||||
|
//skip 1 from the beginning or end depending the Order
|
||||||
|
constexpr size_t updated_index = LastIndexFaster ? 0 : 1;
|
||||||
|
constexpr size_t updated_rank = constRank - 1;
|
||||||
|
sd::CoordsState<updated_rank - 1> cst;
|
||||||
|
//we skip 1
|
||||||
|
size_t offset = sd::init_coords<updated_rank, 0, LastIndexFaster>(cst, 0, bases + updated_index, strides + updated_index);
|
||||||
|
Z startIndex = 0;
|
||||||
|
Z argCurrent = 0;
|
||||||
|
Z argCurrent1 = 0;
|
||||||
|
Z argCurrent2 = 0;
|
||||||
|
Z argCurrent3 = 0;
|
||||||
|
X current = buffer[0];
|
||||||
|
X current1 = buffer1[0];
|
||||||
|
X current2 = buffer2[0];
|
||||||
|
X current3 = buffer3[0];
|
||||||
|
//LOG_CALLS(0)
|
||||||
|
for (Z i = 0; i < outerLoopCount; i++) {
|
||||||
|
const X* inner_buffer = &(buffer[offset]);
|
||||||
|
const X* inner_buffer1 = &(buffer1[offset]);
|
||||||
|
const X* inner_buffer2 = &(buffer2[offset]);
|
||||||
|
const X* inner_buffer3 = &(buffer3[offset]);
|
||||||
|
//typename std::make_signed<Z>::type iArgMax = -1;
|
||||||
|
for (Z j = 0; j < innerLoopCount; j++) {
|
||||||
|
ReductionOp::update(current, argCurrent, inner_buffer[j], j + startIndex);
|
||||||
|
ReductionOp::update(current1, argCurrent1, inner_buffer1[j], j + startIndex);
|
||||||
|
ReductionOp::update(current2, argCurrent2, inner_buffer2[j], j + startIndex);
|
||||||
|
ReductionOp::update(current3, argCurrent3, inner_buffer3[j], j + startIndex);
|
||||||
|
}
|
||||||
|
//we skip 1
|
||||||
|
offset = sd::inc_coords<updated_rank, 0, LastIndexFaster>(cst, offset);
|
||||||
|
startIndex += innerLoopCount;
|
||||||
|
}
|
||||||
|
*output = argCurrent;
|
||||||
|
*output1 = argCurrent1;
|
||||||
|
*output2 = argCurrent2;
|
||||||
|
*output3 = argCurrent3;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename X, typename Z, typename ReductionOp, size_t constRank, bool LastIndexFaster = true>
|
||||||
|
FORCEINLINE void indexInnerReductionConstRankBlock4(const X* buffer, const X* buffer1, const X* buffer2, const X* buffer3,
|
||||||
|
Z* output, Z* output1, Z* output2, Z* output3, const Nd4jLong* bases, const Nd4jLong* strides,
|
||||||
|
const Nd4jLong& outerLoopCount, const Nd4jLong& innerLoopCount, const Nd4jLong& inner_stride)
|
||||||
|
{
|
||||||
|
LOG_CALLS(0)
|
||||||
|
//skip 1 from the beginning or end depending the Order
|
||||||
|
constexpr size_t updated_index = LastIndexFaster ? 0 : 1;
|
||||||
|
constexpr size_t updated_rank = constRank - 1;
|
||||||
|
sd::CoordsState<updated_rank - 1> cst;
|
||||||
|
//we skip 1
|
||||||
|
size_t offset = sd::init_coords<updated_rank, 0, LastIndexFaster>(cst, 0, bases + updated_index, strides + updated_index);
|
||||||
|
Z startIndex = 0;
|
||||||
|
Z argCurrent = 0;
|
||||||
|
Z argCurrent1 = 0;
|
||||||
|
Z argCurrent2 = 0;
|
||||||
|
Z argCurrent3 = 0;
|
||||||
|
X current = buffer[0];
|
||||||
|
X current1 = buffer1[0];
|
||||||
|
X current2 = buffer2[0];
|
||||||
|
X current3 = buffer3[0];
|
||||||
|
//LOG_CALLS(0)
|
||||||
|
for (Z i = 0; i < outerLoopCount; i++) {
|
||||||
|
const X* inner_buffer = &(buffer[offset]);
|
||||||
|
const X* inner_buffer1 = &(buffer1[offset]);
|
||||||
|
const X* inner_buffer2 = &(buffer2[offset]);
|
||||||
|
const X* inner_buffer3 = &(buffer3[offset]);
|
||||||
|
//typename std::make_signed<Z>::type iArgMax = -1;
|
||||||
|
Nd4jLong inner_offset = 0;
|
||||||
|
for (Z j = 0; j < innerLoopCount; j++) {
|
||||||
|
ReductionOp::update(current, argCurrent, inner_buffer[inner_offset], j + startIndex);
|
||||||
|
ReductionOp::update(current1, argCurrent1, inner_buffer1[inner_offset], j + startIndex);
|
||||||
|
ReductionOp::update(current2, argCurrent2, inner_buffer2[inner_offset], j + startIndex);
|
||||||
|
ReductionOp::update(current3, argCurrent3, inner_buffer3[inner_offset], j + startIndex);
|
||||||
|
inner_offset += inner_stride;
|
||||||
|
}
|
||||||
|
//we skip 1
|
||||||
|
offset = sd::inc_coords<updated_rank, 0, LastIndexFaster>(cst, offset);
|
||||||
|
startIndex += innerLoopCount;
|
||||||
|
}
|
||||||
|
*output = argCurrent;
|
||||||
|
*output1 = argCurrent1;
|
||||||
|
*output2 = argCurrent2;
|
||||||
|
*output3 = argCurrent3;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename X, typename Z, typename ReductionOp, bool LastIndexFaster = true>
|
||||||
|
void argIndexCase1Scalar(const int& second_rank,const Nd4jLong* inner_bases,const Nd4jLong* inner_strides, const X* bufferX, Z* outputZ)
|
||||||
|
{
|
||||||
|
Nd4jLong inner_total;
|
||||||
|
Nd4jLong inner_last = 0;
|
||||||
|
int maxThreads = sd::Environment::getInstance()->maxMasterThreads();
|
||||||
|
if (second_rank == 1) {
|
||||||
|
inner_total = inner_bases[0];
|
||||||
|
if (inner_total < threadingThreshold) {
|
||||||
|
maxThreads = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
inner_total = getLength<LastIndexFaster>(inner_bases, second_rank, 1, inner_last);
|
||||||
|
if (inner_total * inner_last < threadingThreshold) {
|
||||||
|
maxThreads = 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
std::unique_ptr<X[]> maxValues(new X[maxThreads]);
|
||||||
|
std::unique_ptr<Z[]> maxIndices(new Z[maxThreads]);
|
||||||
|
X* ptrMaxValues = maxValues.get();
|
||||||
|
Z* ptrMaxIndices = maxIndices.get();
|
||||||
|
auto func = [ptrMaxValues, ptrMaxIndices, inner_last, second_rank, inner_bases, inner_strides, bufferX](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void {
|
||||||
|
//LOG_CALLS(0)
|
||||||
|
const Nd4jLong inner_stride = LastIndexFaster ? inner_strides[second_rank - 1] : inner_strides[0];
|
||||||
|
Z argCurrent; X current;
|
||||||
|
if (second_rank == 1) {
|
||||||
|
const Nd4jLong loopTotal = stop - start;
|
||||||
|
if (inner_stride == 1) {
|
||||||
|
indexInnerReductionRank1Block4WithMerge<X, Z, ReductionOp>(&(bufferX[start]), current, argCurrent, loopTotal);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
indexInnerReductionRank1Block4WithMerge<X, Z, ReductionOp>(&(bufferX[start * inner_stride]), current, argCurrent, loopTotal, inner_stride);
|
||||||
|
}
|
||||||
|
ptrMaxIndices[thread_id] = argCurrent + start;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
if (inner_stride == 1) {
|
||||||
|
indexInnerReduction<X, Z, ReductionOp, LastIndexFaster>(second_rank, bufferX, current, argCurrent, inner_bases, inner_strides, start, stop, inner_last, inner_stride);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
indexInnerReduction<X, Z, ReductionOp, LastIndexFaster>(second_rank, bufferX, current, argCurrent, inner_bases, inner_strides, start, stop, inner_last, inner_stride);
|
||||||
|
}
|
||||||
|
ptrMaxIndices[thread_id] = argCurrent;
|
||||||
|
}
|
||||||
|
ptrMaxValues[thread_id] = current;
|
||||||
|
};
|
||||||
|
#if 0
|
||||||
|
int Count = 0;
|
||||||
|
func(0, 0, inner_total, 1);
|
||||||
|
#else
|
||||||
|
int Count = samediff::Threads::parallel_tad(func, 0, inner_total, 1, maxThreads);
|
||||||
|
#endif
|
||||||
|
Z arg = 0;
|
||||||
|
X current = ptrMaxValues[0];
|
||||||
|
|
||||||
|
for (Z i = 1; i < Count; i++) {
|
||||||
|
ReductionOp::update(current, arg, ptrMaxValues[i], i);
|
||||||
|
}
|
||||||
|
|
||||||
|
*outputZ = ptrMaxIndices[arg];
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template<typename X, typename Z, typename ReductionOp, typename Movement, bool LastIndexFaster = true>
|
||||||
|
void argReductionInnerCases(Movement& movement, Nd4jLong loopTotal, const int& second_rank,const Nd4jLong* inner_bases,const Nd4jLong* inner_strides, const X* bufferX, Z* outputZ)
|
||||||
|
{
|
||||||
|
|
||||||
|
Nd4jLong inner_stride = true /*LastIndexFaster*/ ? inner_strides[second_rank - 1] : inner_strides[0];
|
||||||
|
|
||||||
|
Nd4jLong loopTotal_K = loopTotal / 4;
|
||||||
|
Nd4jLong loopTotal_Tail = loopTotal & 3;
|
||||||
|
if (inner_stride == 1) {
|
||||||
|
if (second_rank == 1) {
|
||||||
|
LOG_CALLS(0)
|
||||||
|
Nd4jLong inner_total = getLength<true>(inner_bases, second_rank);
|
||||||
|
for (Nd4jLong i = 0; i < loopTotal_K; i++) {
|
||||||
|
const X* buffer0 = &(bufferX[movement.First()]);
|
||||||
|
Z* output0 = &(outputZ[movement.Second()]);
|
||||||
|
movement.increment();
|
||||||
|
const X* buffer1 = &(bufferX[movement.First()]);
|
||||||
|
Z* output1 = &(outputZ[movement.Second()]);
|
||||||
|
movement.increment();
|
||||||
|
const X* buffer2 = &(bufferX[movement.First()]);
|
||||||
|
Z* output2 = &(outputZ[movement.Second()]);
|
||||||
|
movement.increment();
|
||||||
|
const X* buffer3 = &(bufferX[movement.First()]);
|
||||||
|
Z* output3 = &(outputZ[movement.Second()]);
|
||||||
|
movement.increment();
|
||||||
|
indexInnerReductionRank1Block4<X, Z, ReductionOp>(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_total);
|
||||||
|
|
||||||
|
}
|
||||||
|
if (inner_total >= 2048) {
|
||||||
|
for (Nd4jLong i = 0; i < loopTotal_Tail; i++) {
|
||||||
|
X current;
|
||||||
|
const X* buffer0 = &(bufferX[movement.First()]);
|
||||||
|
indexInnerReductionRank1Block4WithMerge<X, Z, ReductionOp>(buffer0, current, outputZ[movement.Second()], inner_total);
|
||||||
|
movement.increment();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for (Nd4jLong i = 0; i < loopTotal_Tail; i++) {
|
||||||
|
X current;
|
||||||
|
const X* buffer0 = &(bufferX[movement.First()]);
|
||||||
|
indexInnerReductionRank1<X, Z, ReductionOp>(buffer0, current, outputZ[movement.Second()], inner_total);
|
||||||
|
movement.increment();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
Nd4jLong inner_last;
|
||||||
|
Nd4jLong inner_loop = getLength<true>(inner_bases, second_rank, 1, inner_last);
|
||||||
|
if (second_rank == 2) {
|
||||||
|
LOG_CALLS(1)
|
||||||
|
for (Nd4jLong i = 0; i < loopTotal_K; i++) {
|
||||||
|
const X* buffer0 = &(bufferX[movement.First()]);
|
||||||
|
Z* output0 = &(outputZ[movement.Second()]);
|
||||||
|
movement.increment();
|
||||||
|
const X* buffer1 = &(bufferX[movement.First()]);
|
||||||
|
Z* output1 = &(outputZ[movement.Second()]);
|
||||||
|
movement.increment();
|
||||||
|
const X* buffer2 = &(bufferX[movement.First()]);
|
||||||
|
Z* output2 = &(outputZ[movement.Second()]);
|
||||||
|
movement.increment();
|
||||||
|
const X* buffer3 = &(bufferX[movement.First()]);
|
||||||
|
Z* output3 = &(outputZ[movement.Second()]);
|
||||||
|
movement.increment();
|
||||||
|
indexInnerReductionConstRankBlock4<X, Z, ReductionOp, 2>(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_bases, inner_strides,
|
||||||
|
inner_loop, inner_last);
|
||||||
|
|
||||||
|
}
|
||||||
|
for (Nd4jLong i = 0; i < loopTotal_Tail; i++) {
|
||||||
|
X current;
|
||||||
|
const X* buffer0 = &(bufferX[movement.First()]);
|
||||||
|
indexInnerReductionConstRank<X, Z, ReductionOp, 2>(buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides, inner_loop, inner_last);
|
||||||
|
movement.increment();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
else if (second_rank == 3) {
|
||||||
|
LOG_CALLS(2)
|
||||||
|
for (Nd4jLong i = 0; i < loopTotal_K; i++) {
|
||||||
|
const X* buffer0 = &(bufferX[movement.First()]);
|
||||||
|
Z* output0 = &(outputZ[movement.Second()]);
|
||||||
|
movement.increment();
|
||||||
|
const X* buffer1 = &(bufferX[movement.First()]);
|
||||||
|
Z* output1 = &(outputZ[movement.Second()]);
|
||||||
|
movement.increment();
|
||||||
|
const X* buffer2 = &(bufferX[movement.First()]);
|
||||||
|
Z* output2 = &(outputZ[movement.Second()]);
|
||||||
|
movement.increment();
|
||||||
|
const X* buffer3 = &(bufferX[movement.First()]);
|
||||||
|
Z* output3 = &(outputZ[movement.Second()]);
|
||||||
|
movement.increment();
|
||||||
|
indexInnerReductionConstRankBlock4<X, Z, ReductionOp, 3>(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_bases, inner_strides,
|
||||||
|
inner_loop, inner_last);
|
||||||
|
|
||||||
|
}
|
||||||
|
for (Nd4jLong i = 0; i < loopTotal_Tail; i++) {
|
||||||
|
X current;
|
||||||
|
const X* buffer0 = &(bufferX[movement.First()]);
|
||||||
|
indexInnerReductionConstRank<X, Z, ReductionOp, 3>(buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides,
|
||||||
|
inner_loop, inner_last);
|
||||||
|
movement.increment();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
LOG_CALLS(3)
|
||||||
|
//nd4j_printf("-----%d \n", loopTotal);
|
||||||
|
for (Nd4jLong i = 0; i < loopTotal; i++) {
|
||||||
|
X current;
|
||||||
|
const X* buffer0 = &(bufferX[movement.First()]);
|
||||||
|
indexInnerReduction<X, Z, ReductionOp>(second_rank, buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides, 0,
|
||||||
|
inner_loop, inner_last);
|
||||||
|
movement.increment();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
if (second_rank == 1) {
|
||||||
|
LOG_CALLS(10)
|
||||||
|
Nd4jLong inner_total = getLength<true>(inner_bases, second_rank);
|
||||||
|
for (Nd4jLong i = 0; i < loopTotal_K; i++) {
|
||||||
|
const X* buffer0 = &(bufferX[movement.First()]);
|
||||||
|
Z* output0 = &(outputZ[movement.Second()]);
|
||||||
|
movement.increment();
|
||||||
|
const X* buffer1 = &(bufferX[movement.First()]);
|
||||||
|
Z* output1 = &(outputZ[movement.Second()]);
|
||||||
|
movement.increment();
|
||||||
|
const X* buffer2 = &(bufferX[movement.First()]);
|
||||||
|
Z* output2 = &(outputZ[movement.Second()]);
|
||||||
|
movement.increment();
|
||||||
|
const X* buffer3 = &(bufferX[movement.First()]);
|
||||||
|
Z* output3 = &(outputZ[movement.Second()]);
|
||||||
|
movement.increment();
|
||||||
|
indexInnerReductionRank1Block4<X, Z, ReductionOp>(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_total, inner_stride);
|
||||||
|
|
||||||
|
}
|
||||||
|
if (inner_total >= 2048) {
|
||||||
|
for (Nd4jLong i = 0; i < loopTotal_Tail; i++) {
|
||||||
|
X current;
|
||||||
|
const X* buffer0 = &(bufferX[movement.First()]);
|
||||||
|
indexInnerReductionRank1Block4WithMerge<X, Z, ReductionOp>(buffer0, current, outputZ[movement.Second()], inner_total, inner_stride);
|
||||||
|
movement.increment();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for (Nd4jLong i = 0; i < loopTotal_Tail; i++) {
|
||||||
|
X current;
|
||||||
|
const X* buffer0 = &(bufferX[movement.First()]);
|
||||||
|
indexInnerReductionRank1<X, Z, ReductionOp>(buffer0, current, outputZ[movement.Second()], inner_total, inner_stride);
|
||||||
|
movement.increment();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
Nd4jLong inner_last;
|
||||||
|
Nd4jLong inner_loop = getLength<true>(inner_bases, second_rank, 1, inner_last);
|
||||||
|
if (second_rank == 2) {
|
||||||
|
LOG_CALLS(11)
|
||||||
|
for (Nd4jLong i = 0; i < loopTotal_K; i++) {
|
||||||
|
const X* buffer0 = &(bufferX[movement.First()]);
|
||||||
|
Z* output0 = &(outputZ[movement.Second()]);
|
||||||
|
movement.increment();
|
||||||
|
const X* buffer1 = &(bufferX[movement.First()]);
|
||||||
|
Z* output1 = &(outputZ[movement.Second()]);
|
||||||
|
movement.increment();
|
||||||
|
const X* buffer2 = &(bufferX[movement.First()]);
|
||||||
|
Z* output2 = &(outputZ[movement.Second()]);
|
||||||
|
movement.increment();
|
||||||
|
const X* buffer3 = &(bufferX[movement.First()]);
|
||||||
|
Z* output3 = &(outputZ[movement.Second()]);
|
||||||
|
movement.increment();
|
||||||
|
indexInnerReductionConstRankBlock4<X, Z, ReductionOp, 2>(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_bases, inner_strides,
|
||||||
|
inner_loop, inner_last, inner_stride);
|
||||||
|
|
||||||
|
}
|
||||||
|
for (Nd4jLong i = 0; i < loopTotal_Tail; i++) {
|
||||||
|
X current;
|
||||||
|
const X* buffer0 = &(bufferX[movement.First()]);
|
||||||
|
indexInnerReductionConstRank<X, Z, ReductionOp, 2>(buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides,
|
||||||
|
inner_loop, inner_last, inner_stride);
|
||||||
|
movement.increment();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
else if (second_rank == 3) {
|
||||||
|
LOG_CALLS(12)
|
||||||
|
for (Nd4jLong i = 0; i < loopTotal_K; i++) {
|
||||||
|
const X* buffer0 = &(bufferX[movement.First()]);
|
||||||
|
Z* output0 = &(outputZ[movement.Second()]);
|
||||||
|
movement.increment();
|
||||||
|
const X* buffer1 = &(bufferX[movement.First()]);
|
||||||
|
Z* output1 = &(outputZ[movement.Second()]);
|
||||||
|
movement.increment();
|
||||||
|
const X* buffer2 = &(bufferX[movement.First()]);
|
||||||
|
Z* output2 = &(outputZ[movement.Second()]);
|
||||||
|
movement.increment();
|
||||||
|
const X* buffer3 = &(bufferX[movement.First()]);
|
||||||
|
Z* output3 = &(outputZ[movement.Second()]);
|
||||||
|
movement.increment();
|
||||||
|
indexInnerReductionConstRankBlock4<X, Z, ReductionOp, 3>(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_bases, inner_strides,
|
||||||
|
inner_loop, inner_last, inner_stride);
|
||||||
|
|
||||||
|
}
|
||||||
|
for (Nd4jLong i = 0; i < loopTotal_Tail; i++) {
|
||||||
|
X current;
|
||||||
|
const X* buffer0 = &(bufferX[movement.First()]);
|
||||||
|
indexInnerReductionConstRank<X, Z, ReductionOp, 3>(buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides,
|
||||||
|
inner_loop, inner_last, inner_stride);
|
||||||
|
movement.increment();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
LOG_CALLS(13)
|
||||||
|
//nd4j_printf("-------%d inner loop %d inner_last %d\n", loopTotal, inner_loop,inner_last);
|
||||||
|
for (Nd4jLong i = 0; i < loopTotal; i++) {
|
||||||
|
X current;
|
||||||
|
const X* buffer0 = &(bufferX[movement.First()]);
|
||||||
|
indexInnerReduction<X, Z, ReductionOp>(second_rank, buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides, 0,
|
||||||
|
inner_loop, inner_last, inner_stride);
|
||||||
|
movement.increment();
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename X, typename Z, typename ReductionOp, bool LastIndexFaster = true>
|
||||||
|
void argIndexCaseNonScalar(const int& first_rank, const int& output_rank, bool squashed, const int& second_rank,
|
||||||
|
const Nd4jLong*& outer_bases,const Nd4jLong* outer_strides,const Nd4jLong* output_strides, const Nd4jLong &output_stride,
|
||||||
|
const Nd4jLong*& inner_bases,const Nd4jLong* inner_strides, const X* bufferX, Z* outputZ)
|
||||||
|
{
|
||||||
|
|
||||||
|
Nd4jLong total = getLength<LastIndexFaster>(outer_bases, first_rank);
|
||||||
|
Nd4jLong inner_stride = true /*LastIndexFaster*/ ? inner_strides[second_rank - 1] : inner_strides[0];
|
||||||
|
Nd4jLong outer_stride = LastIndexFaster ? outer_strides[second_rank - 1] : outer_strides[0];
|
||||||
|
auto func = [first_rank, output_rank, squashed, outer_bases, outer_strides, output_strides, output_stride, second_rank, inner_bases, inner_strides, bufferX, outputZ](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void {
|
||||||
|
|
||||||
|
Nd4jLong loopTotal = stop - start;
|
||||||
|
Nd4jLong stride = LastIndexFaster ? outer_strides[first_rank - 1] : outer_strides[0];
|
||||||
|
if (first_rank == 1) {
|
||||||
|
|
||||||
|
if (stride == 1) {
|
||||||
|
ZipGenericCoordsRank1Stride1 movement;
|
||||||
|
movement.init(nullptr, nullptr, nullptr, 0, start);
|
||||||
|
argReductionInnerCases<X, Z, ReductionOp>(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
ZipGenericCoordsRank1BothStrideN movement;
|
||||||
|
movement.init(nullptr, &stride, &output_stride, 0, start);
|
||||||
|
argReductionInnerCases<X, Z, ReductionOp>(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
else if (squashed && first_rank <= output_rank) {
|
||||||
|
if (first_rank == 2) {
|
||||||
|
if (output_stride == 1) {
|
||||||
|
ZipGenericCoordsConstMovementSecondStride1<2, LastIndexFaster> movement;
|
||||||
|
movement.init(outer_bases, outer_strides, nullptr, first_rank, start);
|
||||||
|
argReductionInnerCases<X, Z, ReductionOp>(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ);
|
||||||
|
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
ZipGenericCoordsConstMovementSecondStrideN<2, LastIndexFaster> movement;
|
||||||
|
movement.init(outer_bases, outer_strides, &output_stride, first_rank, start);
|
||||||
|
argReductionInnerCases<X, Z, ReductionOp>(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (first_rank == 3) {
|
||||||
|
if (output_stride == 1) {
|
||||||
|
ZipGenericCoordsConstMovementSecondStride1<3, LastIndexFaster> movement;
|
||||||
|
movement.init(outer_bases, outer_strides, nullptr, first_rank, start);
|
||||||
|
argReductionInnerCases<X, Z, ReductionOp>(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ);
|
||||||
|
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
ZipGenericCoordsConstMovementSecondStrideN<3, LastIndexFaster> movement;
|
||||||
|
movement.init(outer_bases, outer_strides, &output_stride, first_rank, start);
|
||||||
|
argReductionInnerCases<X, Z, ReductionOp>(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
ZipGenericCoordsMovementSecondStrideN< LastIndexFaster> movement;
|
||||||
|
movement.init(outer_bases, outer_strides, &output_stride, first_rank, start);
|
||||||
|
|
||||||
|
argReductionInnerCases<X, Z, ReductionOp>(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
ZipGenericCoordsMovement<LastIndexFaster> movement;
|
||||||
|
movement.init(outer_bases, outer_strides, output_strides, first_rank, start);
|
||||||
|
|
||||||
|
argReductionInnerCases<X, Z, ReductionOp>(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
};
|
||||||
|
#if 0
|
||||||
|
func(0, 0, total, 1);
|
||||||
|
#else
|
||||||
|
//
|
||||||
|
uint32_t numThreads = sd::Environment::getInstance()->maxMasterThreads();
|
||||||
|
Nd4jLong inner_total = getLength<true>(inner_bases, second_rank);
|
||||||
|
if (total * inner_total <= threadingThreshold) {
|
||||||
|
numThreads = 1;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
if (inner_stride > outer_stride && total <= 256) {
|
||||||
|
auto desired = total > 4 ? (total / 4) : 1;
|
||||||
|
numThreads = numThreads > desired ? desired : numThreads;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
samediff::Threads::parallel_tad(func, 0, total, 1, numThreads);
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename X, typename Z, typename ReductionOp>
|
||||||
|
void argIndex_(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
|
||||||
|
char input_order = input.ordering();
|
||||||
|
bool try_squash_outer = (input_order == output.ordering()) && output.ews() != 0;
|
||||||
|
const Nd4jLong* input_shapeInfo = input.shapeInfo();
|
||||||
|
const Nd4jLong* output_shapeInfo = output.shapeInfo();
|
||||||
|
const Nd4jLong rank = input_shapeInfo[0];
|
||||||
|
const Nd4jLong* input_bases = &(input_shapeInfo[1]);
|
||||||
|
const Nd4jLong* input_strides = &(input_shapeInfo[rank + 1]);
|
||||||
|
const Nd4jLong output_rank = output_shapeInfo[0];
|
||||||
|
const Nd4jLong* output_strides = &(output_shapeInfo[output_rank + 1]);
|
||||||
|
Nd4jLong new_bases[MAX_RANK];
|
||||||
|
Nd4jLong new_strides[MAX_RANK];
|
||||||
|
int first_begin, first_end, second_begin, second_end;
|
||||||
|
//rePartition into two parts based on the selection
|
||||||
|
rePartition(input_order, dimensions, rank, input_bases, input_strides, new_bases, new_strides, first_begin, first_end, second_begin, second_end, try_squash_outer, input_order == 'c');
|
||||||
|
int first_rank = first_end - first_begin; //the first rank can be 0 for scalar cases
|
||||||
|
int second_rank = second_end - second_begin;
|
||||||
|
auto bufferX = input.bufferAsT<X>();
|
||||||
|
auto outputZ = output.bufferAsT<Z>();
|
||||||
|
const Nd4jLong* outer_bases = &(new_bases[first_begin]);
|
||||||
|
const Nd4jLong* outer_strides = &(new_strides[first_begin]);
|
||||||
|
const Nd4jLong* inner_bases = &(new_bases[second_begin]);
|
||||||
|
const Nd4jLong* inner_strides = &(new_strides[second_begin]);
|
||||||
|
const Nd4jLong output_stride = output.ordering() == 'c' ? output_strides[output_rank-1]:output_strides[0];
|
||||||
|
if (input_order == 'c') {
|
||||||
|
if (first_rank == 0) {
|
||||||
|
argIndexCase1Scalar<X, Z, ReductionOp>(second_rank, inner_bases, inner_strides, bufferX, outputZ);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
argIndexCaseNonScalar<X, Z, ReductionOp>(first_rank, output_rank, try_squash_outer, second_rank, outer_bases, outer_strides, output_strides,
|
||||||
|
output_stride,inner_bases, inner_strides, bufferX, outputZ);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
if (first_rank == 0) {
|
||||||
|
LOG_CALLS(0);
|
||||||
|
if (second_rank == 1) {
|
||||||
|
argIndexCase1Scalar<X, Z, ReductionOp, false>(second_rank, inner_bases, inner_strides, bufferX, outputZ);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
argIndexCase1Scalar<X, Z, ReductionOp, true>(second_rank, inner_bases, inner_strides, bufferX, outputZ);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
LOG_CALLS(1);
|
||||||
|
argIndexCaseNonScalar<X, Z, ReductionOp,false>(first_rank, output_rank, try_squash_outer, second_rank, outer_bases, outer_strides, output_strides,
|
||||||
|
output_stride, inner_bases, inner_strides, bufferX, outputZ);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename X, typename Z>
|
||||||
|
struct IndexMax {
|
||||||
|
static FORCEINLINE void update(X& current, Z& currentIndex, const X& candidate, const Z& candidateIndex) {
|
||||||
|
if (candidate > current) {
|
||||||
|
current = candidate;
|
||||||
|
currentIndex = candidateIndex;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename X, typename Z>
|
||||||
|
struct IndexMin {
|
||||||
|
static FORCEINLINE void update(X& current, Z& currentIndex, const X& candidate, const Z& candidateIndex) {
|
||||||
|
if (candidate < current) {
|
||||||
|
current = candidate;
|
||||||
|
currentIndex = candidateIndex;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename X, typename Z>
|
||||||
|
struct IndexAbsMax {
|
||||||
|
static FORCEINLINE void update(X& current, Z& currentIndex, const X& candidate, const Z& candidateIndex) {
|
||||||
|
auto absCandidate = sd::math::nd4j_abs<X>(candidate);
|
||||||
|
if (absCandidate > current) {
|
||||||
|
current = absCandidate;
|
||||||
|
currentIndex = candidateIndex;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename X, typename Z>
|
||||||
|
struct IndexAbsMin {
|
||||||
|
static FORCEINLINE void update(X& current, Z& currentIndex, const X& candidate, const Z& candidateIndex) {
|
||||||
|
auto absCandidate = sd::math::nd4j_abs<X>(candidate);
|
||||||
|
if (absCandidate < current) {
|
||||||
|
current = absCandidate;
|
||||||
|
currentIndex = candidateIndex;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
template<typename X, typename Z>
|
||||||
|
void argMax_(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
|
||||||
|
return argIndex_<X, Z, IndexMax<X, Z>>(input, output, dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename X, typename Z>
|
||||||
|
void argMin_(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
|
||||||
|
return argIndex_<X, Z, IndexMin<X, Z>>(input, output, dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename X, typename Z>
|
||||||
|
void argAbsMax_(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
|
||||||
|
return argIndex_<X, Z, IndexAbsMax<X, Z>>(input, output, dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename X, typename Z>
|
||||||
|
void argAbsMin_(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
|
||||||
|
return argIndex_<X, Z, IndexAbsMin<X, Z>>(input, output, dimensions);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,106 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2020 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/reductions.h>
|
||||||
|
#include <legacy/NativeOpExecutioner.h>
|
||||||
|
#include <helpers/ConstantTadHelper.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
//////////////////////////////////////////////////////////////////////////
|
||||||
|
void argMax(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
|
||||||
|
NDArray::prepareSpecialUse({&output}, {&input});
|
||||||
|
if (output.isScalar()) {
|
||||||
|
NativeOpExecutioner::execIndexReduceScalar(LaunchContext::defaultContext(), indexreduce::Ops::IndexMax, input.buffer(), input.shapeInfo(), input.specialBuffer(), input.specialShapeInfo(), nullptr, output.buffer(), output.shapeInfo(), output.specialBuffer(), output.specialShapeInfo());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dimensions);
|
||||||
|
|
||||||
|
NativeOpExecutioner::execIndexReduce(LaunchContext::defaultContext(), indexreduce::Ops::IndexMax,
|
||||||
|
input.buffer(), input.shapeInfo(), input.specialBuffer(), input.specialShapeInfo(),
|
||||||
|
nullptr,
|
||||||
|
output.buffer(), output.shapeInfo(), output.specialBuffer(), output.specialShapeInfo(),
|
||||||
|
(int*) nullptr, dimensions.size(),
|
||||||
|
tadPack.specialShapeInfo(), tadPack.specialOffsets());
|
||||||
|
}
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({ &output }, { &input });
|
||||||
|
}
|
||||||
|
|
||||||
|
void argMin(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
|
||||||
|
NDArray::prepareSpecialUse({ &output }, { &input });
|
||||||
|
if (output.isScalar()) {
|
||||||
|
NativeOpExecutioner::execIndexReduceScalar(LaunchContext::defaultContext(), indexreduce::Ops::IndexMin, input.buffer(), input.shapeInfo(), input.specialBuffer(), input.specialShapeInfo(), nullptr, output.buffer(), output.shapeInfo(), output.specialBuffer(), output.specialShapeInfo());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dimensions);
|
||||||
|
|
||||||
|
NativeOpExecutioner::execIndexReduce(LaunchContext::defaultContext(), indexreduce::Ops::IndexMin,
|
||||||
|
input.buffer(), input.shapeInfo(), input.specialBuffer(), input.specialShapeInfo(),
|
||||||
|
nullptr,
|
||||||
|
output.buffer(), output.shapeInfo(), output.specialBuffer(), output.specialShapeInfo(),
|
||||||
|
(int*) nullptr, dimensions.size(),
|
||||||
|
tadPack.specialShapeInfo(), tadPack.specialOffsets());
|
||||||
|
}
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({ &output }, { &input });
|
||||||
|
}
|
||||||
|
|
||||||
|
void argAbsMax(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
|
||||||
|
NDArray::prepareSpecialUse({ &output }, { &input });
|
||||||
|
if (output.isScalar()) {
|
||||||
|
NativeOpExecutioner::execIndexReduceScalar(LaunchContext::defaultContext(), indexreduce::Ops::IndexAbsoluteMax, input.buffer(), input.shapeInfo(), input.specialBuffer(), input.specialShapeInfo(), nullptr, output.buffer(), output.shapeInfo(), output.specialBuffer(), output.specialShapeInfo());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dimensions);
|
||||||
|
|
||||||
|
NativeOpExecutioner::execIndexReduce(LaunchContext::defaultContext(), indexreduce::Ops::IndexAbsoluteMax,
|
||||||
|
input.buffer(), input.shapeInfo(), input.specialBuffer(), input.specialShapeInfo(),
|
||||||
|
nullptr,
|
||||||
|
output.buffer(), output.shapeInfo(), output.specialBuffer(), output.specialShapeInfo(),
|
||||||
|
(int*) nullptr, dimensions.size(),
|
||||||
|
tadPack.specialShapeInfo(), tadPack.specialOffsets());
|
||||||
|
}
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({ &output }, { &input });
|
||||||
|
}
|
||||||
|
|
||||||
|
void argAbsMin(const NDArray& input, NDArray& output, const std::vector<int>& dimensions) {
|
||||||
|
NDArray::prepareSpecialUse({ &output }, { &input });
|
||||||
|
if (output.isScalar()) {
|
||||||
|
NativeOpExecutioner::execIndexReduceScalar(LaunchContext::defaultContext(), indexreduce::Ops::IndexAbsoluteMin, input.buffer(), input.shapeInfo(), input.specialBuffer(), input.specialShapeInfo(), nullptr, output.buffer(), output.shapeInfo(), output.specialBuffer(), output.specialShapeInfo());
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dimensions);
|
||||||
|
|
||||||
|
NativeOpExecutioner::execIndexReduce(LaunchContext::defaultContext(), indexreduce::Ops::IndexAbsoluteMin,
|
||||||
|
input.buffer(), input.shapeInfo(), input.specialBuffer(), input.specialShapeInfo(),
|
||||||
|
nullptr,
|
||||||
|
output.buffer(), output.shapeInfo(), output.specialBuffer(), output.specialShapeInfo(),
|
||||||
|
(int *) nullptr, dimensions.size(),
|
||||||
|
tadPack.specialShapeInfo(), tadPack.specialOffsets());
|
||||||
|
}
|
||||||
|
|
||||||
|
NDArray::registerSpecialUse({&output}, {&input});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,41 @@
|
||||||
|
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2019 Konduit K.K.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
//
|
||||||
|
// @author AbdelRauf (rauf@konduit.ai)
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef LIBND4J_HELPERS_REDUCTIONS_H
|
||||||
|
#define LIBND4J_HELPERS_REDUCTIONS_H
|
||||||
|
|
||||||
|
#include <system/op_boilerplate.h>
|
||||||
|
#include <math/templatemath.h>
|
||||||
|
#include <array/NDArray.h>
|
||||||
|
|
||||||
|
namespace sd {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
|
||||||
|
void argMax(const NDArray& input, NDArray& output, const std::vector<int>& dimensions);
|
||||||
|
void argAbsMax(const NDArray& input, NDArray& output, const std::vector<int>& dimensions);
|
||||||
|
void argMin(const NDArray& input, NDArray& output, const std::vector<int>& dimensions);
|
||||||
|
void argAbsMin(const NDArray& input, NDArray& output, const std::vector<int>& dimensions);
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -40,6 +40,19 @@ public:
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests19, test_argmax_maxint_vector_1) {
|
||||||
|
auto x = NDArrayFactory::create<float>('c', {3}, {0.1f, 0.5f, 0.7f});
|
||||||
|
auto z = NDArrayFactory::create<Nd4jLong>(0);
|
||||||
|
auto e = NDArrayFactory::create<Nd4jLong>(2);
|
||||||
|
|
||||||
|
sd::ops::argmax op;
|
||||||
|
auto status = op.execute({&x}, {&z}, {DataTypeUtils::max<int>()});
|
||||||
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
ASSERT_EQ(e, z);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests19, test_threshold_encode_1) {
|
TEST_F(DeclarableOpsTests19, test_threshold_encode_1) {
|
||||||
auto x = NDArrayFactory::create<double>('c', {3}, {1.5, 2.5, -3.5});
|
auto x = NDArrayFactory::create<double>('c', {3}, {1.5, 2.5, -3.5});
|
||||||
auto exp_encoded = NDArrayFactory::create<int>('c', {7}, {3, 3, 1056964608, 0, 1, 2, -3});
|
auto exp_encoded = NDArrayFactory::create<int>('c', {7}, {3, 3, 1056964608, 0, 1, 2, -3});
|
||||||
|
@ -276,6 +289,7 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_decode_2) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
TEST_F(DeclarableOpsTests19, test_matmul_ccc) {
|
TEST_F(DeclarableOpsTests19, test_matmul_ccc) {
|
||||||
auto x = NDArrayFactory::create<float>('c', {10, 10});
|
auto x = NDArrayFactory::create<float>('c', {10, 10});
|
||||||
auto y = NDArrayFactory::create<float>('c', {10, 10});
|
auto y = NDArrayFactory::create<float>('c', {10, 10});
|
||||||
|
|
|
@ -43,9 +43,12 @@
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <performance/benchmarking/FullBenchmarkSuit.h>
|
#include <performance/benchmarking/FullBenchmarkSuit.h>
|
||||||
#include <performance/benchmarking/LightBenchmarkSuit.h>
|
#include <performance/benchmarking/LightBenchmarkSuit.h>
|
||||||
|
#include <random>
|
||||||
#include <ops/declarable/helpers/legacy_helpers.h>
|
#include <ops/declarable/helpers/legacy_helpers.h>
|
||||||
#include <ops/declarable/helpers/addBias.h>
|
#include <ops/declarable/helpers/addBias.h>
|
||||||
|
#include <ops/declarable/helpers/axis.h>
|
||||||
|
#include <ops/declarable/helpers/reductions.h>
|
||||||
|
#include <helpers/LoopsCoordsHelper.h>
|
||||||
|
|
||||||
using namespace sd;
|
using namespace sd;
|
||||||
using namespace sd::graph;
|
using namespace sd::graph;
|
||||||
|
@ -275,6 +278,256 @@ TEST_F(PlaygroundTests, test_one_off_ops_1) {
|
||||||
op.execute({&x, &y}, {&z});
|
op.execute({&x, &y}, {&z});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#if defined(INDEX_REDUCTIONS_BENCH_TESTS)
|
||||||
|
//temporarly, testing against the original one
|
||||||
|
void original_argmax(const NDArray& input, std::vector<int>& axis, NDArray& output) {
|
||||||
|
sd::ops::helpers::adjustAxis(input.rankOf(), axis);
|
||||||
|
input.applyIndexReduce(sd::indexreduce::IndexMax, output, axis);
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
void fill_random(sd::NDArray& arr) {
|
||||||
|
Nd4jLong coords[MAX_RANK] = {};
|
||||||
|
std::random_device rd;
|
||||||
|
std::mt19937 gen(rd());
|
||||||
|
//for floats
|
||||||
|
std::uniform_real_distribution<T> dis((T)-10.0, (T)22.9);
|
||||||
|
T* x = arr.bufferAsT<T>();
|
||||||
|
Nd4jLong* shapeInfo = arr.getShapeInfo();
|
||||||
|
Nd4jLong* strides = arr.stridesOf();
|
||||||
|
Nd4jLong rank = shapeInfo[0];
|
||||||
|
Nd4jLong* bases = &(shapeInfo[1]);
|
||||||
|
size_t t = 1;
|
||||||
|
for (size_t i = 0; i < rank ; i++) {
|
||||||
|
t *= bases[i];
|
||||||
|
}
|
||||||
|
size_t offset = 0;
|
||||||
|
if (arr.ordering() == 'c') {
|
||||||
|
|
||||||
|
for (size_t i = 0; i < t; i++) {
|
||||||
|
x[offset] = dis(gen) ;
|
||||||
|
offset = sd::inc_coords(bases, strides, coords, offset, rank);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
|
||||||
|
for (size_t i = 0; i < t; i++) {
|
||||||
|
x[offset] = dis(gen) ;
|
||||||
|
offset = sd::inc_coords<false>(bases, strides, coords, offset, rank);
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void testLegacy(bool random) {
|
||||||
|
#if 0
|
||||||
|
int bases[] = { 3, 2, 4, 5, 7 };
|
||||||
|
constexpr int Loop = 1;
|
||||||
|
#else
|
||||||
|
int bases[] = { 8, 32, 64, 32, 64 };
|
||||||
|
constexpr int Loop = 10;
|
||||||
|
#endif
|
||||||
|
constexpr int N = 5;
|
||||||
|
|
||||||
|
auto x = NDArrayFactory::create<float>('c', { bases[0], bases[1], bases[2], bases[3], bases[4] });
|
||||||
|
if (!random) {
|
||||||
|
x.linspace(1);
|
||||||
|
}
|
||||||
|
else{
|
||||||
|
fill_random<float>(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define COMBINATIONS 1
|
||||||
|
#if COMBINATIONS
|
||||||
|
//https://www.rosettacode.org/wiki/Combinations#C.2B.2B
|
||||||
|
for (int k = N; k >= 1; k--) {
|
||||||
|
|
||||||
|
std::string bitmask(k, 1); // K leading 1's
|
||||||
|
bitmask.resize(N, 0); // N-K trailing 0's
|
||||||
|
|
||||||
|
do {
|
||||||
|
|
||||||
|
|
||||||
|
std::vector<int> dimension;
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> output_bases;
|
||||||
|
|
||||||
|
for (int i = 0; i < N; ++i) // [0..N-1] integers
|
||||||
|
{
|
||||||
|
if (bitmask[i]) dimension.push_back(i);
|
||||||
|
else {
|
||||||
|
output_bases.push_back(bases[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
std::vector<int> dimension = { 0,1,2,3 };
|
||||||
|
int k = 4;
|
||||||
|
#endif
|
||||||
|
auto dim = NDArrayFactory::create<int>(dimension);
|
||||||
|
|
||||||
|
#if 1
|
||||||
|
nd4j_printf("C(N:%d K:%d) \n", N, k);
|
||||||
|
dim.printIndexedBuffer("Dimension");
|
||||||
|
for (int xind : dimension) {
|
||||||
|
nd4j_printf(" %d ,", bases[xind]);
|
||||||
|
}
|
||||||
|
nd4j_printf("%s", "\n");
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> values;
|
||||||
|
sd::ResultSet result;
|
||||||
|
for (int e = 0; e < Loop; e++) {
|
||||||
|
auto timeStart = std::chrono::system_clock::now();
|
||||||
|
NDArray exp = output_bases.size() > 0 ? NDArrayFactory::create<Nd4jLong>('c', output_bases) : NDArrayFactory::create<Nd4jLong>(0);
|
||||||
|
original_argmax(x, dimension, exp);
|
||||||
|
auto timeEnd = std::chrono::system_clock::now();
|
||||||
|
auto outerTime = std::chrono::duration_cast<std::chrono::microseconds>(timeEnd - timeStart).count();
|
||||||
|
values.emplace_back(outerTime);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::sort(values.begin(), values.end());
|
||||||
|
|
||||||
|
nd4j_printf("Time: %lld us;\n", values[values.size() / 2]);
|
||||||
|
#if COMBINATIONS
|
||||||
|
|
||||||
|
} while (std::prev_permutation(bitmask.begin(), bitmask.end()));
|
||||||
|
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
#define DEBUG 1
|
||||||
|
|
||||||
|
void testNewReduction(bool random, bool checkCorrectness = false , char order ='c') {
|
||||||
|
std::vector<Nd4jLong> arr_dimensions;
|
||||||
|
#if defined(DEBUG)
|
||||||
|
int bases[] = { 3, 2, 3, 3, 5 ,4,7,4,7,7 };
|
||||||
|
constexpr int Loop = 1;
|
||||||
|
constexpr int N = 10;
|
||||||
|
#else
|
||||||
|
int bases[] = { 8, 32, 64, 32, 64 };
|
||||||
|
constexpr int Loop = 10;
|
||||||
|
constexpr int N = 5;
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
for (int i = 0; i < N; i++) {
|
||||||
|
arr_dimensions.push_back(bases[i]);
|
||||||
|
}
|
||||||
|
auto x = NDArrayFactory::create<float>(order,arr_dimensions);
|
||||||
|
if (!random) {
|
||||||
|
x.linspace(1);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
fill_random<float>(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
#define COMBINATIONS 1
|
||||||
|
#if COMBINATIONS
|
||||||
|
//https://www.rosettacode.org/wiki/Combinations#C.2B.2B
|
||||||
|
for (int k = N; k >= 1; k--) {
|
||||||
|
|
||||||
|
std::string bitmask(k, 1); // K leading 1's
|
||||||
|
bitmask.resize(N, 0); // N-K trailing 0's
|
||||||
|
|
||||||
|
do {
|
||||||
|
|
||||||
|
|
||||||
|
std::vector<int> dimension;
|
||||||
|
|
||||||
|
std::vector<Nd4jLong> output_bases;
|
||||||
|
|
||||||
|
for (int i = 0; i < N; ++i) // [0..N-1] integers
|
||||||
|
{
|
||||||
|
if (bitmask[i]) dimension.push_back(i);
|
||||||
|
else {
|
||||||
|
output_bases.push_back(bases[i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
std::vector<int> dimension = { 0,1,2,3 };
|
||||||
|
int k = 4;
|
||||||
|
#endif
|
||||||
|
auto dim = NDArrayFactory::create<int>(dimension);
|
||||||
|
|
||||||
|
#if 1
|
||||||
|
nd4j_printf("C(N:%d K:%d) \n", N, k);
|
||||||
|
dim.printIndexedBuffer("Dimension");
|
||||||
|
for (int xind : dimension) {
|
||||||
|
nd4j_printf(" %d ,", bases[xind]);
|
||||||
|
}
|
||||||
|
nd4j_printf("%s", "\n");
|
||||||
|
#endif
|
||||||
|
|
||||||
|
|
||||||
|
sd::ops::argmax op;
|
||||||
|
std::vector<Nd4jLong> values;
|
||||||
|
sd::ResultSet result;
|
||||||
|
for (int e = 0; e < Loop; e++) {
|
||||||
|
auto timeStart = std::chrono::system_clock::now();
|
||||||
|
result = op.evaluate({ &x, &dim }, {}, {});
|
||||||
|
auto timeEnd = std::chrono::system_clock::now();
|
||||||
|
auto outerTime = std::chrono::duration_cast<std::chrono::microseconds>(timeEnd - timeStart).count();
|
||||||
|
values.emplace_back(outerTime);
|
||||||
|
}
|
||||||
|
auto z = result.at(0);
|
||||||
|
|
||||||
|
if (checkCorrectness) {
|
||||||
|
//check for the correctness
|
||||||
|
NDArray exp = output_bases.size() > 0 ? NDArrayFactory::create<Nd4jLong>('c', output_bases) : NDArrayFactory::create<Nd4jLong>(0);
|
||||||
|
original_argmax(x, dimension, exp);
|
||||||
|
|
||||||
|
|
||||||
|
#if 0// defined(DEBUG)
|
||||||
|
x.printIndexedBuffer("X");
|
||||||
|
exp.printIndexedBuffer("Expected");
|
||||||
|
z->printIndexedBuffer("Z");
|
||||||
|
#endif
|
||||||
|
|
||||||
|
ASSERT_TRUE(exp.isSameShape(z));
|
||||||
|
ASSERT_TRUE(exp.equalsTo(z));
|
||||||
|
}
|
||||||
|
std::sort(values.begin(), values.end());
|
||||||
|
|
||||||
|
nd4j_printf("Time: %lld us;\n", values[values.size() / 2]);
|
||||||
|
#if COMBINATIONS
|
||||||
|
|
||||||
|
} while (std::prev_permutation(bitmask.begin(), bitmask.end()));
|
||||||
|
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
constexpr bool test_corr = true;
|
||||||
|
#if !defined(DEBUG)
|
||||||
|
TEST_F(PlaygroundTests, ArgMaxPerfLinspace) {
|
||||||
|
testNewReduction(false, test_corr);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
TEST_F(PlaygroundTests, ArgMaxPerfRandom) {
|
||||||
|
testNewReduction(true, test_corr);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(PlaygroundTests, ArgMaxPerfRandomOrderF) {
|
||||||
|
testNewReduction(true, test_corr, 'f');
|
||||||
|
}
|
||||||
|
|
||||||
|
#if !defined(DEBUG)
|
||||||
|
TEST_F(PlaygroundTests, ArgMaxPerfLegacyLinspace) {
|
||||||
|
testLegacy(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(PlaygroundTests, ArgMaxPerfLegacyRandom) {
|
||||||
|
testLegacy(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
/*
|
/*
|
||||||
|
|
||||||
|
|
|
@ -106,7 +106,7 @@ public class SDBaseOps {
|
||||||
public SDVariable argmax(SDVariable in, boolean keepDims, int... dimensions) {
|
public SDVariable argmax(SDVariable in, boolean keepDims, int... dimensions) {
|
||||||
SDValidation.validateNumerical("argmax", "in", in);
|
SDValidation.validateNumerical("argmax", "in", in);
|
||||||
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
||||||
return new org.nd4j.linalg.api.ops.impl.indexaccum.IMax(sd,in, keepDims, dimensions).outputVariable();
|
return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd,in, keepDims, dimensions).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -130,7 +130,7 @@ public class SDBaseOps {
|
||||||
public SDVariable argmax(String name, SDVariable in, boolean keepDims, int... dimensions) {
|
public SDVariable argmax(String name, SDVariable in, boolean keepDims, int... dimensions) {
|
||||||
SDValidation.validateNumerical("argmax", "in", in);
|
SDValidation.validateNumerical("argmax", "in", in);
|
||||||
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
||||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IMax(sd,in, keepDims, dimensions).outputVariable();
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd,in, keepDims, dimensions).outputVariable();
|
||||||
return sd.updateVariableNameAndReference(out, name);
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -153,7 +153,7 @@ public class SDBaseOps {
|
||||||
public SDVariable argmax(SDVariable in, int... dimensions) {
|
public SDVariable argmax(SDVariable in, int... dimensions) {
|
||||||
SDValidation.validateNumerical("argmax", "in", in);
|
SDValidation.validateNumerical("argmax", "in", in);
|
||||||
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
||||||
return new org.nd4j.linalg.api.ops.impl.indexaccum.IMax(sd,in, false, dimensions).outputVariable();
|
return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd,in, false, dimensions).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -176,7 +176,7 @@ public class SDBaseOps {
|
||||||
public SDVariable argmax(String name, SDVariable in, int... dimensions) {
|
public SDVariable argmax(String name, SDVariable in, int... dimensions) {
|
||||||
SDValidation.validateNumerical("argmax", "in", in);
|
SDValidation.validateNumerical("argmax", "in", in);
|
||||||
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
||||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IMax(sd,in, false, dimensions).outputVariable();
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd,in, false, dimensions).outputVariable();
|
||||||
return sd.updateVariableNameAndReference(out, name);
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -203,7 +203,7 @@ public class SDBaseOps {
|
||||||
public SDVariable argmin(SDVariable in, boolean keepDims, int... dimensions) {
|
public SDVariable argmin(SDVariable in, boolean keepDims, int... dimensions) {
|
||||||
SDValidation.validateNumerical("argmin", "in", in);
|
SDValidation.validateNumerical("argmin", "in", in);
|
||||||
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
||||||
return new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(sd,in, keepDims, dimensions).outputVariable();
|
return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd,in, keepDims, dimensions).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -230,7 +230,7 @@ public class SDBaseOps {
|
||||||
public SDVariable argmin(String name, SDVariable in, boolean keepDims, int... dimensions) {
|
public SDVariable argmin(String name, SDVariable in, boolean keepDims, int... dimensions) {
|
||||||
SDValidation.validateNumerical("argmin", "in", in);
|
SDValidation.validateNumerical("argmin", "in", in);
|
||||||
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
||||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(sd,in, keepDims, dimensions).outputVariable();
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd,in, keepDims, dimensions).outputVariable();
|
||||||
return sd.updateVariableNameAndReference(out, name);
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -256,7 +256,7 @@ public class SDBaseOps {
|
||||||
public SDVariable argmin(SDVariable in, int... dimensions) {
|
public SDVariable argmin(SDVariable in, int... dimensions) {
|
||||||
SDValidation.validateNumerical("argmin", "in", in);
|
SDValidation.validateNumerical("argmin", "in", in);
|
||||||
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
||||||
return new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(sd,in, false, dimensions).outputVariable();
|
return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd,in, false, dimensions).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -282,7 +282,7 @@ public class SDBaseOps {
|
||||||
public SDVariable argmin(String name, SDVariable in, int... dimensions) {
|
public SDVariable argmin(String name, SDVariable in, int... dimensions) {
|
||||||
SDValidation.validateNumerical("argmin", "in", in);
|
SDValidation.validateNumerical("argmin", "in", in);
|
||||||
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
||||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(sd,in, false, dimensions).outputVariable();
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd,in, false, dimensions).outputVariable();
|
||||||
return sd.updateVariableNameAndReference(out, name);
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1875,7 +1875,7 @@ public class SDMath extends SDOps {
|
||||||
public SDVariable iamax(SDVariable in, int... dimensions) {
|
public SDVariable iamax(SDVariable in, int... dimensions) {
|
||||||
SDValidation.validateNumerical("iamax", "in", in);
|
SDValidation.validateNumerical("iamax", "in", in);
|
||||||
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
|
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
|
||||||
return new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(sd,in, false, dimensions).outputVariable();
|
return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd,in, false, dimensions).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1890,7 +1890,7 @@ public class SDMath extends SDOps {
|
||||||
public SDVariable iamax(String name, SDVariable in, int... dimensions) {
|
public SDVariable iamax(String name, SDVariable in, int... dimensions) {
|
||||||
SDValidation.validateNumerical("iamax", "in", in);
|
SDValidation.validateNumerical("iamax", "in", in);
|
||||||
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
|
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
|
||||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(sd,in, false, dimensions).outputVariable();
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd,in, false, dimensions).outputVariable();
|
||||||
return sd.updateVariableNameAndReference(out, name);
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1906,7 +1906,7 @@ public class SDMath extends SDOps {
|
||||||
public SDVariable iamax(SDVariable in, boolean keepDims, int... dimensions) {
|
public SDVariable iamax(SDVariable in, boolean keepDims, int... dimensions) {
|
||||||
SDValidation.validateNumerical("iamax", "in", in);
|
SDValidation.validateNumerical("iamax", "in", in);
|
||||||
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
|
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
|
||||||
return new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(sd,in, keepDims, dimensions).outputVariable();
|
return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd,in, keepDims, dimensions).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1922,7 +1922,7 @@ public class SDMath extends SDOps {
|
||||||
public SDVariable iamax(String name, SDVariable in, boolean keepDims, int... dimensions) {
|
public SDVariable iamax(String name, SDVariable in, boolean keepDims, int... dimensions) {
|
||||||
SDValidation.validateNumerical("iamax", "in", in);
|
SDValidation.validateNumerical("iamax", "in", in);
|
||||||
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
|
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
|
||||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(sd,in, keepDims, dimensions).outputVariable();
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd,in, keepDims, dimensions).outputVariable();
|
||||||
return sd.updateVariableNameAndReference(out, name);
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1937,7 +1937,7 @@ public class SDMath extends SDOps {
|
||||||
public SDVariable iamin(SDVariable in, int... dimensions) {
|
public SDVariable iamin(SDVariable in, int... dimensions) {
|
||||||
SDValidation.validateNumerical("iamin", "in", in);
|
SDValidation.validateNumerical("iamin", "in", in);
|
||||||
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
|
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
|
||||||
return new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(sd,in, false, dimensions).outputVariable();
|
return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd,in, false, dimensions).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1952,7 +1952,7 @@ public class SDMath extends SDOps {
|
||||||
public SDVariable iamin(String name, SDVariable in, int... dimensions) {
|
public SDVariable iamin(String name, SDVariable in, int... dimensions) {
|
||||||
SDValidation.validateNumerical("iamin", "in", in);
|
SDValidation.validateNumerical("iamin", "in", in);
|
||||||
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
|
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
|
||||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(sd,in, false, dimensions).outputVariable();
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd,in, false, dimensions).outputVariable();
|
||||||
return sd.updateVariableNameAndReference(out, name);
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1968,7 +1968,7 @@ public class SDMath extends SDOps {
|
||||||
public SDVariable iamin(SDVariable in, boolean keepDims, int... dimensions) {
|
public SDVariable iamin(SDVariable in, boolean keepDims, int... dimensions) {
|
||||||
SDValidation.validateNumerical("iamin", "in", in);
|
SDValidation.validateNumerical("iamin", "in", in);
|
||||||
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
|
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
|
||||||
return new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(sd,in, keepDims, dimensions).outputVariable();
|
return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd,in, keepDims, dimensions).outputVariable();
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -1984,7 +1984,7 @@ public class SDMath extends SDOps {
|
||||||
public SDVariable iamin(String name, SDVariable in, boolean keepDims, int... dimensions) {
|
public SDVariable iamin(String name, SDVariable in, boolean keepDims, int... dimensions) {
|
||||||
SDValidation.validateNumerical("iamin", "in", in);
|
SDValidation.validateNumerical("iamin", "in", in);
|
||||||
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
|
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
|
||||||
SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(sd,in, keepDims, dimensions).outputVariable();
|
SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd,in, keepDims, dimensions).outputVariable();
|
||||||
return sd.updateVariableNameAndReference(out, name);
|
return sd.updateVariableNameAndReference(out, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -682,14 +682,6 @@ public class LegacyOpMapper {
|
||||||
|
|
||||||
public static Class<?> indexReduceClass(int opNum){
|
public static Class<?> indexReduceClass(int opNum){
|
||||||
switch (opNum){
|
switch (opNum){
|
||||||
case 0:
|
|
||||||
return IMax.class;
|
|
||||||
case 1:
|
|
||||||
return IMin.class;
|
|
||||||
case 2:
|
|
||||||
return IAMax.class;
|
|
||||||
case 3:
|
|
||||||
return IAMin.class;
|
|
||||||
case 4:
|
case 4:
|
||||||
return FirstIndex.class;
|
return FirstIndex.class;
|
||||||
case 5:
|
case 5:
|
||||||
|
|
|
@ -1055,10 +1055,6 @@ public class OpValidation {
|
||||||
IsNumericTensor.class,
|
IsNumericTensor.class,
|
||||||
//Exclude index accumulations (index out, not real-valued)
|
//Exclude index accumulations (index out, not real-valued)
|
||||||
FirstIndex.class,
|
FirstIndex.class,
|
||||||
IAMax.class,
|
|
||||||
IAMin.class,
|
|
||||||
IMax.class,
|
|
||||||
IMin.class,
|
|
||||||
LastIndex.class,
|
LastIndex.class,
|
||||||
ArgMax.class,
|
ArgMax.class,
|
||||||
ArgMin.class,
|
ArgMin.class,
|
||||||
|
|
|
@ -105,13 +105,11 @@ public class ImportClassMapping {
|
||||||
org.nd4j.linalg.api.ops.impl.image.ResizeNearestNeighbor.class,
|
org.nd4j.linalg.api.ops.impl.image.ResizeNearestNeighbor.class,
|
||||||
org.nd4j.linalg.api.ops.impl.image.ResizeArea.class,
|
org.nd4j.linalg.api.ops.impl.image.ResizeArea.class,
|
||||||
org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex.class,
|
org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex.class,
|
||||||
org.nd4j.linalg.api.ops.impl.indexaccum.IAMax.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.indexaccum.IAMin.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.indexaccum.IMax.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.indexaccum.IMin.class,
|
|
||||||
org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex.class,
|
org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex.class,
|
||||||
org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax.class,
|
org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax.class,
|
||||||
org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin.class,
|
org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgAmax.class,
|
||||||
|
org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgAmin.class,
|
||||||
org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction.class,
|
org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction.class,
|
||||||
org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D.class,
|
org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D.class,
|
||||||
org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling3D.class,
|
org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling3D.class,
|
||||||
|
|
|
@ -1,78 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.indexaccum;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseIndexAccumulation;
|
|
||||||
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Calculate the index of the max absolute value over a vector
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
@Data
|
|
||||||
public class IAMax extends BaseIndexAccumulation {
|
|
||||||
public IAMax(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
|
|
||||||
super(sameDiff, i_v, keepDims, dimensions);
|
|
||||||
}
|
|
||||||
|
|
||||||
public IAMax() {}
|
|
||||||
|
|
||||||
public IAMax(INDArray x, int... dimensions) {
|
|
||||||
this(x, false, dimensions);
|
|
||||||
}
|
|
||||||
|
|
||||||
public IAMax(INDArray x, boolean keepDims, int... dimensions) {
|
|
||||||
this(x, null, dimensions);
|
|
||||||
this.keepDims = keepDims;
|
|
||||||
}
|
|
||||||
|
|
||||||
public IAMax(INDArray x, INDArray z, int... dimensions) {
|
|
||||||
super(x, z, dimensions);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int opNum() {
|
|
||||||
return 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "iamax";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
return "AbsArgMax";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "absargmax";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> grad){
|
|
||||||
return Collections.singletonList(sameDiff.zerosLike(arg()));
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,80 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.indexaccum;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseIndexAccumulation;
|
|
||||||
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Calculate the index of the max absolute value over a vector
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
@Data
|
|
||||||
public class IAMin extends BaseIndexAccumulation {
|
|
||||||
public IAMin(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
|
|
||||||
super(sameDiff, i_v, keepDims, dimensions);
|
|
||||||
}
|
|
||||||
|
|
||||||
public IAMin() {}
|
|
||||||
|
|
||||||
public IAMin(INDArray x, int... dimensions) {
|
|
||||||
super(x, dimensions);
|
|
||||||
}
|
|
||||||
|
|
||||||
public IAMin(INDArray in, boolean keepDims, int... dimnesions){
|
|
||||||
super(in, null, dimnesions);
|
|
||||||
this.keepDims = keepDims;
|
|
||||||
}
|
|
||||||
|
|
||||||
public IAMin(INDArray x, INDArray z, int... dimensions) {
|
|
||||||
super(x, z, dimensions);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int opNum() {
|
|
||||||
return 3;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "iamin";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
return "AbsArgMin";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "absargmin";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> grad){
|
|
||||||
return Collections.singletonList(sameDiff.zerosLike(arg()));
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,87 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.indexaccum;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseIndexAccumulation;
|
|
||||||
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Calculate the index
|
|
||||||
* of max value over a vector
|
|
||||||
*
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
@Data
|
|
||||||
public class IMax extends BaseIndexAccumulation {
|
|
||||||
public IMax(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
|
|
||||||
super(sameDiff, i_v, keepDims, dimensions);
|
|
||||||
}
|
|
||||||
|
|
||||||
public IMax() {
|
|
||||||
}
|
|
||||||
|
|
||||||
public IMax(INDArray x, INDArray z, int... dimensions) {
|
|
||||||
super(x, z, dimensions);
|
|
||||||
}
|
|
||||||
|
|
||||||
public IMax(INDArray x, int... dimensions) {
|
|
||||||
super(x, null, dimensions);
|
|
||||||
}
|
|
||||||
|
|
||||||
public IMax(INDArray x, boolean keepDims, int... dimensions) {
|
|
||||||
super(x, null, dimensions);
|
|
||||||
this.keepDims = keepDims;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int opNum() {
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "imax";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
return "arg_max";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Type opType() {
|
|
||||||
return Type.INDEXREDUCE;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
|
||||||
//Not differentiable, but (assuming no ties) output does not change for a given infinitesimal change in the input
|
|
||||||
return Collections.singletonList(sameDiff.zerosLike(arg()));
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,83 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.nd4j.linalg.api.ops.impl.indexaccum;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import org.nd4j.autodiff.samediff.SDVariable;
|
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
|
||||||
import org.nd4j.imports.NoOpNameFoundException;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.BaseIndexAccumulation;
|
|
||||||
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Calculate the index of min value over a vector
|
|
||||||
*
|
|
||||||
* @author Alex Black
|
|
||||||
*/
|
|
||||||
@Data
|
|
||||||
public class IMin extends BaseIndexAccumulation {
|
|
||||||
public IMin(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
|
|
||||||
super(sameDiff, i_v, keepDims, dimensions);
|
|
||||||
}
|
|
||||||
|
|
||||||
public IMin() {
|
|
||||||
}
|
|
||||||
|
|
||||||
public IMin(INDArray x, int... dimensions) {
|
|
||||||
super(x, dimensions);
|
|
||||||
}
|
|
||||||
|
|
||||||
public IMin(INDArray x, boolean keepDims, int... dimensions) {
|
|
||||||
super(x, keepDims, dimensions);
|
|
||||||
}
|
|
||||||
|
|
||||||
public IMin(INDArray x, INDArray z, int... dimensions) {
|
|
||||||
super(x, z, dimensions);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int opNum() {
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String opName() {
|
|
||||||
return "imin";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String onnxName() {
|
|
||||||
return "ArgMin";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
|
||||||
//Not differentiable, but (assuming no ties) output does not change for a given infinitesimal change in the input
|
|
||||||
return Collections.singletonList(sameDiff.zerosLike(arg()));
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -0,0 +1,111 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.common.base.Preconditions;
|
||||||
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.tensorflow.framework.AttrValue;
|
||||||
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
import org.tensorflow.framework.NodeDef;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class ArgAmax extends DynamicCustomOp {
|
||||||
|
protected boolean keepDims = false;
|
||||||
|
private int[] dimensions;
|
||||||
|
|
||||||
|
protected DataType outputType = DataType.INT64;
|
||||||
|
|
||||||
|
public ArgAmax(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
|
||||||
|
super(sameDiff, i_v);
|
||||||
|
|
||||||
|
this.keepDims = keepDims;
|
||||||
|
this.dimensions = dimensions;
|
||||||
|
|
||||||
|
if (dimensions != null && dimensions.length > 0)
|
||||||
|
addIArgument(dimensions);
|
||||||
|
|
||||||
|
addBArgument(keepDims);
|
||||||
|
|
||||||
|
addDArgument(outputType);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ArgAmax() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public ArgAmax(INDArray x, INDArray z, boolean keepDims, int... dimensions) {
|
||||||
|
super(new INDArray[]{x}, z != null ? new INDArray[] {z} : new INDArray[0]);
|
||||||
|
|
||||||
|
this.keepDims = keepDims;
|
||||||
|
this.dimensions = dimensions;
|
||||||
|
|
||||||
|
if (dimensions != null && dimensions.length > 0)
|
||||||
|
addIArgument(dimensions);
|
||||||
|
|
||||||
|
addBArgument(keepDims);
|
||||||
|
|
||||||
|
addDArgument(outputType);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ArgAmax(INDArray x, INDArray z, int... dimensions) {
|
||||||
|
this(x, z, false, dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ArgAmax(INDArray x, int... dimensions) {
|
||||||
|
this(x, null, dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ArgAmax(INDArray x, boolean keepDims, int... dimensions) {
|
||||||
|
this(x, null, keepDims, dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "argamax";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String tensorflowName() {
|
||||||
|
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
|
if(attributesForNode.containsKey("output_type")) {
|
||||||
|
outputType = TFGraphMapper.convertType(attributesForNode.get("output_type").getType());
|
||||||
|
} else {
|
||||||
|
outputType = DataType.LONG;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
|
Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 1 || inputDataTypes.size() == 2),
|
||||||
|
"Expected 1 or 2 input datatype to argamax, got %s", inputDataTypes); //2nd input: axis
|
||||||
|
return Collections.singletonList(outputType == null ? DataType.LONG : outputType);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,111 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
|
||||||
|
|
||||||
|
import lombok.Data;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
|
import org.nd4j.common.base.Preconditions;
|
||||||
|
import org.nd4j.imports.NoOpNameFoundException;
|
||||||
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
|
import org.tensorflow.framework.AttrValue;
|
||||||
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
import org.tensorflow.framework.NodeDef;
|
||||||
|
|
||||||
|
import java.util.Collections;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
@Data
|
||||||
|
public class ArgAmin extends DynamicCustomOp {
|
||||||
|
protected boolean keepDims = false;
|
||||||
|
private int[] dimensions;
|
||||||
|
|
||||||
|
protected DataType outputType = DataType.INT64;
|
||||||
|
|
||||||
|
public ArgAmin(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
|
||||||
|
super(sameDiff, i_v);
|
||||||
|
|
||||||
|
this.keepDims = keepDims;
|
||||||
|
this.dimensions = dimensions;
|
||||||
|
|
||||||
|
if (dimensions != null && dimensions.length > 0)
|
||||||
|
addIArgument(dimensions);
|
||||||
|
|
||||||
|
addBArgument(keepDims);
|
||||||
|
|
||||||
|
addDArgument(outputType);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ArgAmin() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public ArgAmin(INDArray x, INDArray z, boolean keepDims, int... dimensions) {
|
||||||
|
super(new INDArray[]{x}, z != null ? new INDArray[] {z} : new INDArray[0]);
|
||||||
|
|
||||||
|
this.keepDims = keepDims;
|
||||||
|
this.dimensions = dimensions;
|
||||||
|
|
||||||
|
if (dimensions != null && dimensions.length > 0)
|
||||||
|
addIArgument(dimensions);
|
||||||
|
|
||||||
|
addBArgument(keepDims);
|
||||||
|
|
||||||
|
addDArgument(outputType);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ArgAmin(INDArray x, INDArray z, int... dimensions) {
|
||||||
|
this(x, z, false, dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ArgAmin(INDArray x, int... dimensions) {
|
||||||
|
this(x, null, dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ArgAmin(INDArray x, boolean keepDims, int... dimensions) {
|
||||||
|
this(x, null, keepDims, dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String opName() {
|
||||||
|
return "argamin";
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String tensorflowName() {
|
||||||
|
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
|
||||||
|
if(attributesForNode.containsKey("output_type")) {
|
||||||
|
outputType = TFGraphMapper.convertType(attributesForNode.get("output_type").getType());
|
||||||
|
} else {
|
||||||
|
outputType = DataType.LONG;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
|
||||||
|
Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 1 || inputDataTypes.size() == 2),
|
||||||
|
"Expected 1 or 2 input datatype to argamin, got %s", inputDataTypes); //2nd input: axis
|
||||||
|
return Collections.singletonList(outputType == null ? DataType.LONG : outputType);
|
||||||
|
}
|
||||||
|
}
|
|
@ -17,10 +17,12 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
|
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
@ -32,8 +34,53 @@ import java.util.Map;
|
||||||
|
|
||||||
@Data
|
@Data
|
||||||
public class ArgMax extends DynamicCustomOp {
|
public class ArgMax extends DynamicCustomOp {
|
||||||
|
protected boolean keepDims = false;
|
||||||
|
private int[] dimensions;
|
||||||
|
|
||||||
protected DataType outputType;
|
protected DataType outputType = DataType.INT64;
|
||||||
|
|
||||||
|
public ArgMax(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
|
||||||
|
super(sameDiff, i_v);
|
||||||
|
|
||||||
|
this.keepDims = keepDims;
|
||||||
|
this.dimensions = dimensions;
|
||||||
|
|
||||||
|
if (dimensions != null && dimensions.length > 0)
|
||||||
|
addIArgument(dimensions);
|
||||||
|
|
||||||
|
addBArgument(keepDims);
|
||||||
|
|
||||||
|
addDArgument(outputType);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ArgMax() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public ArgMax(INDArray x, INDArray z, boolean keepDims, int... dimensions) {
|
||||||
|
super(new INDArray[]{x}, z != null ? new INDArray[] {z} : new INDArray[0]);
|
||||||
|
|
||||||
|
this.keepDims = keepDims;
|
||||||
|
this.dimensions = dimensions;
|
||||||
|
|
||||||
|
if (dimensions != null && dimensions.length > 0)
|
||||||
|
addIArgument(dimensions);
|
||||||
|
|
||||||
|
addBArgument(keepDims);
|
||||||
|
|
||||||
|
addDArgument(outputType);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ArgMax(INDArray x, INDArray z, int... dimensions) {
|
||||||
|
this(x, z, false, dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ArgMax(INDArray x, int... dimensions) {
|
||||||
|
this(x, null, dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ArgMax(INDArray x, boolean keepDims, int... dimensions) {
|
||||||
|
this(x, null, keepDims, dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
|
|
|
@ -17,10 +17,12 @@
|
||||||
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
|
package org.nd4j.linalg.api.ops.impl.indexaccum.custom;
|
||||||
|
|
||||||
import lombok.Data;
|
import lombok.Data;
|
||||||
|
import org.nd4j.autodiff.samediff.SDVariable;
|
||||||
import org.nd4j.autodiff.samediff.SameDiff;
|
import org.nd4j.autodiff.samediff.SameDiff;
|
||||||
import org.nd4j.common.base.Preconditions;
|
import org.nd4j.common.base.Preconditions;
|
||||||
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
import org.nd4j.linalg.api.buffer.DataType;
|
||||||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.tensorflow.framework.AttrValue;
|
import org.tensorflow.framework.AttrValue;
|
||||||
import org.tensorflow.framework.GraphDef;
|
import org.tensorflow.framework.GraphDef;
|
||||||
|
@ -37,8 +39,53 @@ import java.util.Map;
|
||||||
*/
|
*/
|
||||||
@Data
|
@Data
|
||||||
public class ArgMin extends DynamicCustomOp {
|
public class ArgMin extends DynamicCustomOp {
|
||||||
|
protected boolean keepDims = false;
|
||||||
|
private int[] dimensions;
|
||||||
|
|
||||||
protected DataType outputType = DataType.LONG;
|
protected DataType outputType = DataType.INT64;
|
||||||
|
|
||||||
|
public ArgMin(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) {
|
||||||
|
super(sameDiff, i_v);
|
||||||
|
|
||||||
|
this.keepDims = keepDims;
|
||||||
|
this.dimensions = dimensions;
|
||||||
|
|
||||||
|
if (dimensions != null && dimensions.length > 0)
|
||||||
|
addIArgument(dimensions);
|
||||||
|
|
||||||
|
addBArgument(keepDims);
|
||||||
|
|
||||||
|
addDArgument(outputType);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ArgMin() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public ArgMin(INDArray x, INDArray z, boolean keepDims, int... dimensions) {
|
||||||
|
super(new INDArray[]{x}, z != null ? new INDArray[] {z} : new INDArray[0]);
|
||||||
|
|
||||||
|
this.keepDims = keepDims;
|
||||||
|
this.dimensions = dimensions;
|
||||||
|
|
||||||
|
if (dimensions != null && dimensions.length > 0)
|
||||||
|
addIArgument(dimensions);
|
||||||
|
|
||||||
|
addBArgument(keepDims);
|
||||||
|
|
||||||
|
addDArgument(outputType);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ArgMin(INDArray x, INDArray z, int... dimensions) {
|
||||||
|
this(x, z, false, dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ArgMin(INDArray x, int... dimensions) {
|
||||||
|
this(x, null, dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ArgMin(INDArray x, boolean keepDims, int... dimensions) {
|
||||||
|
this(x, null, keepDims, dimensions);
|
||||||
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public String opName() {
|
public String opName() {
|
||||||
|
|
|
@ -17,6 +17,8 @@
|
||||||
package org.nd4j.linalg.factory;
|
package org.nd4j.linalg.factory;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
|
||||||
|
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin;
|
||||||
import org.nd4j.linalg.factory.ops.*;
|
import org.nd4j.linalg.factory.ops.*;
|
||||||
import org.nd4j.shade.guava.primitives.Ints;
|
import org.nd4j.shade.guava.primitives.Ints;
|
||||||
import org.nd4j.shade.guava.primitives.Longs;
|
import org.nd4j.shade.guava.primitives.Longs;
|
||||||
|
@ -50,8 +52,6 @@ import org.nd4j.linalg.api.ops.Op;
|
||||||
import org.nd4j.linalg.api.ops.OpContext;
|
import org.nd4j.linalg.api.ops.OpContext;
|
||||||
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||||
import org.nd4j.linalg.api.ops.impl.indexaccum.IMax;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.indexaccum.IMin;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
|
import org.nd4j.linalg.api.ops.impl.reduce.Mmul;
|
||||||
import org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans;
|
import org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans;
|
||||||
import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
|
import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate;
|
||||||
|
@ -627,16 +627,16 @@ public class Nd4j {
|
||||||
* @return array of maximum values.
|
* @return array of maximum values.
|
||||||
*/
|
*/
|
||||||
public static INDArray argMax(INDArray arr, @NonNull int... dimension) {
|
public static INDArray argMax(INDArray arr, @NonNull int... dimension) {
|
||||||
IMax imax = new IMax(arr, dimension);
|
val imax = new ArgMax(arr, dimension);
|
||||||
return Nd4j.getExecutioner().exec(imax);
|
return Nd4j.getExecutioner().exec(imax)[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* See {@link #argMax(INDArray, int...)} but return minimum values.
|
* See {@link #argMax(INDArray, int...)} but return minimum values.
|
||||||
*/
|
*/
|
||||||
public static INDArray argMin(INDArray arr, @NonNull int... dimension) {
|
public static INDArray argMin(INDArray arr, @NonNull int... dimension) {
|
||||||
IMin imin = new IMin(arr, dimension);
|
val imin = new ArgMin(arr, dimension);
|
||||||
return Nd4j.getExecutioner().exec(imin);
|
return Nd4j.getExecutioner().exec(imin)[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -75,7 +75,7 @@ public class NDBase {
|
||||||
public INDArray argmax(INDArray in, boolean keepDims, int... dimensions) {
|
public INDArray argmax(INDArray in, boolean keepDims, int... dimensions) {
|
||||||
NDValidation.validateNumerical("argmax", "in", in);
|
NDValidation.validateNumerical("argmax", "in", in);
|
||||||
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IMax(in, keepDims, dimensions));
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(in, keepDims, dimensions))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -97,7 +97,7 @@ public class NDBase {
|
||||||
public INDArray argmax(INDArray in, int... dimensions) {
|
public INDArray argmax(INDArray in, int... dimensions) {
|
||||||
NDValidation.validateNumerical("argmax", "in", in);
|
NDValidation.validateNumerical("argmax", "in", in);
|
||||||
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IMax(in, false, dimensions));
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(in, false, dimensions))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -123,7 +123,7 @@ public class NDBase {
|
||||||
public INDArray argmin(INDArray in, boolean keepDims, int... dimensions) {
|
public INDArray argmin(INDArray in, boolean keepDims, int... dimensions) {
|
||||||
NDValidation.validateNumerical("argmin", "in", in);
|
NDValidation.validateNumerical("argmin", "in", in);
|
||||||
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(in, keepDims, dimensions));
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(in, keepDims, dimensions))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -148,7 +148,7 @@ public class NDBase {
|
||||||
public INDArray argmin(INDArray in, int... dimensions) {
|
public INDArray argmin(INDArray in, int... dimensions) {
|
||||||
NDValidation.validateNumerical("argmin", "in", in);
|
NDValidation.validateNumerical("argmin", "in", in);
|
||||||
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length);
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(in, false, dimensions));
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(in, false, dimensions))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -896,7 +896,7 @@ public class NDMath {
|
||||||
public INDArray iamax(INDArray in, int... dimensions) {
|
public INDArray iamax(INDArray in, int... dimensions) {
|
||||||
NDValidation.validateNumerical("iamax", "in", in);
|
NDValidation.validateNumerical("iamax", "in", in);
|
||||||
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
|
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(in, false, dimensions));
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(in, false, dimensions))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -911,7 +911,7 @@ public class NDMath {
|
||||||
public INDArray iamax(INDArray in, boolean keepDims, int... dimensions) {
|
public INDArray iamax(INDArray in, boolean keepDims, int... dimensions) {
|
||||||
NDValidation.validateNumerical("iamax", "in", in);
|
NDValidation.validateNumerical("iamax", "in", in);
|
||||||
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
|
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(in, keepDims, dimensions));
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(in, keepDims, dimensions))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -925,7 +925,7 @@ public class NDMath {
|
||||||
public INDArray iamin(INDArray in, int... dimensions) {
|
public INDArray iamin(INDArray in, int... dimensions) {
|
||||||
NDValidation.validateNumerical("iamin", "in", in);
|
NDValidation.validateNumerical("iamin", "in", in);
|
||||||
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
|
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(in, false, dimensions));
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(in, false, dimensions))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -940,7 +940,7 @@ public class NDMath {
|
||||||
public INDArray iamin(INDArray in, boolean keepDims, int... dimensions) {
|
public INDArray iamin(INDArray in, boolean keepDims, int... dimensions) {
|
||||||
NDValidation.validateNumerical("iamin", "in", in);
|
NDValidation.validateNumerical("iamin", "in", in);
|
||||||
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
|
Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length);
|
||||||
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(in, keepDims, dimensions));
|
return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(in, keepDims, dimensions))[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -17469,6 +17469,60 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
||||||
}
|
}
|
||||||
// #endif
|
// #endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This operation returns index of absolute max element in a given NDArray (optionally: along given dimension(s))
|
||||||
|
* Expected input:
|
||||||
|
* 0: N-dimensional array
|
||||||
|
* 1: optional axis vector
|
||||||
|
*
|
||||||
|
* Int args:
|
||||||
|
* 0: optional axis
|
||||||
|
*/
|
||||||
|
// #if NOT_EXCLUDED(OP_argamax)
|
||||||
|
@Namespace("sd::ops") public static class argamax extends DeclarableCustomOp {
|
||||||
|
static { Loader.load(); }
|
||||||
|
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||||
|
public argamax(Pointer p) { super(p); }
|
||||||
|
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||||
|
public argamax(long size) { super((Pointer)null); allocateArray(size); }
|
||||||
|
private native void allocateArray(long size);
|
||||||
|
@Override public argamax position(long position) {
|
||||||
|
return (argamax)super.position(position);
|
||||||
|
}
|
||||||
|
|
||||||
|
public argamax() { super((Pointer)null); allocate(); }
|
||||||
|
private native void allocate();
|
||||||
|
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||||
|
}
|
||||||
|
// #endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This operation returns index of absolute min element in a given NDArray (optionally: along given dimension(s))
|
||||||
|
* Expected input:
|
||||||
|
* 0: N-dimensional array
|
||||||
|
* 1: optional axis vector
|
||||||
|
*
|
||||||
|
* Int args:
|
||||||
|
* 0: optional axis
|
||||||
|
*/
|
||||||
|
// #if NOT_EXCLUDED(OP_argamin)
|
||||||
|
@Namespace("sd::ops") public static class argamin extends DeclarableCustomOp {
|
||||||
|
static { Loader.load(); }
|
||||||
|
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||||
|
public argamin(Pointer p) { super(p); }
|
||||||
|
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||||
|
public argamin(long size) { super((Pointer)null); allocateArray(size); }
|
||||||
|
private native void allocateArray(long size);
|
||||||
|
@Override public argamin position(long position) {
|
||||||
|
return (argamin)super.position(position);
|
||||||
|
}
|
||||||
|
|
||||||
|
public argamin() { super((Pointer)null); allocate(); }
|
||||||
|
private native void allocate();
|
||||||
|
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||||
|
}
|
||||||
|
// #endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This operation provides various normalization modes:
|
* This operation provides various normalization modes:
|
||||||
* 0: frobenius
|
* 0: frobenius
|
||||||
|
|
|
@ -32,8 +32,8 @@ import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.CustomOp;
|
import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax;
|
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgAmax;
|
||||||
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMin;
|
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgAmin;
|
||||||
import org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyWithLogitsLoss;
|
import org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyWithLogitsLoss;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.Moments;
|
import org.nd4j.linalg.api.ops.impl.reduce.Moments;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments;
|
import org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments;
|
||||||
|
@ -863,12 +863,12 @@ public class ReductionOpValidation extends BaseOpValidation {
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
reduce = sd.math().iamax(s, dim);
|
reduce = sd.math().iamax(s, dim);
|
||||||
exp = Nd4j.getExecutioner().exec(new IAMax(in.dup(), dim));
|
exp = Nd4j.getExecutioner().exec(new ArgAmax(in.dup(), dim))[0];
|
||||||
name = "iamax";
|
name = "iamax";
|
||||||
break;
|
break;
|
||||||
case 3:
|
case 3:
|
||||||
reduce = sd.math().iamin(s, dim);
|
reduce = sd.math().iamin(s, dim);
|
||||||
exp = Nd4j.getExecutioner().exec(new IAMin(in.dup(), dim));
|
exp = Nd4j.getExecutioner().exec(new ArgAmin(in.dup(), dim))[0];
|
||||||
name = "iamin";
|
name = "iamin";
|
||||||
break;
|
break;
|
||||||
case 4:
|
case 4:
|
||||||
|
|
|
@ -144,7 +144,7 @@ public class NameScopeTests extends BaseNd4jTest {
|
||||||
|
|
||||||
scope.close();
|
scope.close();
|
||||||
|
|
||||||
assertTrue("Var with name test/imax exists", SD.variableMap().containsKey("test/imax"));
|
assertTrue("Var with name test/argmax exists", SD.variableMap().containsKey("test/argmax"));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
|
|
|
@ -52,10 +52,10 @@ import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastEqualTo;
|
||||||
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThan;
|
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThan;
|
||||||
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThanOrEqual;
|
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThanOrEqual;
|
||||||
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastLessThan;
|
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastLessThan;
|
||||||
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax;
|
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgAmax;
|
||||||
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMin;
|
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgAmin;
|
||||||
import org.nd4j.linalg.api.ops.impl.indexaccum.IMax;
|
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.indexaccum.IMin;
|
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.Im2col;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.Im2col;
|
||||||
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
|
||||||
|
@ -3765,10 +3765,10 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ALL);
|
Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ALL);
|
||||||
|
|
||||||
INDArray arr = Nd4j.create(new double[] {-0.24, -0.26, -0.07, -0.01});
|
INDArray arr = Nd4j.create(new double[] {-0.24, -0.26, -0.07, -0.01});
|
||||||
IMax iMax = new IMax(arr);
|
val iMax = new ArgMax(arr);
|
||||||
IAMax iaMax = new IAMax(arr.dup());
|
val iaMax = new ArgAmax(arr.dup());
|
||||||
val imax = Nd4j.getExecutioner().execAndReturn(iMax).getFinalResult().intValue();
|
val imax = Nd4j.getExecutioner().exec(iMax)[0].getInt(0);
|
||||||
val iamax = Nd4j.getExecutioner().execAndReturn(iaMax).getFinalResult().intValue();
|
val iamax = Nd4j.getExecutioner().exec(iaMax)[0].getInt(0);
|
||||||
// System.out.println("IMAX: " + imax);
|
// System.out.println("IMAX: " + imax);
|
||||||
// System.out.println("IAMAX: " + iamax);
|
// System.out.println("IAMAX: " + iamax);
|
||||||
assertEquals(1, iamax);
|
assertEquals(1, iamax);
|
||||||
|
@ -3780,10 +3780,10 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
public void testIMinIAMin() {
|
public void testIMinIAMin() {
|
||||||
INDArray arr = Nd4j.create(new double[] {-0.24, -0.26, -0.07, -0.01});
|
INDArray arr = Nd4j.create(new double[] {-0.24, -0.26, -0.07, -0.01});
|
||||||
INDArray abs = Transforms.abs(arr);
|
INDArray abs = Transforms.abs(arr);
|
||||||
IAMin iaMin = new IAMin(abs);
|
val iaMin = new ArgAmin(abs);
|
||||||
IMin iMin = new IMin(arr.dup());
|
val iMin = new ArgMin(arr.dup());
|
||||||
double imin = Nd4j.getExecutioner().execAndReturn(iMin).getFinalResult().doubleValue();
|
double imin = Nd4j.getExecutioner().exec(iMin)[0].getDouble(0);
|
||||||
double iamin = Nd4j.getExecutioner().execAndReturn(iaMin).getFinalResult().doubleValue();
|
double iamin = Nd4j.getExecutioner().exec(iaMin)[0].getDouble(0);
|
||||||
// System.out.println("IMin: " + imin);
|
// System.out.println("IMin: " + imin);
|
||||||
// System.out.println("IAMin: " + iamin);
|
// System.out.println("IAMin: " + iamin);
|
||||||
assertEquals(3, iamin, 1e-12);
|
assertEquals(3, iamin, 1e-12);
|
||||||
|
@ -4077,7 +4077,7 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
arr.get(NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.all()).assign(Nd4j.create(slices[i]));
|
arr.get(NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.all()).assign(Nd4j.create(slices[i]));
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray out = Nd4j.getExecutioner().exec(new IMax(arr, 1,2));
|
INDArray out = Nd4j.exec(new ArgMax(arr, 1,2))[0];
|
||||||
|
|
||||||
assertEquals(DataType.LONG, out.dataType());
|
assertEquals(DataType.LONG, out.dataType());
|
||||||
|
|
||||||
|
@ -4119,8 +4119,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
INDArray actC = Nd4j.getExecutioner().exec(new IMax(arr.dup('c'), 0,1));
|
INDArray actC = Nd4j.getExecutioner().exec(new ArgMax(arr.dup('c'), 0,1))[0];
|
||||||
INDArray actF = Nd4j.getExecutioner().exec(new IMax(arr.dup('f'), 0,1));
|
INDArray actF = Nd4j.getExecutioner().exec(new ArgMax(arr.dup('f'), 0,1))[0];
|
||||||
//
|
//
|
||||||
assertEquals(exp, actC);
|
assertEquals(exp, actC);
|
||||||
assertEquals(exp, actF);
|
assertEquals(exp, actF);
|
||||||
|
@ -4153,8 +4153,8 @@ public class Nd4jTestsC extends BaseNd4jTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
actC = Nd4j.getExecutioner().exec(new IMax(arr.dup('c'), 2, 3));
|
actC = Nd4j.getExecutioner().exec(new ArgMax(arr.dup('c'), 2, 3))[0];
|
||||||
actF = Nd4j.getExecutioner().exec(new IMax(arr.dup('f'), 2, 3));
|
actF = Nd4j.getExecutioner().exec(new ArgMax(arr.dup('f'), 2, 3))[0];
|
||||||
|
|
||||||
assertEquals(exp, actC);
|
assertEquals(exp, actC);
|
||||||
assertEquals(exp, actF);
|
assertEquals(exp, actF);
|
||||||
|
|
|
@ -25,7 +25,7 @@ import org.junit.runners.Parameterized;
|
||||||
import org.nd4j.linalg.BaseNd4jTest;
|
import org.nd4j.linalg.BaseNd4jTest;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.CustomOp;
|
import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.indexaccum.IMax;
|
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
|
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
|
||||||
|
@ -122,7 +122,7 @@ public class CrashTest extends BaseNd4jTest {
|
||||||
float sum = x.sumNumber().floatValue();
|
float sum = x.sumNumber().floatValue();
|
||||||
|
|
||||||
// index reduction
|
// index reduction
|
||||||
Nd4j.getExecutioner().exec(new IMax(x));
|
Nd4j.getExecutioner().exec(new ArgMax(x));
|
||||||
|
|
||||||
// casual transform
|
// casual transform
|
||||||
Nd4j.getExecutioner().exec(new Sqrt(x, x));
|
Nd4j.getExecutioner().exec(new Sqrt(x, x));
|
||||||
|
|
|
@ -26,9 +26,9 @@ import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.CustomOp;
|
import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||||
import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax;
|
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgAmax;
|
||||||
import org.nd4j.linalg.api.ops.impl.indexaccum.IMax;
|
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.indexaccum.IMin;
|
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.floating.Mean;
|
import org.nd4j.linalg.api.ops.impl.reduce.floating.Mean;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2;
|
import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax;
|
import org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax;
|
||||||
|
@ -282,9 +282,9 @@ public class OpExecutionerTests extends BaseNd4jTest {
|
||||||
public void testIamax2() {
|
public void testIamax2() {
|
||||||
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
|
INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE);
|
||||||
assertEquals(getFailureMessage(), 3, Nd4j.getBlasWrapper().iamax(linspace));
|
assertEquals(getFailureMessage(), 3, Nd4j.getBlasWrapper().iamax(linspace));
|
||||||
val op = new IAMax(linspace);
|
val op = new ArgAmax(linspace);
|
||||||
|
|
||||||
int iamax = Nd4j.getExecutioner().execAndReturn(op).getFinalResult().intValue();
|
int iamax = Nd4j.getExecutioner().exec(op)[0].getInt(0);
|
||||||
assertEquals(3, iamax);
|
assertEquals(3, iamax);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -565,24 +565,24 @@ public class OpExecutionerTests extends BaseNd4jTest {
|
||||||
@Test
|
@Test
|
||||||
public void testIMax() {
|
public void testIMax() {
|
||||||
INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE);
|
INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE);
|
||||||
IMax imax = new IMax(arr);
|
ArgMax imax = new ArgMax(arr);
|
||||||
assertEquals(9, Nd4j.getExecutioner().execAndReturn(imax).getFinalResult().intValue());
|
assertEquals(9, Nd4j.getExecutioner().exec(imax)[0].getInt(0));
|
||||||
|
|
||||||
arr.muli(-1);
|
arr.muli(-1);
|
||||||
imax = new IMax(arr);
|
imax = new ArgMax(arr);
|
||||||
int maxIdx = Nd4j.getExecutioner().execAndReturn(imax).getFinalResult().intValue();
|
int maxIdx = Nd4j.getExecutioner().exec(imax)[0].getInt(0);
|
||||||
assertEquals(0, maxIdx);
|
assertEquals(0, maxIdx);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testIMin() {
|
public void testIMin() {
|
||||||
INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE);
|
INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE);
|
||||||
IMin imin = new IMin(arr);
|
ArgMin imin = new ArgMin(arr);
|
||||||
assertEquals(0, Nd4j.getExecutioner().execAndReturn(imin).getFinalResult().intValue());
|
assertEquals(0, Nd4j.getExecutioner().exec(imin)[0].getInt(0));
|
||||||
|
|
||||||
arr.muli(-1);
|
arr.muli(-1);
|
||||||
imin = new IMin(arr);
|
imin = new ArgMin(arr);
|
||||||
int minIdx = Nd4j.getExecutioner().execAndReturn(imin).getFinalResult().intValue();
|
int minIdx = Nd4j.getExecutioner().exec(imin)[0].getInt(0);
|
||||||
assertEquals(9, minIdx);
|
assertEquals(9, minIdx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -32,8 +32,8 @@ import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.CustomOp;
|
import org.nd4j.linalg.api.ops.CustomOp;
|
||||||
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
|
||||||
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
|
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.indexaccum.IMax;
|
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax;
|
||||||
import org.nd4j.linalg.api.ops.impl.indexaccum.IMin;
|
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.floating.Mean;
|
import org.nd4j.linalg.api.ops.impl.reduce.floating.Mean;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2;
|
import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax;
|
import org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax;
|
||||||
|
@ -478,24 +478,24 @@ public class OpExecutionerTestsC extends BaseNd4jTest {
|
||||||
@Test
|
@Test
|
||||||
public void testIMax() {
|
public void testIMax() {
|
||||||
INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE);
|
INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE);
|
||||||
IMax imax = new IMax(arr);
|
ArgMax imax = new ArgMax(arr);
|
||||||
assertEquals(9, Nd4j.getExecutioner().execAndReturn(imax).getFinalResult().intValue());
|
assertEquals(9, Nd4j.getExecutioner().exec(imax)[0].getInt(0));
|
||||||
|
|
||||||
arr.muli(-1);
|
arr.muli(-1);
|
||||||
imax = new IMax(arr);
|
imax = new ArgMax(arr);
|
||||||
int maxIdx = Nd4j.getExecutioner().execAndReturn(imax).getFinalResult().intValue();
|
int maxIdx = Nd4j.getExecutioner().exec(imax)[0].getInt(0);
|
||||||
assertEquals(0, maxIdx);
|
assertEquals(0, maxIdx);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testIMin() {
|
public void testIMin() {
|
||||||
INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE);
|
INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE);
|
||||||
IMin imin = new IMin(arr);
|
ArgMin imin = new ArgMin(arr);
|
||||||
assertEquals(0, Nd4j.getExecutioner().execAndReturn(imin).getFinalResult().intValue());
|
assertEquals(0, Nd4j.getExecutioner().exec(imin)[0].getInt(0));
|
||||||
|
|
||||||
arr.muli(-1);
|
arr.muli(-1);
|
||||||
imin = new IMin(arr);
|
imin = new ArgMin(arr);
|
||||||
int minIdx = Nd4j.getExecutioner().execAndReturn(imin).getFinalResult().intValue();
|
int minIdx = Nd4j.getExecutioner().exec(imin)[0].getInt(0);
|
||||||
assertEquals(9, minIdx);
|
assertEquals(9, minIdx);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -26,6 +26,7 @@ import org.nd4j.linalg.api.buffer.DataType;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.bool.All;
|
import org.nd4j.linalg.api.ops.impl.reduce.bool.All;
|
||||||
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
import org.nd4j.linalg.factory.Nd4jBackend;
|
import org.nd4j.linalg.factory.Nd4jBackend;
|
||||||
|
|
||||||
|
@ -234,7 +235,7 @@ public class EmptyTests extends BaseNd4jTest {
|
||||||
assertEquals(e, reduced);
|
assertEquals(e, reduced);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
@Test(expected = ND4JIllegalStateException.class)
|
||||||
public void testEmptyReduction_4() {
|
public void testEmptyReduction_4() {
|
||||||
val x = Nd4j.create(DataType.FLOAT, 2, 0);
|
val x = Nd4j.create(DataType.FLOAT, 2, 0);
|
||||||
val e = Nd4j.create(DataType.FLOAT, 0);
|
val e = Nd4j.create(DataType.FLOAT, 0);
|
||||||
|
|
Loading…
Reference in New Issue