[WIP] bits_hamming_distance (#192)

* bits_hamming_distance op

Signed-off-by: raver119 <raver119@gmail.com>

* bits_hamming_distance cuda

Signed-off-by: raver119 <raver119@gmail.com>
master
raver119 2019-08-28 18:20:44 +03:00 committed by GitHub
parent f4860574d7
commit dec296da17
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 283 additions and 0 deletions

View File

@ -0,0 +1,57 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <op_boilerplate.h>
#if NOT_EXCLUDED(OP_bits_hamming_distance)
#include <ops/declarable/CustomOperations.h>
#include <ops/declarable/helpers/helpers.h>
#include <ops/declarable/helpers/hamming.h>
namespace nd4j {
namespace ops {
CUSTOM_OP_IMPL(bits_hamming_distance, 2, 1, true, 0, 0) {
auto x = INPUT_VARIABLE(0);
auto y = INPUT_VARIABLE(1);
auto output = OUTPUT_VARIABLE(0);
REQUIRE_TRUE(x->lengthOf() == y->lengthOf(), 0, "bits_hamming_distance: both arguments must have the same length");
REQUIRE_TRUE(x->dataType() == y->dataType(), 0, "bits_hamming_distance: both arguments must have the same data type");
helpers::hamming(block.launchContext(), *x, *y, *output);
return Status::OK();
}
DECLARE_SHAPE_FN(bits_hamming_distance) {
return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(nd4j::DataType::INT64));
}
DECLARE_TYPES(bits_hamming_distance) {
getOpDescriptor()
->setAllowedInputTypes(0, {ALL_INTS})
->setAllowedInputTypes(1, {ALL_INTS})
->setAllowedOutputTypes(0, {ALL_INDICES})
->setSameMode(true);
}
}
}
#endif

View File

@ -80,6 +80,17 @@ namespace nd4j {
#if NOT_EXCLUDED(OP_cyclic_rshift_bits) #if NOT_EXCLUDED(OP_cyclic_rshift_bits)
DECLARE_CONFIGURABLE_OP(cyclic_rshift_bits, 1, 1, true, 0, -2); DECLARE_CONFIGURABLE_OP(cyclic_rshift_bits, 1, 1, true, 0, -2);
#endif #endif
/**
* This operation returns hamming distance based on bits
*
* PLEASE NOTE: This operation is applicable only to integer data types
*
* @tparam T
*/
#if NOT_EXCLUDED(OP_bits_hamming_distance)
DECLARE_CUSTOM_OP(bits_hamming_distance, 2, 1, true, 0, 0);
#endif
} }
} }

View File

@ -0,0 +1,72 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <ops/declarable/helpers/helpers.h>
#include <ops/declarable/helpers/hamming.h>
namespace nd4j {
namespace ops {
namespace helpers {
template <typename X, typename Z>
static void _hamming(NDArray &x, NDArray &y, NDArray &z) {
auto xEws = x.ews();
auto yEws = y.ews();
auto xBuffer = x.bufferAsT<X>();
auto yBuffer = y.bufferAsT<X>();
Nd4jLong distance = 0;
if (xEws == 1 && yEws == 1 && x.ordering() == y.ordering()) {
PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:distance)
for (Nd4jLong e = 0; e < x.lengthOf(); e++) {
auto _x = static_cast<unsigned long long>(xBuffer[e]);
auto _y = static_cast<unsigned long long>(yBuffer[e]);
distance += __builtin_popcountll(_x ^ _y);
}
} else if (xEws > 1 && yEws > 1 && x.ordering() == y.ordering()) {
PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:distance)
for (Nd4jLong e = 0; e < x.lengthOf(); e++) {
auto _x = static_cast<unsigned long long>(xBuffer[e * xEws]);
auto _y = static_cast<unsigned long long>(yBuffer[e * yEws]);
distance += __builtin_popcountll(_x ^ _y);
}
} else {
PRAGMA_OMP_PARALLEL_FOR_SIMD_REDUCTION(+:distance)
for (Nd4jLong e = 0; e < x.lengthOf(); e++) {
auto _x = static_cast<unsigned long long>(x.e<Nd4jLong>(e));
auto _y = static_cast<unsigned long long>(y.e<Nd4jLong>(e));
distance += __builtin_popcountll(_x ^ _y);
}
}
z.p(0, distance);
}
void hamming(LaunchContext *context, NDArray &x, NDArray &y, NDArray &output) {
BUILD_DOUBLE_SELECTOR(x.dataType(), output.dataType(), _hamming, (x, y, output), INTEGER_TYPES, INDEXING_TYPES);
}
}
}
}

View File

