[WIP] bits_hamming_distance (#192)
* bits_hamming_distance op Signed-off-by: raver119 <raver119@gmail.com> * bits_hamming_distance cuda Signed-off-by: raver119 <raver119@gmail.com>master
parent
f4860574d7
commit
dec296da17
|
@ -0,0 +1,57 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <op_boilerplate.h>
|
||||||
|
#if NOT_EXCLUDED(OP_bits_hamming_distance)
|
||||||
|
|
||||||
|
#include <ops/declarable/CustomOperations.h>
|
||||||
|
#include <ops/declarable/helpers/helpers.h>
|
||||||
|
#include <ops/declarable/helpers/hamming.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
CUSTOM_OP_IMPL(bits_hamming_distance, 2, 1, true, 0, 0) {
|
||||||
|
auto x = INPUT_VARIABLE(0);
|
||||||
|
auto y = INPUT_VARIABLE(1);
|
||||||
|
auto output = OUTPUT_VARIABLE(0);
|
||||||
|
|
||||||
|
REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0, "bits_hamming_distance: both arguments must have the same length");
|
||||||
|
REQUIRE_TRUE(x->dataType() == y->dataType(), 0, "bits_hamming_distance: both arguments must have the same data type");
|
||||||
|
|
||||||
|
helpers::hamming(block.launchContext(), *x, *y, *output);
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_SHAPE_FN(bits_hamming_distance) {
|
||||||
|
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(nd4j::DataType::INT64));
|
||||||
|
}
|
||||||
|
|
||||||
|
DECLARE_TYPES(bits_hamming_distance) {
|
||||||
|
getOpDescriptor()
|
||||||
|
->setAllowedInputTypes(0, {ALL_INTS})
|
||||||
|
->setAllowedInputTypes(1, {ALL_INTS})
|
||||||
|
->setAllowedOutputTypes(0, {ALL_INDICES})
|
||||||
|
->setSameMode(true);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif
|
|
@ -80,6 +80,17 @@ namespace nd4j {
|
||||||
#if NOT_EXCLUDED(OP_cyclic_rshift_bits)
|
#if NOT_EXCLUDED(OP_cyclic_rshift_bits)
|
||||||
DECLARE_CONFIGURABLE_OP(cyclic_rshift_bits, 1, 1, true, 0, -2);
|
DECLARE_CONFIGURABLE_OP(cyclic_rshift_bits, 1, 1, true, 0, -2);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This operation returns hamming distance based on bits
|
||||||
|
*
|
||||||
|
* PLEASE NOTE: This operation is applicable only to integer data types
|
||||||
|
*
|
||||||
|
* @tparam T
|
||||||
|
*/
|
||||||
|
#if NOT_EXCLUDED(OP_bits_hamming_distance)
|
||||||
|
DECLARE_CUSTOM_OP(bits_hamming_distance, 2, 1, true, 0, 0);
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,72 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/helpers.h>
|
||||||
|
#include <ops/declarable/helpers/hamming.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
template <typename X, typename Z>
|
||||||
|
static void _hamming(NDArray &x, NDArray &y, NDArray &z) {
|
||||||
|
auto xEws = x.ews();
|
||||||
|
auto yEws = y.ews();
|
||||||
|
|
||||||
|
auto xBuffer = x.bufferAsT<X>();
|
||||||
|
auto yBuffer = y.bufferAsT<X>();
|
||||||
|
|
||||||
|
Nd4jLong distance = 0;
|
||||||
|
|
||||||
|
if (xEws == 1 && yEws == 1 && x.ordering() == y.ordering()) {
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:distance)
|
||||||
|
for (Nd4jLong e = 0; e < x.lengthOf(); e++) {
|
||||||
|
auto _x = static_cast<unsigned long long>(xBuffer[e]);
|
||||||
|
auto _y = static_cast<unsigned long long>(yBuffer[e]);
|
||||||
|
|
||||||
|
distance += __builtin_popcountll(_x ^ _y);
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if (xEws > 1 && yEws > 1 && x.ordering() == y.ordering()) {
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:distance)
|
||||||
|
for (Nd4jLong e = 0; e < x.lengthOf(); e++) {
|
||||||
|
auto _x = static_cast<unsigned long long>(xBuffer[e * xEws]);
|
||||||
|
auto _y = static_cast<unsigned long long>(yBuffer[e * yEws]);
|
||||||
|
|
||||||
|
distance += __builtin_popcountll(_x ^ _y);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:distance)
|
||||||
|
for (Nd4jLong e = 0; e < x.lengthOf(); e++) {
|
||||||
|
auto _x = static_cast<unsigned long long>(x.e<Nd4jLong>(e));
|
||||||
|
auto _y = static_cast<unsigned long long>(y.e<Nd4jLong>(e));
|
||||||
|
|
||||||
|
distance += __builtin_popcountll(_x ^ _y);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
z.p(0, distance);
|
||||||
|
}
|
||||||
|
|
||||||
|
void hamming(LaunchContext *context, NDArray &x, NDArray &y, NDArray &output) {
|
||||||
|
BUILD_DOUBLE_SELECTOR(x.dataType(), output.dataType(), _hamming, (x, y, output), INTEGER_TYPES, INDEXING_TYPES);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,95 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <ops/declarable/helpers/helpers.h>
|
||||||
|
#include <ops/declarable/helpers/hamming.h>
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
template <typename X, typename Z>
|
||||||
|
static _CUDA_G void _hammingKernel(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, void *vz, void *reductionBuffer, Nd4jLong length) {
|
||||||
|
auto x = reinterpret_cast<X*>(vx);
|
||||||
|
auto y = reinterpret_cast<X*>(vy);
|
||||||
|
auto z = reinterpret_cast<Z*>(vz);
|
||||||
|
|
||||||
|
__shared__ Nd4jLong *shared;
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
extern __shared__ unsigned char shmem[];
|
||||||
|
shared = reinterpret_cast<Nd4jLong*>(shmem);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// we want to nullify temporary memory before accumulating intermediate results
|
||||||
|
shared[threadIdx.x] = 0;
|
||||||
|
|
||||||
|
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||||
|
for (Nd4jLong e = tid; e < length; e += blockDim.x * gridDim.x) {
|
||||||
|
auto _x = static_cast<unsigned long long>(x[shape::getIndexOffset(e, xShapeInfo, length)]);
|
||||||
|
auto _y = static_cast<unsigned long long>(y[shape::getIndexOffset(e, yShapeInfo, length)]);
|
||||||
|
|
||||||
|
// we save intermediate result into shared memory
|
||||||
|
shared[threadIdx.x] += __popcll(_x ^ _y);
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// now we accumulate values
|
||||||
|
auto numItems = nd4j::math::nd4j_min<Nd4jLong>(blockDim.x, length);
|
||||||
|
auto floorPow2 = numItems;
|
||||||
|
if (floorPow2 & (floorPow2 - 1)) {
|
||||||
|
|
||||||
|
while (floorPow2 & (floorPow2 - 1))
|
||||||
|
floorPow2 &= floorPow2 - 1;
|
||||||
|
|
||||||
|
if (threadIdx.x >= floorPow2)
|
||||||
|
shared[threadIdx.x - floorPow2] = shared[threadIdx.x - floorPow2] + shared[threadIdx.x];
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; activeThreads >>= 1) {
|
||||||
|
if (threadIdx.x < activeThreads && threadIdx.x + activeThreads < numItems)
|
||||||
|
shared[threadIdx.x] = shared[threadIdx.x] + shared[threadIdx.x + activeThreads];
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// FIXME: do we really want atomicAdd on global memory here
|
||||||
|
// and store them to output
|
||||||
|
if (threadIdx.x == 0 && shared[0] > 0)
|
||||||
|
nd4j::math::atomics::nd4j_atomicAdd<Z>(&z[0], static_cast<Z>(shared[threadIdx.x]));
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename X, typename Z>
|
||||||
|
static void _hamming(LaunchContext *context, NDArray &x, NDArray &y, NDArray &z) {
|
||||||
|
_hammingKernel<X, Z><<<256, 256, 256 * sizeof(Nd4jLong) + 256, *context->getCudaStream()>>>(x.specialBuffer(), x.specialShapeInfo(), y.specialBuffer(), y.specialShapeInfo(), z.specialBuffer(), nullptr, x.lengthOf());
|
||||||
|
}
|
||||||
|
|
||||||
|
void hamming(LaunchContext *context, NDArray &x, NDArray &y, NDArray &output) {
|
||||||
|
NDArray::prepareSpecialUse({&output}, {&x, &y});
|
||||||
|
BUILD_DOUBLE_SELECTOR(x.dataType(), output.dataType(), _hamming, (context, x, y, output), INTEGER_TYPES, INDEXING_TYPES);
|
||||||
|
NDArray::registerSpecialUse({&output}, {&x, &y});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,32 @@
|
||||||
|
/*******************************************************************************
|
||||||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||||
|
*
|
||||||
|
* This program and the accompanying materials are made available under the
|
||||||
|
* terms of the Apache License, Version 2.0 which is available at
|
||||||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
* License for the specific language governing permissions and limitations
|
||||||
|
* under the License.
|
||||||
|
*
|
||||||
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
|
******************************************************************************/
|
||||||
|
|
||||||
|
//
|
||||||
|
// @author raver119@gmail.com
|
||||||
|
//
|
||||||
|
|
||||||
|
#ifndef SD_HAMMING_H
|
||||||
|
#define SD_HAMMING_H
|
||||||
|
|
||||||
|
namespace nd4j {
|
||||||
|
namespace ops {
|
||||||
|
namespace helpers {
|
||||||
|
void hamming(LaunchContext *context, NDArray &x, NDArray &y, NDArray &output);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif //DEV_TESTS_HAMMING_H
|
|
@ -117,4 +117,20 @@ TEST_F(DeclarableOpsTests16, test_svd_1) {
|
||||||
auto status = op.execute({&x}, {&z}, {}, {0, 0, 16}, {});
|
auto status = op.execute({&x}, {&z}, {}, {0, 0, 16}, {});
|
||||||
|
|
||||||
ASSERT_EQ(Status::OK(), status);
|
ASSERT_EQ(Status::OK(), status);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(DeclarableOpsTests16, test_hamming_distance_1) {
|
||||||
|
auto x = NDArrayFactory::create<Nd4jLong>({37, 37, 37});
|
||||||
|
auto y = NDArrayFactory::create<Nd4jLong>({8723, 8723, 8723});
|
||||||
|
auto e = NDArrayFactory::create<Nd4jLong>(18);
|
||||||
|
|
||||||
|
nd4j::ops::bits_hamming_distance op;
|
||||||
|
auto result = op.execute({&x, &y}, {}, {});
|
||||||
|
ASSERT_EQ(Status::OK(), result->status());
|
||||||
|
|
||||||
|
auto z = result->at(0);
|
||||||
|
|
||||||
|
ASSERT_EQ(e, *z);
|
||||||
|
|
||||||
|
delete result;
|
||||||
}
|
}
|
Loading…
Reference in New Issue