[WIP] INDArray hashCode() impl (#50)

* initial commit

Signed-off-by: raver119@gmail.com <raver119@gmail.com>

* one more initial commit

Signed-off-by: raver119 <raver119@gmail.com>

* parallel hashCode prototype

Signed-off-by: raver119 <raver119@gmail.com>

* longBytes for hashCode

Signed-off-by: raver119 <raver119@gmail.com>

* INDArray hashCode java side

Signed-off-by: raver119 <raver119@gmail.com>

* few tests fixed for MSVC

Signed-off-by: raver119 <raver119@gmail.com>

* Small gradcheck validation util fix - hash names not SDVariables

Signed-off-by: Alex Black <blacka101@gmail.com>

* Small fix + ignore for logged issue

Signed-off-by: Alex Black <blacka101@gmail.com>

* - scrollable iterator fix
- sptree hashset replaced with collection

Signed-off-by: raver119 <raver119@gmail.com>

* hashcode exception removed

* int hashCode for java
master
raver119 2019-07-10 14:32:12 +03:00 committed by AlexDBlack
parent cc65c01118
commit 5708fc087a
18 changed files with 445 additions and 19 deletions

View File

@ -114,6 +114,11 @@ public class ScrollableMultiDataSetIterator implements MultiDataSetIterator {
return p; return p;
} }
@Override
public void remove() {
//
}
@Override @Override
public MultiDataSet next(int i) { public MultiDataSet next(int i) {
throw new UnsupportedOperationException(); throw new UnsupportedOperationException();

View File

@ -29,7 +29,8 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import java.io.Serializable; import java.io.Serializable;
import java.util.HashSet; import java.util.ArrayList;
import java.util.Collection;
import java.util.Set; import java.util.Set;
@ -55,20 +56,20 @@ public class SpTree implements Serializable {
private int nodeCapacity; private int nodeCapacity;
private int numChildren = 2; private int numChildren = 2;
private boolean isLeaf = true; private boolean isLeaf = true;
private Set<INDArray> indices; private Collection<INDArray> indices;
private SpTree[] children; private SpTree[] children;
private static Logger log = LoggerFactory.getLogger(SpTree.class); private static Logger log = LoggerFactory.getLogger(SpTree.class);
private String similarityFunction = Distance.EUCLIDEAN.toString(); private String similarityFunction = Distance.EUCLIDEAN.toString();
public SpTree(SpTree parent, INDArray data, INDArray corner, INDArray width, Set<INDArray> indices, public SpTree(SpTree parent, INDArray data, INDArray corner, INDArray width, Collection<INDArray> indices,
String similarityFunction) { String similarityFunction) {
init(parent, data, corner, width, indices, similarityFunction); init(parent, data, corner, width, indices, similarityFunction);
} }
public SpTree(INDArray data, Set<INDArray> indices, String similarityFunction) { public SpTree(INDArray data, Collection<INDArray> indices, String similarityFunction) {
this.indices = indices; this.indices = indices;
this.N = data.rows(); this.N = data.rows();
this.D = data.columns(); this.D = data.columns();
@ -90,26 +91,26 @@ public class SpTree implements Serializable {
} }
public SpTree(SpTree parent, INDArray data, INDArray corner, INDArray width, Set<INDArray> indices) { public SpTree(SpTree parent, INDArray data, INDArray corner, INDArray width, Collection<INDArray> indices) {
this(parent, data, corner, width, indices, "euclidean"); this(parent, data, corner, width, indices, "euclidean");
} }
public SpTree(INDArray data, Set<INDArray> indices) { public SpTree(INDArray data, Collection<INDArray> indices) {
this(data, indices, "euclidean"); this(data, indices, "euclidean");
} }
public SpTree(INDArray data) { public SpTree(INDArray data) {
this(data, new HashSet<INDArray>()); this(data, new ArrayList<INDArray>());
} }
public MemoryWorkspace workspace() { public MemoryWorkspace workspace() {
return null; return null;
} }
private void init(SpTree parent, INDArray data, INDArray corner, INDArray width, Set<INDArray> indices, private void init(SpTree parent, INDArray data, INDArray corner, INDArray width, Collection<INDArray> indices,
String similarityFunction) { String similarityFunction) {
this.parent = parent; this.parent = parent;

View File

@ -0,0 +1,54 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_hashcode)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/transforms.h>
#include <ops/declarable/helpers/hashcode.h>
namespace nd4j {
namespace ops {
REDUCTION_OP_IMPL(hashcode, 1, 1, false, 0, 0) {
REQUIRE_TRUE(block.width() == 1, 0, "hashcode: this op can't be applied along dimension");
auto input = INPUT_VARIABLE(0);
auto output = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(output->isScalar(), 0, "hashcode: this op requires scalar output");
helpers::hashCode(block.launchContext(), *input, *output);
return Status::OK();
};
DECLARE_TYPES(hashcode) {
getOpDescriptor()
->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS})
->setAllowedInputTypes(1, {ALL_INTS})
->setAllowedOutputTypes({nd4j::DataType::INT64});
};
}
}
#endif

View File

@ -199,6 +199,13 @@ namespace nd4j {
DECLARE_CONFIGURABLE_OP(standardize, 1, 1, true, 0, -2); DECLARE_CONFIGURABLE_OP(standardize, 1, 1, true, 0, -2);
DECLARE_CUSTOM_OP(standardize_bp, 2, 1, false, 0, -2); DECLARE_CUSTOM_OP(standardize_bp, 2, 1, false, 0, -2);
#endif #endif
/**
* This operation calculates hash code, optionally along dimension
*/
#if NOT_EXCLUDED(OP_hashcode)
DECLARE_REDUCTION_OP(hashcode, 1, 1, false, 0, 0);
#endif
} }
} }