@ -0,0 +1,95 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#include <ops/declarable/helpers/helpers.h>
#include <ops/declarable/helpers/hamming.h>
namespace nd4j {
namespace ops {
namespace helpers {
template <typename X, typename Z>
static _CUDA_G void _hammingKernel(void *vx, Nd4jLong *xShapeInfo, void *vy, Nd4jLong *yShapeInfo, void *vz, void *reductionBuffer, Nd4jLong length) {
auto x = reinterpret_cast<X*>(vx);
auto y = reinterpret_cast<X*>(vy);
auto z = reinterpret_cast<Z*>(vz);
__shared__ Nd4jLong *shared;
if (threadIdx.x == 0) {
extern __shared__ unsigned char shmem[];
shared = reinterpret_cast<Nd4jLong*>(shmem);
}
__syncthreads();
// we want to nullify temporary memory before accumulating intermediate results
shared[threadIdx.x] = 0;
auto tid = threadIdx.x + blockIdx.x * blockDim.x;
for (Nd4jLong e = tid; e < length; e += blockDim.x * gridDim.x) {
auto _x = static_cast<unsigned long long>(x[shape::getIndexOffset(e, xShapeInfo, length)]);
auto _y = static_cast<unsigned long long>(y[shape::getIndexOffset(e, yShapeInfo, length)]);
// we save intermediate result into shared memory
shared[threadIdx.x] += __popcll(_x ^ _y);
}
__syncthreads();
// now we accumulate values
auto numItems = nd4j::math::nd4j_min<Nd4jLong>(blockDim.x, length);
auto floorPow2 = numItems;
if (floorPow2 & (floorPow2 - 1)) {
while (floorPow2 & (floorPow2 - 1))
floorPow2 &= floorPow2 - 1;
if (threadIdx.x >= floorPow2)
shared[threadIdx.x - floorPow2] = shared[threadIdx.x - floorPow2] + shared[threadIdx.x];
__syncthreads();
}
__syncthreads();
for (Nd4jLong activeThreads = floorPow2 >> 1; activeThreads; activeThreads >>= 1) {
if (threadIdx.x < activeThreads && threadIdx.x + activeThreads < numItems)
shared[threadIdx.x] = shared[threadIdx.x] + shared[threadIdx.x + activeThreads];
__syncthreads();
}
__syncthreads();
// FIXME: do we really want atomicAdd on global memory here
// and store them to output
if (threadIdx.x == 0 && shared[0] > 0)
nd4j::math::atomics::nd4j_atomicAdd<Z>(&z[0], static_cast<Z>(shared[threadIdx.x]));
}
template <typename X, typename Z>
static void _hamming(LaunchContext *context, NDArray &x, NDArray &y, NDArray &z) {
_hammingKernel<X, Z><<<256, 256, 256 * sizeof(Nd4jLong) + 256, *context->getCudaStream()>>>(x.specialBuffer(), x.specialShapeInfo(), y.specialBuffer(), y.specialShapeInfo(), z.specialBuffer(), nullptr, x.lengthOf());
}
void hamming(LaunchContext *context, NDArray &x, NDArray &y, NDArray &output) {
NDArray::prepareSpecialUse({&output}, {&x, &y});
BUILD_DOUBLE_SELECTOR(x.dataType(), output.dataType(), _hamming, (context, x, y, output), INTEGER_TYPES, INDEXING_TYPES);
NDArray::registerSpecialUse({&output}, {&x, &y});
}
}
}
}

View File

@ -0,0 +1,32 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef SD_HAMMING_H
#define SD_HAMMING_H
namespace nd4j {
namespace ops {
namespace helpers {
void hamming(LaunchContext *context, NDArray &x, NDArray &y, NDArray &output);
}
}
}
#endif //DEV_TESTS_HAMMING_H

View File

@ -117,4 +117,20 @@ TEST_F(DeclarableOpsTests16, test_svd_1) {
auto status = op.execute({&x}, {&z}, {}, {0, 0, 16}, {}); auto status = op.execute({&x}, {&z}, {}, {0, 0, 16}, {});
ASSERT_EQ(Status::OK(), status); ASSERT_EQ(Status::OK(), status);
}
TEST_F(DeclarableOpsTests16, test_hamming_distance_1) {
auto x = NDArrayFactory::create<Nd4jLong>({37, 37, 37});
auto y = NDArrayFactory::create<Nd4jLong>({8723, 8723, 8723});
auto e = NDArrayFactory::create<Nd4jLong>(18);
nd4j::ops::bits_hamming_distance op;
auto result = op.execute({&x, &y}, {}, {});
ASSERT_EQ(Status::OK(), result->status());
auto z = result->at(0);
ASSERT_EQ(e, *z);
delete result;
} }