2019-07-10 13:32:12 +02:00
|
|
|
/*******************************************************************************
|
|
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
|
|
*
|
|
|
|
* This program and the accompanying materials are made available under the
|
|
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
|
|
*
|
|
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
|
|
* License for the specific language governing permissions and limitations
|
|
|
|
* under the License.
|
|
|
|
*
|
|
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
|
|
******************************************************************************/
|
|
|
|
|
|
|
|
//
|
|
|
|
// @author raver119@gmail.com
|
|
|
|
//
|
|
|
|
|
|
|
|
#include <ops/declarable/helpers/hashcode.h>
|
|
|
|
|
2019-07-20 07:58:44 +02:00
|
|
|
|
2019-07-10 13:32:12 +02:00
|
|
|
namespace nd4j {
|
|
|
|
namespace ops {
|
|
|
|
namespace helpers {
|
2019-07-20 07:58:44 +02:00
|
|
|
template <typename T>
|
|
|
|
static __global__ void splitBufferToChuncks(T* buffer, Nd4jLong* tempBuffer, Nd4jLong numBlocks, Nd4jLong blockSize, Nd4jLong length) {
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
for (int b = blockIdx.x * blockDim.x + threadIdx.x; b < numBlocks; b += gridDim.x*blockDim.x) {
|
2019-07-20 07:58:44 +02:00
|
|
|
auto blockBuffer = buffer + b * numBlocks;
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jLong r = 1LL;
|
|
|
|
for (int e = 0; e < blockSize && e + (b * numBlocks) < length; e++) {
|
2019-07-20 07:58:44 +02:00
|
|
|
auto v = longBytes<T>(blockBuffer[e]);
|
2019-08-02 19:01:03 +02:00
|
|
|
r = 31LL * r + v;
|
2019-07-20 07:58:44 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
tempBuffer[b] = r;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
static __global__ void internalHash(Nd4jLong* tempBuffer, Nd4jLong* tempResult, Nd4jLong numBlocks, Nd4jLong blockSize, Nd4jLong lastLength) {
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
for (int b = blockIdx.x * blockDim.x + threadIdx.x; b < numBlocks; b += gridDim.x * blockDim.x) {
|
2019-07-20 07:58:44 +02:00
|
|
|
auto blockBuffer = tempBuffer + b * numBlocks;
|
2019-08-02 19:01:03 +02:00
|
|
|
Nd4jLong r = 1LL;
|
2019-07-10 13:32:12 +02:00
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
for (Nd4jLong e = 0; e < blockSize && e + (b * numBlocks) < lastLength; e++) {
|
2019-07-20 07:58:44 +02:00
|
|
|
auto v = longBytes<T>(blockBuffer[e]);
|
2019-08-02 19:01:03 +02:00
|
|
|
r = 31LL * r + v;
|
2019-07-20 07:58:44 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
tempResult[b] = r;
|
2019-08-02 19:01:03 +02:00
|
|
|
|
2019-07-20 07:58:44 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
static __global__ void lastStep(Nd4jLong* resultBuf, Nd4jLong* tempBufferA, Nd4jLong* tempResult, Nd4jLong length, Nd4jLong blockSize) {
|
|
|
|
if (threadIdx.x == 0) {
|
|
|
|
if (length <= blockSize)
|
|
|
|
*resultBuf = *tempBufferA;
|
|
|
|
else
|
|
|
|
*resultBuf = *tempResult;
|
|
|
|
}
|
2019-07-10 13:32:12 +02:00
|
|
|
}
|
2019-07-20 07:58:44 +02:00
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
void hashCode_(LaunchContext *context, NDArray &array, NDArray &result) {
|
|
|
|
auto blockSize = 32;
|
|
|
|
auto stream = context->getCudaStream();
|
|
|
|
array.syncToDevice();
|
|
|
|
|
|
|
|
NDArray::prepareSpecialUse({&result}, {&array});
|
|
|
|
auto length = array.lengthOf();
|
|
|
|
int numBlocks = length / blockSize + ((length % blockSize == 0) ? 0 : 1);
|
|
|
|
auto tempA = NDArrayFactory::create<Nd4jLong>('c', {numBlocks}, context);
|
|
|
|
auto tempB = NDArrayFactory::create<Nd4jLong>('c', { numBlocks / blockSize + 1}, context);
|
|
|
|
|
|
|
|
auto buffer = reinterpret_cast<T*>(array.specialBuffer()); //bufferAsT<T>();
|
|
|
|
auto tempBufferA = reinterpret_cast<Nd4jLong*>(tempA.specialBuffer()); //bufferAsT<Nd4jLong>();
|
|
|
|
auto tempBufferB = reinterpret_cast<Nd4jLong*>(tempB.specialBuffer()); //bufferAsT<Nd4jLong>();
|
|
|
|
|
|
|
|
// default buffer is the first one, because it might be the last one in case of small arrays (< blockSize)
|
|
|
|
auto tempBuffer = tempBufferA;
|
|
|
|
auto tempResult = tempBufferB;
|
|
|
|
|
|
|
|
// we divide array into 32 element chunks, and store intermediate results once
|
2019-08-02 19:01:03 +02:00
|
|
|
splitBufferToChuncks<T><<<numBlocks, 1, 1024, *stream>>>(buffer, tempBuffer, numBlocks, blockSize, length);
|
2019-07-20 07:58:44 +02:00
|
|
|
|
|
|
|
// we replace pointer with intermediate one, and repeat only one chunk left
|
|
|
|
int iterationCount = 0;
|
|
|
|
while (numBlocks > 1) {
|
|
|
|
int lastLength = numBlocks;
|
|
|
|
numBlocks = lastLength / blockSize + ((lastLength % blockSize == 0) ? 0 : 1);
|
|
|
|
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
internalHash<Nd4jLong><<<numBlocks, 1, 1024, *stream>>>(tempBuffer, tempResult, numBlocks, blockSize, lastLength);
|
2019-07-20 07:58:44 +02:00
|
|
|
|
|
|
|
|
|
|
|
iterationCount++;
|
|
|
|
// swapping buffers
|
|
|
|
if (iterationCount % 2 == 0) {
|
|
|
|
tempBuffer = tempBufferA;
|
|
|
|
tempResult = tempBufferB;
|
|
|
|
} else {
|
|
|
|
tempBuffer = tempBufferB;
|
|
|
|
tempResult = tempBufferA;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-08-02 19:01:03 +02:00
|
|
|
lastStep<<<1,1,128, *stream>>>(reinterpret_cast<Nd4jLong*>(result.specialBuffer()), tempBufferA, tempResult, length, blockSize);
|
|
|
|
// tempA.syncToHost();
|
|
|
|
// tempB.syncToHost();
|
|
|
|
// result.assign((length <= blockSize?tempA.e(0) : tempB.e(0)));
|
2019-07-20 07:58:44 +02:00
|
|
|
|
|
|
|
NDArray::registerSpecialUse({&result}, {&array});
|
|
|
|
}
|
|
|
|
|
|
|
|
void hashCode(LaunchContext *context, NDArray &array, NDArray &result) {
|
|
|
|
BUILD_SINGLE_SELECTOR(array.dataType(), hashCode_, (context, array, result), LIBND4J_TYPES);
|
|
|
|
}
|
|
|
|
|
|
|
|
BUILD_SINGLE_TEMPLATE(template void hashCode_, (LaunchContext* context, NDArray& array, NDArray& result), LIBND4J_TYPES);
|
2019-07-10 13:32:12 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|