View File

@ -0,0 +1,101 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <ops/declarable/helpers/hashcode.h>
namespace nd4j {
namespace ops {
namespace helpers {
template <typename T>
static void hashCode_(LaunchContext *context, NDArray &array, NDArray &result) {
auto blockSize = 32;
auto length = array.lengthOf();
int numBlocks = length / blockSize + ((length % blockSize == 0) ? 0 : 1);
auto tempA = NDArrayFactory::create<Nd4jLong>('c', {numBlocks}, context);
auto tempB = NDArrayFactory::create<Nd4jLong>('c', { numBlocks / blockSize + 1}, context);
auto buffer = array.bufferAsT<T>();
auto tempBufferA = tempA.bufferAsT<Nd4jLong>();
auto tempBufferB = tempB.bufferAsT<Nd4jLong>();
// default buffer is the first one, because it might be the last one in case of small arrays (< blockSize)
auto tempBuffer = tempBufferA;
auto tempResult = tempBufferB;
// we divide array into 32 element chunks, and store intermediate results once
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int b = 0; b < numBlocks; b++) {
auto blockBuffer = buffer + b * numBlocks;
Nd4jLong r = 1;
for (int e = 0; e < blockSize && e + (b * numBlocks) < length; e++) {
auto v = longBytes<T>(blockBuffer[e]);
r = 31 * r + v;
}
tempBuffer[b] = r;
}
// we replace pointer with intermediate one, and repeat only one chunk left
int iterationCount = 0;
while (numBlocks > 1) {
int lastLength = numBlocks;
numBlocks = lastLength / blockSize + ((lastLength % blockSize == 0) ? 0 : 1);
PRAGMA_OMP_PARALLEL_FOR_SIMD
for (int b = 0; b < numBlocks; b++) {
auto blockBuffer = tempBuffer + b * numBlocks;
Nd4jLong r = 1;
for (int e = 0; e < blockSize && e + (b * numBlocks) < lastLength; e++) {
auto v = longBytes<T>(blockBuffer[e]);
r = 31 * r + v;
}
tempResult[b] = r;
}
iterationCount++;
// swapping buffers
if (iterationCount % 2 == 0) {
tempBuffer = tempBufferA;
tempResult = tempBufferB;
} else {
tempBuffer = tempBufferB;
tempResult = tempBufferA;
}
}
if (length <= blockSize)
result.p(0, tempBufferA[0]);
else
result.p(0, tempResult[0]);
}
void hashCode(LaunchContext *context, NDArray &array, NDArray &result) {
BUILD_SINGLE_SELECTOR(array.dataType(), hashCode_, (context, array, result), LIBND4J_TYPES);
}
}
}
}

