[WIP] bits_hamming_distance (#192)
* bits_hamming_distance op Signed-off-by: raver119 <raver119@gmail.com> * bits_hamming_distance cuda Signed-off-by: raver119 <raver119@gmail.com>master
parent
f4860574d7
commit
dec296da17
|
@ -0,0 +1,57 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2019 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <op_boilerplate.h>
|
||||
#if NOT_EXCLUDED(OP_bits_hamming_distance)
|
||||
|
||||
#include <ops/declarable/CustomOperations.h>
|
||||
#include <ops/declarable/helpers/helpers.h>
|
||||
#include <ops/declarable/helpers/hamming.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
CUSTOM_OP_IMPL(bits_hamming_distance, 2, 1, true, 0, 0) {
|
||||
auto x = INPUT_VARIABLE(0);
|
||||
auto y = INPUT_VARIABLE(1);
|
||||
auto output = OUTPUT_VARIABLE(0);
|
||||
|
||||
REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0, "bits_hamming_distance: both arguments must have the same length");
|
||||
REQUIRE_TRUE(x->dataType() == y->dataType(), 0, "bits_hamming_distance: both arguments must have the same data type");
|
||||
|
||||
helpers::hamming(block.launchContext(), *x, *y, *output);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
DECLARE_SHAPE_FN(bits_hamming_distance) {
|
||||
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(nd4j::DataType::INT64));
|
||||
}
|
||||
|
||||
DECLARE_TYPES(bits_hamming_distance) {
|
||||
getOpDescriptor()
|
||||
->setAllowedInputTypes(0, {ALL_INTS})
|
||||
->setAllowedInputTypes(1, {ALL_INTS})
|
||||
->setAllowedOutputTypes(0, {ALL_INDICES})
|
||||
->setSameMode(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
|
@ -80,6 +80,17 @@ namespace nd4j {
|
|||
#if NOT_EXCLUDED(OP_cyclic_rshift_bits)
|
||||
DECLARE_CONFIGURABLE_OP(cyclic_rshift_bits, 1, 1, true, 0, -2);
|
||||
#endif
|
||||
|
||||
/**
|
||||
* This operation returns hamming distance based on bits
|
||||
*
|
||||
* PLEASE NOTE: This operation is applicable only to integer data types
|
||||
*
|
||||
* @tparam T
|
||||
*/
|
||||
#if NOT_EXCLUDED(OP_bits_hamming_distance)
|
||||
DECLARE_CUSTOM_OP(bits_hamming_distance, 2, 1, true, 0, 0);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <ops/declarable/helpers/helpers.h>
|
||||
#include <ops/declarable/helpers/hamming.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
template <typename X, typename Z>
|
||||
static void _hamming(NDArray &x, NDArray &y, NDArray &z) {
|
||||
auto xEws = x.ews();
|
||||
auto yEws = y.ews();
|
||||
|
||||
auto xBuffer = x.bufferAsT<X>();
|
||||
auto yBuffer = y.bufferAsT<X>();
|
||||
|
||||
Nd4jLong distance = 0;
|
||||
|
||||
if (xEws == 1 && yEws == 1 && x.ordering() == y.ordering()) {
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:distance)
|
||||
for (Nd4jLong e = 0; e < x.lengthOf(); e++) {
|
||||
auto _x = static_cast<unsigned long long>(xBuffer[e]);
|
||||
auto _y = static_cast<unsigned long long>(yBuffer[e]);
|
||||
|
||||
distance += __builtin_popcountll(_x ^ _y);
|
||||
}
|
||||
|
||||
} else if (xEws > 1 && yEws > 1 && x.ordering() == y.ordering()) {
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:distance)
|
||||
for (Nd4jLong e = 0; e < x.lengthOf(); e++) {
|
||||
auto _x = static_cast<unsigned long long>(xBuffer[e * xEws]);
|
||||
auto _y = static_cast<unsigned long long>(yBuffer[e * yEws]);
|
||||
|
||||
distance += __builtin_popcountll(_x ^ _y);
|
||||
}
|
||||
} else {
|
||||
PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:distance)
|
||||
for (Nd4jLong e = 0; e < x.lengthOf(); e++) {
|
||||
auto _x = static_cast<unsigned long long>(x.e<Nd4jLong>(e));
|
||||
auto _y = static_cast<unsigned long long>(y.e<Nd4jLong>(e));
|
||||
|
||||
distance += __builtin_popcountll(_x ^ _y);
|
||||
}
|
||||
}
|
||||
|
||||
z.p(0, distance);
|
||||
}
|
||||
|
||||
void hamming(LaunchContext *context, NDArray &x, NDArray &y, NDArray &output) {
|
||||
BUILD_DOUBLE_SELECTOR(x.dataType(), output.dataType(), _hamming, (x, y, output), INTEGER_TYPES, INDEXING_TYPES);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,95 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#include <ops/declarable/helpers/helpers.h>
|
||||
#include <ops/declarable/helpers/hamming.h>
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
template <typename X, typename Z>
|
||||
static _CUDA_G void _hammingKernel(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, void *vz, void *reductionBuffer, Nd4jLong length) {
|
||||
auto x = reinterpret_cast<X*>(vx);
|
||||
auto y = reinterpret_cast<X*>(vy);
|
||||
auto z = reinterpret_cast<Z*>(vz);
|
||||
|
||||
__shared__ Nd4jLong *shared;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
extern __shared__ unsigned char shmem[];
|
||||
shared = reinterpret_cast<Nd4jLong*>(shmem);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// we want to nullify temporary memory before accumulating intermediate results
|
||||
shared[threadIdx.x] = 0;
|
||||
|
||||
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
for (Nd4jLong e = tid; e < length; e += blockDim.x * gridDim.x) {
|
||||
auto _x = static_cast<unsigned long long>(x[shape::getIndexOffset(e, xShapeInfo, length)]);
|
||||
auto _y = static_cast<unsigned long long>(y[shape::getIndexOffset(e, yShapeInfo, length)]);
|
||||
|
||||
// we save intermediate result into shared memory
|
||||
shared[threadIdx.x] += __popcll(_x ^ _y);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// now we accumulate values
|
||||
auto numItems = nd4j::math::nd4j_min<Nd4jLong>(blockDim.x, length);
|
||||
auto floorPow2 = numItems;
|
||||
if (floorPow2 & (floorPow2 - 1)) {
|
||||
|
||||
while (floorPow2 & (floorPow2 - 1))
|
||||
floorPow2 &= floorPow2 - 1;
|
||||
|
||||
if (threadIdx.x >= floorPow2)
|
||||
shared[threadIdx.x - floorPow2] = shared[threadIdx.x - floorPow2] + shared[threadIdx.x];
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; activeThreads >>= 1) {
|
||||
if (threadIdx.x < activeThreads && threadIdx.x + activeThreads < numItems)
|
||||
shared[threadIdx.x] = shared[threadIdx.x] + shared[threadIdx.x + activeThreads];
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// FIXME: do we really want atomicAdd on global memory here
|
||||
// and store them to output
|
||||
if (threadIdx.x == 0 && shared[0] > 0)
|
||||
nd4j::math::atomics::nd4j_atomicAdd<Z>(&z[0], static_cast<Z>(shared[threadIdx.x]));
|
||||
}
|
||||
|
||||
template <typename X, typename Z>
|
||||
static void _hamming(LaunchContext *context, NDArray &x, NDArray &y, NDArray &z) {
|
||||
_hammingKernel<X, Z><<<256, 256, 256 * sizeof(Nd4jLong) + 256, *context->getCudaStream()>>>(x.specialBuffer(), x.specialShapeInfo(), y.specialBuffer(), y.specialShapeInfo(), z.specialBuffer(), nullptr, x.lengthOf());
|
||||
}
|
||||
|
||||
void hamming(LaunchContext *context, NDArray &x, NDArray &y, NDArray &output) {
|
||||
NDArray::prepareSpecialUse({&output}, {&x, &y});
|
||||
BUILD_DOUBLE_SELECTOR(x.dataType(), output.dataType(), _hamming, (context, x, y, output), INTEGER_TYPES, INDEXING_TYPES);
|
||||
NDArray::registerSpecialUse({&output}, {&x, &y});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
/*******************************************************************************
|
||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
||||
*
|
||||
* This program and the accompanying materials are made available under the
|
||||
* terms of the Apache License, Version 2.0 which is available at
|
||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* License for the specific language governing permissions and limitations
|
||||
* under the License.
|
||||
*
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
******************************************************************************/
|
||||
|
||||
//
|
||||
// @author raver119@gmail.com
|
||||
//
|
||||
|
||||
#ifndef SD_HAMMING_H
|
||||
#define SD_HAMMING_H
|
||||
|
||||
namespace nd4j {
|
||||
namespace ops {
|
||||
namespace helpers {
|
||||
void hamming(LaunchContext *context, NDArray &x, NDArray &y, NDArray &output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#endif //DEV_TESTS_HAMMING_H
|
|
@ -117,4 +117,20 @@ TEST_F(DeclarableOpsTests16, test_svd_1) {
|
|||
auto status = op.execute({&x}, {&z}, {}, {0, 0, 16}, {});
|
||||
|
||||
ASSERT_EQ(Status::OK(), status);
|
||||
}
|
||||
|
||||
TEST_F(DeclarableOpsTests16, test_hamming_distance_1) {
|
||||
auto x = NDArrayFactory::create<Nd4jLong>({37, 37, 37});
|
||||
auto y = NDArrayFactory::create<Nd4jLong>({8723, 8723, 8723});
|
||||
auto e = NDArrayFactory::create<Nd4jLong>(18);
|
||||
|
||||
nd4j::ops::bits_hamming_distance op;
|
||||
auto result = op.execute({&x, &y}, {}, {});
|
||||
ASSERT_EQ(Status::OK(), result->status());
|
||||
|
||||
auto z = result->at(0);
|
||||
|
||||
ASSERT_EQ(e, *z);
|
||||
|
||||
delete result;
|
||||
}
|
Loading…
Reference in New Issue