/******************************************************************************* * Copyright (c) 2019 Konduit K.K. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author raver119@gmail.com // #include "testlayers.h" #include #include #include #include #include #include using namespace nd4j; class AtomicTests : public testing::Test { public: AtomicTests() { // } }; template static _CUDA_G void multiplyKernel(void *vbuffer, uint64_t length, void *vresult) { auto buffer = reinterpret_cast(vbuffer); auto result = reinterpret_cast(vresult); auto tid = blockIdx.x * blockDim.x + threadIdx.x; for (auto e = tid; e < length; e += gridDim.x * blockDim.x) { auto rem = e % 4; auto i = (e - rem) / 4; nd4j::math::atomics::nd4j_atomicMul(&result[i], buffer[e]); } } template static void multiplyLauncher(void *vbuffer, uint64_t length, void *vresult) { multiplyKernel<<<256, 256, 1024, *nd4j::LaunchContext::defaultContext()->getCudaStream()>>>(vbuffer, length, vresult); auto err = cudaStreamSynchronize(*nd4j::LaunchContext::defaultContext()->getCudaStream()); if (err != 0) nd4j::cuda_exception::build("multiply failed", err); } static void multiplyHost(NDArray &input, NDArray &output) { BUILD_SINGLE_SELECTOR(input.dataType(), multiplyLauncher, (input.specialBuffer(), input.lengthOf(), output.specialBuffer()), NUMERIC_TYPES); } TEST_F(AtomicTests, test_multiply) { std::vector dtypes = {nd4j::DataType::FLOAT32, nd4j::DataType::DOUBLE, nd4j::DataType::INT16}; for (auto t:dtypes) { nd4j_printf("Trying data type [%s]\n", DataTypeUtils::asString(t).c_str()); NDArray input('c', {4, 25}, t); NDArray output('c', {input.lengthOf() / 4}, t); NDArray exp = output.ulike(); input.assign(2); output.assign(2); exp.assign(32); multiplyHost(input, output); ASSERT_EQ(exp, output); } }