View File

@ -0,0 +1,32 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <ops/declarable/helpers/hashcode.h>
namespace nd4j {
namespace ops {
namespace helpers {
void hashCode(LaunchContext *context, NDArray &array, NDArray &result) {
}
}
}
}

View File

@ -0,0 +1,70 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef DEV_TESTS_HASHCODE_H
#define DEV_TESTS_HASHCODE_H
#include "helpers.h"
namespace nd4j {
namespace ops {
namespace helpers {
template <typename T>
FORCEINLINE Nd4jLong longBytes(T value);
template <>
FORCEINLINE Nd4jLong longBytes(float value) {
int intie = *(int *)&value;
return static_cast<Nd4jLong>(intie);
}
template <>
FORCEINLINE Nd4jLong longBytes(double value) {
Nd4jLong longie = *(Nd4jLong *)&value;
return longie;
}
template <>
FORCEINLINE Nd4jLong longBytes(float16 value) {
return longBytes<float>((float) value);
}
template <>
FORCEINLINE Nd4jLong longBytes(Nd4jLong value) {
return value;
}
template <>
FORCEINLINE Nd4jLong longBytes(bfloat16 value) {
return longBytes<float>((float) value);
}
template <typename T>
FORCEINLINE Nd4jLong longBytes(T value) {
return longBytes<Nd4jLong>((Nd4jLong) value);
}
void hashCode(LaunchContext *context, NDArray &array, NDArray &result);
}
}
}
#endif //DEV_TESTS_HASHCODE_H

View File

