[WIP] INDArray hashCode() impl (#50)
* initial commit Signed-off-by: raver119@gmail.com <raver119@gmail.com> * one more initial commit Signed-off-by: raver119 <raver119@gmail.com> * parallel hashCode prototype Signed-off-by: raver119 <raver119@gmail.com> * longBytes for hashCode Signed-off-by: raver119 <raver119@gmail.com> * INDArray hashCode java side Signed-off-by: raver119 <raver119@gmail.com> * few tests fixed for MSVC Signed-off-by: raver119 <raver119@gmail.com> * Small gradcheck validation util fix - hash names not SDVariables Signed-off-by: Alex Black <blacka101@gmail.com> * Small fix + ignore for logged issue Signed-off-by: Alex Black <blacka101@gmail.com> * - scrollable iterator fix - sptree hashset replaced with collection Signed-off-by: raver119 <raver119@gmail.com> * hashcode exception removed * int hashCode for javamaster
parent
cc65c01118
commit
5708fc087a
|
@ -114,6 +114,11 @@ public class ScrollableMultiDataSetIterator implements MultiDataSetIterator {
|
|||
return p;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void remove() {
|
||||
//
|
||||
}
|
||||
|
||||
@Override
|
||||
public MultiDataSet next(int i) {
|
||||
throw new UnsupportedOperationException();
|
||||
|
|
|
@ -29,7 +29,8 @@ import org.slf4j.Logger;
|
|||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.HashSet;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Set;
|
||||
|
||||
|
||||
|
@ -55,20 +56,20 @@ public class SpTree implements Serializable {
|
|||
private int nodeCapacity;
|
||||
private int numChildren = 2;
|
||||
private boolean isLeaf = true;
|
||||
private Set<INDArray> indices;
|
||||
private Collection<INDArray> indices;
|
||||
private SpTree[] children;
|
||||
private static Logger log = LoggerFactory.getLogger(SpTree.class);
|
||||
private String similarityFunction = Distance.EUCLIDEAN.toString();
|
||||
|
||||
|
||||
|
||||
public SpTree(SpTree parent, INDArray data, INDArray corner, INDArray width, Set<INDArray> indices,
|
||||
public SpTree(SpTree parent, INDArray data, INDArray corner, INDArray width, Collection<INDArray> indices,
|
||||
String similarityFunction) {
|
||||
init(parent, data, corner, width, indices, similarityFunction);
|
||||
}
|
||||
|
||||
|
||||
public SpTree(INDArray data, Set<INDArray> indices, String similarityFunction) {
|
||||
public SpTree(INDArray data, Collection<INDArray> indices, String similarityFunction) {
|
||||
this.indices = indices;
|
||||
this.N = data.rows();
|
||||
this.D = data.columns();
|
||||
|
@ -90,26 +91,26 @@ public class SpTree implements Serializable {
|
|||
}
|
||||
|
||||
|
||||
public SpTree(SpTree parent, INDArray data, INDArray corner, INDArray width, Set<INDArray> indices) {
|
||||
public SpTree(SpTree parent, INDArray data, INDArray corner, INDArray width, Collection<INDArray> indices) {
|
||||
this(parent, data, corner, width, indices, "euclidean");
|
||||
}
|
||||
|
||||
|
||||
public SpTree(INDArray data, Set<INDArray> indices) {
|
||||
public SpTree(INDArray data, Collection<INDArray> indices) {
|
||||
this(data, indices, "euclidean");
|
||||
}
|
||||
|
||||
|
||||
|
||||
public SpTree(INDArray data) {
|
||||
this(data, new HashSet<INDArray>());
|
||||
this(data, new ArrayList<INDArray>());
|
||||
}
|
||||
|
||||
public MemoryWorkspace workspace() {
|
||||
return null;
|
||||
}
|
||||
|
||||
private void init(SpTree parent, INDArray data, INDArray corner, INDArray width, Set<INDArray> indices,
|
||||
private void init(SpTree parent, INDArray data, INDArray corner, INDArray width, Collection<INDArray> indices,
|
||||
String similarityFunction) {
|
||||
|
||||
this.parent = parent;
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_hashcode)
|
||||
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <ops/declarable/helpers/transforms.h>
|
||||
#include <ops/declarable/helpers/hashcode.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
REDUCTION_OP_IMPL(hashcode, 1, 1, false, 0, 0) {
|
||||
REQUIRE_TRUE(block.width() == 1, 0, "hashcode: this op can't be applied along dimension");
|
||||
|
||||
auto input = INPUT_VARIABLE(0);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
REQUIRE_TRUE(output->isScalar(), 0, "hashcode: this op requires scalar output");
|
||||
|
||||
helpers::hashCode(block.launchContext(), *input, *output);
|
||||
|
||||
return Status::OK();
|
||||
};
|
||||
|
||||
|
||||
DECLARE_TYPES(hashcode) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS})
|
||||
->setAllowedInputTypes(1, {ALL_INTS})
|
||||
->setAllowedOutputTypes({nd4j::DataType::INT64});
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
@ -199,6 +199,13 @@ namespace nd4j {
|
|||
DECLARE_CONFIGURABLE_OP(standardize, 1, 1, true, 0, -2);
|
||||
DECLARE_CUSTOM_OP(standardize_bp, 2, 1, false, 0, -2);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* This operation calculates hash code, optionally along dimension
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_hashcode)
|
||||
DECLARE_REDUCTION_OP(hashcode, 1, 1, false, 0, 0);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,101 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <ops/declarable/helpers/hashcode.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
template <typename T>
|
||||
static void hashCode_(LaunchContext *context, NDArray &array, NDArray &result) {
|
||||
auto blockSize = 32;
|
||||
auto length = array.lengthOf();
|
||||
int numBlocks = length / blockSize + ((length % blockSize == 0) ? 0 : 1);
|
||||
auto tempA = NDArrayFactory::create<Nd4jLong>('c', {numBlocks}, context);
|
||||
auto tempB = NDArrayFactory::create<Nd4jLong>('c', { numBlocks / blockSize + 1}, context);
|
||||
|
||||
auto buffer = array.bufferAsT<T>();
|
||||
auto tempBufferA = tempA.bufferAsT<Nd4jLong>();
|
||||
auto tempBufferB = tempB.bufferAsT<Nd4jLong>();
|
||||
|
||||
// default buffer is the first one, because it might be the last one in case of small arrays (< blockSize)
|
||||
auto tempBuffer = tempBufferA;
|
||||
auto tempResult = tempBufferB;
|
||||
|
||||
// we divide array into 32 element chunks, and store intermediate results once
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (int b = 0; b < numBlocks; b++) {
|
||||
auto blockBuffer = buffer + b * numBlocks;
|
||||
|
||||
Nd4jLong r = 1;
|
||||
for (int e = 0; e < blockSize && e + (b * numBlocks) < length; e++) {
|
||||
auto v = longBytes<T>(blockBuffer[e]);
|
||||
r = 31 * r + v;
|
||||
}
|
||||
|
||||
tempBuffer[b] = r;
|
||||
}
|
||||
|
||||
// we replace pointer with intermediate one, and repeat only one chunk left
|
||||
int iterationCount = 0;
|
||||
while (numBlocks > 1) {
|
||||
int lastLength = numBlocks;
|
||||
numBlocks = lastLength / blockSize + ((lastLength % blockSize == 0) ? 0 : 1);
|
||||
|
||||
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD
|
||||
for (int b = 0; b < numBlocks; b++) {
|
||||
auto blockBuffer = tempBuffer + b * numBlocks;
|
||||
|
||||
Nd4jLong r = 1;
|
||||
for (int e = 0; e < blockSize && e + (b * numBlocks) < lastLength; e++) {
|
||||
auto v = longBytes<T>(blockBuffer[e]);
|
||||
r = 31 * r + v;
|
||||
}
|
||||
|
||||
tempResult[b] = r;
|
||||
}
|
||||
|
||||
|
||||
iterationCount++;
|
||||
// swapping buffers
|
||||
if (iterationCount % 2 == 0) {
|
||||
tempBuffer = tempBufferA;
|
||||
tempResult = tempBufferB;
|
||||
} else {
|
||||
tempBuffer = tempBufferB;
|
||||
tempResult = tempBufferA;
|
||||
}
|
||||
}
|
||||
|
||||
if (length <= blockSize)
|
||||
result.p(0, tempBufferA[0]);
|
||||
else
|
||||
result.p(0, tempResult[0]);
|
||||
}
|
||||
|
||||
|
||||
void hashCode(LaunchContext *context, NDArray &array, NDArray &result) {
|
||||
BUILD_SINGLE_SELECTOR(array.dataType(), hashCode_, (context, array, result), LIBND4J_TYPES);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,32 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <ops/declarable/helpers/hashcode.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
void hashCode(LaunchContext *context, NDArray &array, NDArray &result) {
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,70 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#ifndef DEV_TESTS_HASHCODE_H
|
||||
#define DEV_TESTS_HASHCODE_H
|
||||
|
||||
#include "helpers.h"
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
template <typename T>
|
||||
FORCEINLINE Nd4jLong longBytes(T value);
|
||||
|
||||
template <>
|
||||
FORCEINLINE Nd4jLong longBytes(float value) {
|
||||
int intie = *(int *)&value;
|
||||
return static_cast<Nd4jLong>(intie);
|
||||
}
|
||||
|
||||
template <>
|
||||
FORCEINLINE Nd4jLong longBytes(double value) {
|
||||
Nd4jLong longie = *(Nd4jLong *)&value;
|
||||
return longie;
|
||||
}
|
||||
|
||||
template <>
|
||||
FORCEINLINE Nd4jLong longBytes(float16 value) {
|
||||
return longBytes<float>((float) value);
|
||||
}
|
||||
|
||||
template <>
|
||||
FORCEINLINE Nd4jLong longBytes(Nd4jLong value) {
|
||||
return value;
|
||||
}
|
||||
|
||||
template <>
|
||||
FORCEINLINE Nd4jLong longBytes(bfloat16 value) {
|
||||
return longBytes<float>((float) value);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
FORCEINLINE Nd4jLong longBytes(T value) {
|
||||
return longBytes<Nd4jLong>((Nd4jLong) value);
|
||||
}
|
||||
|
||||
|
||||
void hashCode(LaunchContext *context, NDArray &array, NDArray &result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif //DEV_TESTS_HASHCODE_H
|
|
@ -1925,10 +1925,10 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_2) {
|
|||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, Image_CropAndResize_1) {
|
||||
|
||||
int axis = 0;
|
||||
NDArray images = NDArrayFactory::create<double>('c', {1,2,2,1}, {1,2,3,4});
|
||||
NDArray boxes = NDArrayFactory::create<float>('c', {1,4}, {0,0,1,1});
|
||||
NDArray boxI = NDArrayFactory::create<int>('c', {1}, {(int)0});
|
||||
NDArray boxI = NDArrayFactory::create<int>('c', {1}, {axis});
|
||||
NDArray cropSize = NDArrayFactory::create<int>({1, 1});
|
||||
|
||||
//NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f});
|
||||
|
@ -1949,10 +1949,10 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_1) {
|
|||
|
||||
////////////////////////////////////////////////////////////////////
|
||||
TEST_F(DeclarableOpsTests10, Image_CropAndResize_2) {
|
||||
|
||||
int axis = 0;
|
||||
NDArray images = NDArrayFactory::create<float>('c', {1,2,2,1}, {1,2,3,4});
|
||||
NDArray boxes = NDArrayFactory::create<float>('c', {1,4}, {0,0,1,1});
|
||||
NDArray boxI = NDArrayFactory::create<int>('c', {1}, {(int)0});
|
||||
NDArray boxI = NDArrayFactory::create<int>('c', {1}, {axis});
|
||||
NDArray cropSize = NDArrayFactory::create<int>({1, 1});
|
||||
|
||||
//NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f});
|
||||
|
|
|
@ -250,6 +250,46 @@ TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_1) {
|
|||
delete result;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, test_hashCode_1) {
|
||||
auto x = NDArrayFactory::create<int>('c', {10});
|
||||
auto y = NDArrayFactory::create<int>('c', {10});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(2.);
|
||||
|
||||
nd4j::ops::hashcode op;
|
||||
auto resultA0 = op.execute({&x}, {}, {});
|
||||
auto resultA1 = op.execute({&x}, {}, {});
|
||||
auto resultB0 = op.execute({&y}, {}, {});
|
||||
|
||||
ASSERT_EQ(*resultA0->at(0), *resultA1->at(0));
|
||||
ASSERT_NE(*resultA0->at(0), *resultB0->at(0));
|
||||
|
||||
delete resultA0;
|
||||
delete resultA1;
|
||||
delete resultB0;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, test_hashCode_2) {
|
||||
auto x = NDArrayFactory::create<int>('c', {1027});
|
||||
auto y = NDArrayFactory::create<int>('c', {1027});
|
||||
|
||||
x.linspace(1.);
|
||||
y.linspace(2.);
|
||||
|
||||
nd4j::ops::hashcode op;
|
||||
auto resultA0 = op.execute({&x}, {}, {});
|
||||
auto resultA1 = op.execute({&x}, {}, {});
|
||||
auto resultB0 = op.execute({&y}, {}, {});
|
||||
|
||||
ASSERT_EQ(*resultA0->at(0), *resultA1->at(0));
|
||||
ASSERT_NE(*resultA0->at(0), *resultB0->at(0));
|
||||
|
||||
delete resultA0;
|
||||
delete resultA1;
|
||||
delete resultB0;
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests15, test_lstmBlock_1) {
|
||||
auto x0 = NDArrayFactory::create<Nd4jLong>(5);
|
||||
auto x1 = NDArrayFactory::create<float>('c', {5, 1, 4}, {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f, 0.50563407f, 0.89252293f, 0.5461209f, 0.92336726f, 0.085571885f, 0.7937801f, 0.65908563f, 0.55552566f, 0.15962744f, 0.30874777f, 0.15476847f, 0.46954823f, 0.9938899f, 0.6112741f});
|
||||
|
|
|
@ -842,9 +842,10 @@ TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_2) {
|
|||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_3) {
|
||||
int axis = 0;
|
||||
auto x = NDArrayFactory::create<double>('c', {1}, {10});
|
||||
auto begin = NDArrayFactory::create<int>('c', {1}, {(int)0});
|
||||
auto end = NDArrayFactory::create<int>('c', {1}, {(int)0});
|
||||
auto begin = NDArrayFactory::create<int>('c', {1}, {axis});
|
||||
auto end = NDArrayFactory::create<int>('c', {1}, {axis});
|
||||
auto stride = NDArrayFactory::create<int>('c', {1}, {1});
|
||||
//x.linspace(1);
|
||||
//auto exp = NDArrayFactory::create<double>('c', {1,3,4,5});
|
||||
|
|
|
@ -643,8 +643,8 @@ TEST_F(ParityOpsTests, Test_Select_2) {
|
|||
}
|
||||
|
||||
TEST_F(ParityOpsTests, Test_Select_3) {
|
||||
|
||||
auto mask = NDArrayFactory::create<bool>('c', {1, 1}, {false});
|
||||
bool value = false;
|
||||
auto mask = NDArrayFactory::create<bool>('c', {1, 1}, {value});
|
||||
auto x = NDArrayFactory::create<float>('c', {1, 1}, {1});
|
||||
auto y = NDArrayFactory::create<float>('c', {1, 1}, {2});
|
||||
auto exp = NDArrayFactory::create<float>('c', {1, 1}, {2});
|
||||
|
|
|
@ -1601,6 +1601,30 @@ TEST_F(PlaygroundTests, test_assign_float) {
|
|||
|
||||
}
|
||||
|
||||
TEST_F(PlaygroundTests, test_hash_1) {
|
||||
std::vector<int> vec;
|
||||
for (int e = 1; e < 100000; e++)
|
||||
vec.emplace_back(e);
|
||||
|
||||
int size = vec.size();
|
||||
int r = 0;
|
||||
PRAGMA_OMP_PARALLEL_FOR_REDUCTION(+:r)
|
||||
for (int e = 0; e < size; e++) {
|
||||
r += 31 * vec[e];
|
||||
}
|
||||
|
||||
nd4j_printf("Result: %i\n", r);
|
||||
}
|
||||
|
||||
TEST_F(PlaygroundTests, test_hash_2) {
|
||||
auto x = NDArrayFactory::create<int>('c', {5, 10000});
|
||||
x.linspace(1.f);
|
||||
|
||||
//auto h = x.reduceNumber(reduce::LongOps::HashCode);
|
||||
|
||||
//h.printIndexedBuffer("hash");
|
||||
}
|
||||
|
||||
/*
|
||||
TEST_F(PlaygroundTests, test_manual_loop) {
|
||||
const unsigned int len = 32 * 128 * 256 * 256;
|
||||
|
|
|
@ -596,8 +596,6 @@ public class GradCheckUtil {
|
|||
DifferentialFunction[] dfs = sd.functions();
|
||||
List<SDVariable> vars = sd.variables();
|
||||
|
||||
Set<SDVariable> varsSet = new HashSet<>(vars);
|
||||
Preconditions.checkState(vars.size() == varsSet.size(), "Duplicate variables in variables() list");
|
||||
Set<String> varSetStr = new HashSet<>();
|
||||
for(SDVariable v : vars){
|
||||
if(varSetStr.contains(v.getVarName())){
|
||||
|
@ -605,6 +603,7 @@ public class GradCheckUtil {
|
|||
}
|
||||
varSetStr.add(v.getVarName());
|
||||
}
|
||||
Preconditions.checkState(vars.size() == varSetStr.size(), "Duplicate variables in variables() list");
|
||||
|
||||
//1. Check incomingArgsReverse and outgoingArgsReverse
|
||||
Map<String,SameDiffOp> ops = sd.getOps();
|
||||
|
|
|
@ -173,6 +173,7 @@ public class ImportClassMapping {
|
|||
org.nd4j.linalg.api.ops.impl.meta.ReduceMetaOp.class,
|
||||
org.nd4j.linalg.api.ops.impl.nlp.CbowRound.class,
|
||||
org.nd4j.linalg.api.ops.impl.nlp.SkipGramRound.class,
|
||||
org.nd4j.linalg.api.ops.impl.reduce.HashCode.class,
|
||||
org.nd4j.linalg.api.ops.impl.reduce.Mmul.class,
|
||||
org.nd4j.linalg.api.ops.impl.reduce.MmulBp.class,
|
||||
org.nd4j.linalg.api.ops.impl.reduce.Moments.class,
|
||||
|
|
|
@ -38,6 +38,7 @@ import org.nd4j.linalg.api.iter.NdIndexIterator;
|
|||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||
import org.nd4j.linalg.api.ops.CustomOp;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.HashCode;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.bool.All;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.bool.Any;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.floating.*;
|
||||
|
@ -5444,6 +5445,12 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
return equalsWithEps(o, Nd4j.EPS_THRESHOLD);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
val longHash = Nd4j.exec(new HashCode(this))[0].getLong(0);
|
||||
return Math.abs(longHash) <= Integer.MAX_VALUE ? (int) longHash : (int) (longHash % Integer.MAX_VALUE);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataBuffer shapeInfoDataBuffer() {
|
||||
return shapeInformation;
|
||||
|
@ -6826,4 +6833,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
|
|||
public INDArray ulike() {
|
||||
return Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering());
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
package org.nd4j.linalg.api.ops.impl.reduce;
|
||||
|
||||
import lombok.NonNull;
|
||||
import org.nd4j.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.DynamicCustomOp;
|
||||
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* This is hashCode op wrapper. Basically - simple parallel hash implementation.
|
||||
*
|
||||
* @author raver119@gmail.com
|
||||
*/
|
||||
public class HashCode extends DynamicCustomOp {
|
||||
public HashCode() {
|
||||
//
|
||||
}
|
||||
|
||||
public HashCode(@NonNull INDArray array) {
|
||||
this.inputArguments.add(array);
|
||||
}
|
||||
|
||||
public HashCode(@NonNull INDArray array, @NonNull INDArray result) {
|
||||
this(array);
|
||||
Preconditions.checkArgument(result.dataType() == DataType.LONG && result.isScalar(), "HashCode op expects LONG scalar as output");
|
||||
|
||||
this.outputArguments.add(result);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<LongShapeDescriptor> calculateOutputShape() {
|
||||
return Collections.singletonList(LongShapeDescriptor.fromShape(new long[0], DataType.LONG));
|
||||
}
|
||||
|
||||
@Override
|
||||
public String opName() {
|
||||
return "hashcode";
|
||||
}
|
||||
}
|
|
@ -16712,6 +16712,26 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
|
|||
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
|
||||
}
|
||||
// #endif
|
||||
|
||||
/**
|
||||
* This operation calculates hash code, optionally along dimension
|
||||
*/
|
||||
// #if NOT_EXCLUDED(OP_hashcode)
|
||||
@Namespace("nd4j::ops") public static class hashcode extends DeclarableReductionOp {
|
||||
static { Loader.load(); }
|
||||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
|
||||
public hashcode(Pointer p) { super(p); }
|
||||
/** Native array allocator. Access with {@link Pointer#position(long)}. */
|
||||
public hashcode(long size) { super((Pointer)null); allocateArray(size); }
|
||||
private native void allocateArray(long size);
|
||||
@Override public hashcode position(long position) {
|
||||
return (hashcode)super.position(position);
|
||||
}
|
||||
|
||||
public hashcode() { super((Pointer)null); allocate(); }
|
||||
private native void allocate();
|
||||
}
|
||||
// #endif
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -108,7 +108,10 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
|
|||
|
||||
//2019/06/22 - Known issue: https://github.com/eclipse/deeplearning4j/issues/7935
|
||||
"fake_quant/min_max_vars/.*",
|
||||
"fake_quant/min_max_args/.*"
|
||||
"fake_quant/min_max_args/.*",
|
||||
|
||||
//2019/07/09 - Need "Multinomial" op - https://github.com/eclipse/deeplearning4j/issues/7913
|
||||
"multinormal/.*"
|
||||
};
|
||||
|
||||
@BeforeClass
|
||||
|
|
Loading…
Reference in New Issue