@ -1925,10 +1925,10 @@ TEST_F(DeclarableOpsTests10, Image_NonMaxSuppressing_2) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, Image_CropAndResize_1) { TEST_F(DeclarableOpsTests10, Image_CropAndResize_1) {
int axis = 0;
NDArray images = NDArrayFactory::create<double>('c', {1,2,2,1}, {1,2,3,4}); NDArray images = NDArrayFactory::create<double>('c', {1,2,2,1}, {1,2,3,4});
NDArray boxes = NDArrayFactory::create<float>('c', {1,4}, {0,0,1,1}); NDArray boxes = NDArrayFactory::create<float>('c', {1,4}, {0,0,1,1});
NDArray boxI = NDArrayFactory::create<int>('c', {1}, {(int)0}); NDArray boxI = NDArrayFactory::create<int>('c', {1}, {axis});
NDArray cropSize = NDArrayFactory::create<int>({1, 1}); NDArray cropSize = NDArrayFactory::create<int>({1, 1});
//NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); //NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f});
@ -1949,10 +1949,10 @@ TEST_F(DeclarableOpsTests10, Image_CropAndResize_1) {
//////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////
TEST_F(DeclarableOpsTests10, Image_CropAndResize_2) { TEST_F(DeclarableOpsTests10, Image_CropAndResize_2) {
int axis = 0;
NDArray images = NDArrayFactory::create<float>('c', {1,2,2,1}, {1,2,3,4}); NDArray images = NDArrayFactory::create<float>('c', {1,2,2,1}, {1,2,3,4});
NDArray boxes = NDArrayFactory::create<float>('c', {1,4}, {0,0,1,1}); NDArray boxes = NDArrayFactory::create<float>('c', {1,4}, {0,0,1,1});
NDArray boxI = NDArrayFactory::create<int>('c', {1}, {(int)0}); NDArray boxI = NDArrayFactory::create<int>('c', {1}, {axis});
NDArray cropSize = NDArrayFactory::create<int>({1, 1}); NDArray cropSize = NDArrayFactory::create<int>({1, 1});
//NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f}); //NDArray<float> ('c', {6}, {0.9f, .75f, .6f, .95f, .5f, .3f});

View File

@ -250,6 +250,46 @@ TEST_F(DeclarableOpsTests15, Test_layer_norm_bp_1) {
delete result; delete result;
} }
TEST_F(DeclarableOpsTests15, test_hashCode_1) {
auto x = NDArrayFactory::create<int>('c', {10});
auto y = NDArrayFactory::create<int>('c', {10});
x.linspace(1.);
y.linspace(2.);
nd4j::ops::hashcode op;
auto resultA0 = op.execute({&x}, {}, {});
auto resultA1 = op.execute({&x}, {}, {});
auto resultB0 = op.execute({&y}, {}, {});
ASSERT_EQ(*resultA0->at(0), *resultA1->at(0));
ASSERT_NE(*resultA0->at(0), *resultB0->at(0));
delete resultA0;
delete resultA1;
delete resultB0;
}
TEST_F(DeclarableOpsTests15, test_hashCode_2) {
auto x = NDArrayFactory::create<int>('c', {1027});
auto y = NDArrayFactory::create<int>('c', {1027});
x.linspace(1.);
y.linspace(2.);
nd4j::ops::hashcode op;
auto resultA0 = op.execute({&x}, {}, {});
auto resultA1 = op.execute({&x}, {}, {});
auto resultB0 = op.execute({&y}, {}, {});
ASSERT_EQ(*resultA0->at(0), *resultA1->at(0));
ASSERT_NE(*resultA0->at(0), *resultB0->at(0));
delete resultA0;
delete resultA1;
delete resultB0;
}
TEST_F(DeclarableOpsTests15, test_lstmBlock_1) { TEST_F(DeclarableOpsTests15, test_lstmBlock_1) {
auto x0 = NDArrayFactory::create<Nd4jLong>(5); auto x0 = NDArrayFactory::create<Nd4jLong>(5);
auto x1 = NDArrayFactory::create<float>('c', {5, 1, 4}, {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f, 0.50563407f, 0.89252293f, 0.5461209f, 0.92336726f, 0.085571885f, 0.7937801f, 0.65908563f, 0.55552566f, 0.15962744f, 0.30874777f, 0.15476847f, 0.46954823f, 0.9938899f, 0.6112741f}); auto x1 = NDArrayFactory::create<float>('c', {5, 1, 4}, {0.7787856f, 0.80119777f, 0.72437465f, 0.23089433f, 0.72714126f, 0.18039072f, 0.50563407f, 0.89252293f, 0.5461209f, 0.92336726f, 0.085571885f, 0.7937801f, 0.65908563f, 0.55552566f, 0.15962744f, 0.30874777f, 0.15476847f, 0.46954823f, 0.9938899f, 0.6112741f});

View File

@ -842,9 +842,10 @@ TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_2) {
} }
TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_3) { TEST_F(DeclarableOpsTests4, Test_StridedSlice_Alex_3) {
int axis = 0;
auto x = NDArrayFactory::create<double>('c', {1}, {10}); auto x = NDArrayFactory::create<double>('c', {1}, {10});
auto begin = NDArrayFactory::create<int>('c', {1}, {(int)0}); auto begin = NDArrayFactory::create<int>('c', {1}, {axis});
auto end = NDArrayFactory::create<int>('c', {1}, {(int)0}); auto end = NDArrayFactory::create<int>('c', {1}, {axis});
auto stride = NDArrayFactory::create<int>('c', {1}, {1}); auto stride = NDArrayFactory::create<int>('c', {1}, {1});
//x.linspace(1); //x.linspace(1);
//auto exp = NDArrayFactory::create<double>('c', {1,3,4,5}); //auto exp = NDArrayFactory::create<double>('c', {1,3,4,5});

View File

@ -643,8 +643,8 @@ TEST_F(ParityOpsTests, Test_Select_2) {
} }
TEST_F(ParityOpsTests, Test_Select_3) { TEST_F(ParityOpsTests, Test_Select_3) {
bool value = false;
auto mask = NDArrayFactory::create<bool>('c', {1, 1}, {false}); auto mask = NDArrayFactory::create<bool>('c', {1, 1}, {value});
auto x = NDArrayFactory::create<float>('c', {1, 1}, {1}); auto x = NDArrayFactory::create<float>('c', {1, 1}, {1});
auto y = NDArrayFactory::create<float>('c', {1, 1}, {2}); auto y = NDArrayFactory::create<float>('c', {1, 1}, {2});
auto exp = NDArrayFactory::create<float>('c', {1, 1}, {2}); auto exp = NDArrayFactory::create<float>('c', {1, 1}, {2});

View File

@ -1601,6 +1601,30 @@ TEST_F(PlaygroundTests, test_assign_float) {
} }
TEST_F(PlaygroundTests, test_hash_1) {
std::vector<int> vec;
for (int e = 1; e < 100000; e++)
vec.emplace_back(e);
int size = vec.size();
int r = 0;
PRAGMA_OMP_PARALLEL_FOR_REDUCTION(+:r)
for (int e = 0; e < size; e++) {
r += 31 * vec[e];
}
nd4j_printf("Result: %i\n", r);
}
TEST_F(PlaygroundTests, test_hash_2) {
auto x = NDArrayFactory::create<int>('c', {5, 10000});
x.linspace(1.f);
//auto h = x.reduceNumber(reduce::LongOps::HashCode);
//h.printIndexedBuffer("hash");
}
/* /*
TEST_F(PlaygroundTests, test_manual_loop) { TEST_F(PlaygroundTests, test_manual_loop) {
const unsigned int len = 32 * 128 * 256 * 256; const unsigned int len = 32 * 128 * 256 * 256;

View File

@ -596,8 +596,6 @@ public class GradCheckUtil {
DifferentialFunction[] dfs = sd.functions(); DifferentialFunction[] dfs = sd.functions();
List<SDVariable> vars = sd.variables(); List<SDVariable> vars = sd.variables();
Set<SDVariable> varsSet = new HashSet<>(vars);
Preconditions.checkState(vars.size() == varsSet.size(), "Duplicate variables in variables() list");
Set<String> varSetStr = new HashSet<>(); Set<String> varSetStr = new HashSet<>();
for(SDVariable v : vars){ for(SDVariable v : vars){
if(varSetStr.contains(v.getVarName())){ if(varSetStr.contains(v.getVarName())){
@ -605,6 +603,7 @@ public class GradCheckUtil {
} }
varSetStr.add(v.getVarName()); varSetStr.add(v.getVarName());
} }
Preconditions.checkState(vars.size() == varSetStr.size(), "Duplicate variables in variables() list");
//1. Check incomingArgsReverse and outgoingArgsReverse //1. Check incomingArgsReverse and outgoingArgsReverse
Map<String,SameDiffOp> ops = sd.getOps(); Map<String,SameDiffOp> ops = sd.getOps();

View File

@ -173,6 +173,7 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.meta.ReduceMetaOp.class, org.nd4j.linalg.api.ops.impl.meta.ReduceMetaOp.class,
org.nd4j.linalg.api.ops.impl.nlp.CbowRound.class, org.nd4j.linalg.api.ops.impl.nlp.CbowRound.class,
org.nd4j.linalg.api.ops.impl.nlp.SkipGramRound.class, org.nd4j.linalg.api.ops.impl.nlp.SkipGramRound.class,
org.nd4j.linalg.api.ops.impl.reduce.HashCode.class,
org.nd4j.linalg.api.ops.impl.reduce.Mmul.class, org.nd4j.linalg.api.ops.impl.reduce.Mmul.class,
org.nd4j.linalg.api.ops.impl.reduce.MmulBp.class, org.nd4j.linalg.api.ops.impl.reduce.MmulBp.class,
org.nd4j.linalg.api.ops.impl.reduce.Moments.class, org.nd4j.linalg.api.ops.impl.reduce.Moments.class,

View File

@ -38,6 +38,7 @@ import org.nd4j.linalg.api.iter.NdIndexIterator;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.reduce.HashCode;
import org.nd4j.linalg.api.ops.impl.reduce.bool.All; import org.nd4j.linalg.api.ops.impl.reduce.bool.All;
import org.nd4j.linalg.api.ops.impl.reduce.bool.Any; import org.nd4j.linalg.api.ops.impl.reduce.bool.Any;
import org.nd4j.linalg.api.ops.impl.reduce.floating.*; import org.nd4j.linalg.api.ops.impl.reduce.floating.*;
@ -5444,6 +5445,12 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return equalsWithEps(o, Nd4j.EPS_THRESHOLD); return equalsWithEps(o, Nd4j.EPS_THRESHOLD);
} }
@Override
public int hashCode() {
val longHash = Nd4j.exec(new HashCode(this))[0].getLong(0);
return Math.abs(longHash) <= Integer.MAX_VALUE ? (int) longHash : (int) (longHash % Integer.MAX_VALUE);
}
@Override @Override
public DataBuffer shapeInfoDataBuffer() { public DataBuffer shapeInfoDataBuffer() {
return shapeInformation; return shapeInformation;
@ -6826,4 +6833,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
public INDArray ulike() { public INDArray ulike() {
return Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering()); return Nd4j.createUninitialized(this.dataType(), this.shape(), this.ordering());
} }
} }

View File

@ -0,0 +1,59 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.reduce;
import lombok.NonNull;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import java.util.Collections;
import java.util.List;
/**
* This is hashCode op wrapper. Basically - simple parallel hash implementation.
*
* @author raver119@gmail.com
*/
public class HashCode extends DynamicCustomOp {
public HashCode() {
//
}
public HashCode(@NonNull INDArray array) {
this.inputArguments.add(array);
}
public HashCode(@NonNull INDArray array, @NonNull INDArray result) {
this(array);
Preconditions.checkArgument(result.dataType() == DataType.LONG && result.isScalar(), "HashCode op expects LONG scalar as output");
this.outputArguments.add(result);
}
@Override
public List<LongShapeDescriptor> calculateOutputShape() {
return Collections.singletonList(LongShapeDescriptor.fromShape(new long[0], DataType.LONG));
}
@Override
public String opName() {
return "hashcode";
}
}

View File

@ -16712,6 +16712,26 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block);
} }
// #endif // #endif
/**
* This operation calculates hash code, optionally along dimension
*/
// #if NOT_EXCLUDED(OP_hashcode)
@Namespace("nd4j::ops") public static class hashcode extends DeclarableReductionOp {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public hashcode(Pointer p) { super(p); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public hashcode(long size) { super((Pointer)null); allocateArray(size); }
private native void allocateArray(long size);
@Override public hashcode position(long position) {
return (hashcode)super.position(position);
}
public hashcode() { super((Pointer)null); allocate(); }
private native void allocate();
}
// #endif

View File

@ -108,7 +108,10 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
//2019/06/22 - Known issue: https://github.com/eclipse/deeplearning4j/issues/7935 //2019/06/22 - Known issue: https://github.com/eclipse/deeplearning4j/issues/7935
"fake_quant/min_max_vars/.*", "fake_quant/min_max_vars/.*",
"fake_quant/min_max_args/.*" "fake_quant/min_max_args/.*",
//2019/07/09 - Need "Multinomial" op - https://github.com/eclipse/deeplearning4j/issues/7913
"multinormal/.*"
}; };
@BeforeClass @BeforeClass