diff --git a/libnd4j/CMakeLists.txt b/libnd4j/CMakeLists.txt index 106401b31..0631763c2 100755 --- a/libnd4j/CMakeLists.txt +++ b/libnd4j/CMakeLists.txt @@ -131,6 +131,23 @@ if(NOT SD_CUDA) endif() endif() +#arm-compute entry +if(${HELPERS_armcompute}) + find_package(ARMCOMPUTE REQUIRED) + + if(ARMCOMPUTE_FOUND) + message("Found ARMCOMPUTE: ${ARMCOMPUTE_LIBRARIES}") + set(HAVE_ARMCOMPUTE 1) + # Add preprocessor definition for ARM Compute NEON + add_definitions(-DARMCOMPUTENEON_ENABLED) + #build our library with neon support + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfpu=neon") + include_directories(${ARMCOMPUTE_INCLUDE}) + message("----${ARMCOMPUTE_INCLUDE}---") + endif() + +endif() + # new mkl-dnn entry if (${HELPERS_mkldnn}) diff --git a/libnd4j/blas/CMakeLists.txt b/libnd4j/blas/CMakeLists.txt index fb1dc066e..b6bd1f7c0 100755 --- a/libnd4j/blas/CMakeLists.txt +++ b/libnd4j/blas/CMakeLists.txt @@ -146,6 +146,10 @@ if (HAVE_MKLDNN) file(GLOB_RECURSE CUSTOMOPS_MKLDNN_SOURCES false ../include/ops/declarable/platform/mkldnn/*.cpp ../include/ops/declarable/platform/mkldnn/mkldnnUtils.h) endif() +if(HAVE_ARMCOMPUTE) + file(GLOB_RECURSE CUSTOMOPS_ARMCOMPUTE_SOURCES false ../include/ops/declarable/platform/armcompute/*.cpp ../include/ops/declarable/platform/armcompute/*.h) +endif() + if(SD_CUDA) message("Build cublas") find_package(CUDA) @@ -243,7 +247,7 @@ if(SD_CUDA) ${CUSTOMOPS_HELPERS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ${LOOPS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES} ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${INDEXING_SOURCES} ${EXCEPTIONS_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES} ${CUSTOMOPS_CUDNN_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES} - ${CUSTOMOPS_GENERIC_SOURCES} + ${CUSTOMOPS_ARMCOMPUTE_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ) if (WIN32) @@ -351,8 +355,8 @@ elseif(SD_CPU) add_definitions(-D__CPUBLAS__=true) add_library(samediff_obj OBJECT ${LEGACY_SOURCES} ${LOOPS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES} - ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} - ${OPS_SOURCES} ${PERF_SOURCES}) + ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES} + ${CUSTOMOPS_ARMCOMPUTE_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES}) if(IOS) add_library(${SD_LIBRARY_NAME} STATIC $) else() @@ -378,12 +382,12 @@ elseif(SD_CPU) if (NOT BLAS_LIBRARIES) set(BLAS_LIBRARIES "") endif() - target_link_libraries(${SD_LIBRARY_NAME} ${MKLDNN} ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${BLAS_LIBRARIES} ${CPU_FEATURES}) + target_link_libraries(${SD_LIBRARY_NAME} ${MKLDNN} ${MKLDNN_LIBRARIES} ${ARMCOMPUTE_LIBRARIES} ${OPENBLAS_LIBRARIES} ${BLAS_LIBRARIES} ${CPU_FEATURES}) if ("${SD_ALL_OPS}" AND "${SD_BUILD_MINIFIER}") message(STATUS "Building minifier...") add_executable(minifier ../minifier/minifier.cpp ../minifier/graphopt.cpp) - target_link_libraries(minifier samediff_obj ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES}) + target_link_libraries(minifier samediff_obj ${MKLDNN_LIBRARIES} ${ARMCOMPUTE_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES}) endif() if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND "${CMAKE_CXX_COMPILER_VERSION}" VERSION_LESS 4.9) diff --git a/libnd4j/cmake/FindARMCOMPUTE.cmake b/libnd4j/cmake/FindARMCOMPUTE.cmake new file mode 100644 index 000000000..ae0e1fbba --- /dev/null +++ b/libnd4j/cmake/FindARMCOMPUTE.cmake @@ -0,0 +1,74 @@ +################################################################################ +# Copyright (c) 2020 Konduit K.K. +# +# This program and the accompanying materials are made available under the +# terms of the Apache License, Version 2.0 which is available at +# https://www.apache.org/licenses/LICENSE-2.0. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# +# SPDX-License-Identifier: Apache-2.0 +################################################################################ + + + +### Find ARM COMPUTE LIBRARY STATIC libraries + +SET (COMPUTE_INCLUDE_DIRS + /usr/include + ${ARMCOMPUTE_ROOT} + ${ARMCOMPUTE_ROOT}/include + ${ARMCOMPUTE_ROOT}/applications + ${ARMCOMPUTE_ROOT}/applications/arm_compute +) + + +SET (COMPUTE_LIB_DIRS + /lib + /usr/lib + ${ARMCOMPUTE_ROOT} + ${ARMCOMPUTE_ROOT}/lib + ${ARMCOMPUTE_ROOT}/build +) + +find_path(ARMCOMPUTE_INCLUDE arm_compute/core/CL/ICLKernel.h + PATHS ${COMPUTE_INCLUDE_DIRS} + NO_DEFAULT_PATH NO_CMAKE_FIND_ROOT_PATH) + +find_path(ARMCOMPUTE_INCLUDE arm_compute/core/CL/ICLKernel.h) + +find_path(HALF_INCLUDE half/half.hpp) +find_path(HALF_INCLUDE half/half.hpp + PATHS ${ARMCOMPUTE_ROOT}/include + NO_DEFAULT_PATH NO_CMAKE_FIND_ROOT_PATH) +include_directories(SYSTEM ${HALF_INCLUDE}) + +# Find the Arm Compute libraries if not already specified +if (NOT DEFINED ARMCOMPUTE_LIBRARIES) + + find_library(ARMCOMPUTE_LIBRARY NAMES arm_compute-static + PATHS ${COMPUTE_LIB_DIRS} + PATH_SUFFIXES "Release" + NO_DEFAULT_PATH NO_CMAKE_FIND_ROOT_PATH) + + find_library(ARMCOMPUTE_CORE_LIBRARY NAMES arm_compute_core-static + PATHS ${COMPUTE_LIB_DIRS} + PATH_SUFFIXES "Release" + NO_DEFAULT_PATH NO_CMAKE_FIND_ROOT_PATH) + # In case it wasn't there, try a default search (will work in cases where + # the library has been installed into a standard location) + find_library(ARMCOMPUTE_LIBRARY NAMES arm_compute-static) + find_library(ARMCOMPUTE_CORE_LIBRARY NAMES arm_compute_core-static) + + set(ARMCOMPUTE_LIBRARIES ${ARMCOMPUTE_LIBRARY} ${ARMCOMPUTE_CORE_LIBRARY} ) +endif() + + +INCLUDE(FindPackageHandleStandardArgs) + +FIND_PACKAGE_HANDLE_STANDARD_ARGS(ARMCOMPUTE REQUIRED_VARS ARMCOMPUTE_INCLUDE ARMCOMPUTE_LIBRARIES) + diff --git a/libnd4j/include/config.h.in b/libnd4j/include/config.h.in index 1e63552d0..c858dd765 100644 --- a/libnd4j/include/config.h.in +++ b/libnd4j/include/config.h.in @@ -3,6 +3,8 @@ #cmakedefine HAVE_MKLDNN +#cmakedefine HAVE_ARMCOMPUTE + #cmakedefine MKLDNN_PATH "@MKLDNN_PATH@" #cmakedefine HAVE_OPENBLAS diff --git a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp index 8f45c696b..7e66d4b11 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/lup.cpp @@ -215,7 +215,9 @@ namespace helpers { auto maxValue = T(0); //sd::math::nd4j_abs(compoundBuffer[xInitialIndex]); auto result = -1; //auto loop = PRAGMA_THREADS_FOR { - auto start = column, stop = rowNum, increment = 1; + auto start = column; + auto stop = rowNum; + auto increment = 1; for (auto rowCounter = start; rowCounter < stop; rowCounter++) { Nd4jLong xPos[] = {rowCounter, column}; auto xIndex = shape::getOffset(compoundShape, xPos, 0); diff --git a/libnd4j/include/ops/declarable/platform/armcompute/armcomputeUtils.cpp b/libnd4j/include/ops/declarable/platform/armcompute/armcomputeUtils.cpp new file mode 100644 index 000000000..66b472252 --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/armcompute/armcomputeUtils.cpp @@ -0,0 +1,278 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // Created by Abdelrauf 2020 + + +#include +#include +#include +#include +#include +#include + +#include "armcomputeUtils.h" + + +namespace sd { +namespace ops { +namespace platforms { + + + +Arm_DataType getArmType ( const DataType &dType){ + Arm_DataType ret; + switch (dType){ + case HALF : + ret = Arm_DataType::F16; + break; + case FLOAT32 : + ret = Arm_DataType::F32; + break; + case DOUBLE : + ret = Arm_DataType::F64; + break; + case INT8 : + ret = Arm_DataType::S8; + break; + case INT16 : + ret = Arm_DataType::S16; + break; + case INT32 : + ret = Arm_DataType::S32; + break; + case INT64 : + ret = Arm_DataType::S64; + break; + case UINT8 : + ret = Arm_DataType::U8; + break; + case UINT16 : + ret = Arm_DataType::U16; + break; + case UINT32 : + ret = Arm_DataType::U32; + break; + case UINT64 : + ret = Arm_DataType::U64; + break; + case BFLOAT16 : + ret = Arm_DataType::BFLOAT16; + break; + default: + ret = Arm_DataType::UNKNOWN; + }; + + return ret; +} +bool isArmcomputeFriendly(const NDArray& arr) { + auto dType = getArmType(arr.dataType()); + int rank = (int)(arr.rankOf()); + return dType != Arm_DataType::UNKNOWN && + rank<=arm_compute::MAX_DIMS && + arr.ordering() == 'c' && + arr.ews()==1 && + shape::strideDescendingCAscendingF(arr.shapeInfo()) == true; +} + +Arm_TensorInfo getArmTensorInfo(int rank, Nd4jLong* bases,sd::DataType ndArrayType, arm_compute::DataLayout layout) { + constexpr int numChannels = 1; + auto dType = getArmType(ndArrayType); + + Arm_TensorShape shape; + shape.set_num_dimensions(rank); + for (int i = 0, j = rank - 1; i < rank; i++, j--) { + shape[i] = static_cast(bases[j]); + } + // fill the rest unused with 1 + for (int i = rank; i < arm_compute::MAX_DIMS; i++) { + shape[i] = 1; + } + + return Arm_TensorInfo(shape, numChannels, dType, layout); +} + +Arm_TensorInfo getArmTensorInfo(const NDArray& arr, + arm_compute::DataLayout layout) { + auto dType = getArmType(arr.dataType()); + + // + constexpr int numChannels = 1; + int rank = (int)(arr.rankOf()); + auto bases = arr.shapeOf(); + auto arrStrides = arr.stridesOf(); + + // https://arm-software.github.io/ComputeLibrary/v20.05/_dimensions_8h_source.xhtml + // note: underhood it is stored as std::array _id; + // TensorShape is derived from Dimensions + // as well as Strides : public Dimensions + Arm_TensorShape shape; + Arm_Strides strides; + shape.set_num_dimensions(rank); + strides.set_num_dimensions(rank); + size_t element_size = arm_compute::data_size_from_type(dType); + for (int i = 0, j = rank - 1; i < rank; i++, j--) { + shape[i] = static_cast(bases[j]); + strides[i] = static_cast(arrStrides[j]) * element_size; + } + // fill the rest unused with 1 + for (int i = rank; i < arm_compute::MAX_DIMS; i++) { + shape[i] = 1; + } + size_t total_size; + size_t size_ind = rank - 1; + total_size = shape[size_ind] * strides[size_ind]; + + Arm_TensorInfo info; + info.init(shape, numChannels, dType, strides, 0, total_size); + info.set_data_layout(layout); + + return info; +} + +Arm_Tensor getArmTensor(const NDArray& arr, arm_compute::DataLayout layout) { + // - Ownership of the backing memory is not transferred to the tensor itself. + // - The tensor mustn't be memory managed. + // - Padding requirements should be accounted by the client code. + // In other words, if padding is required by the tensor after the function + // configuration step, then the imported backing memory should account for it. + // Padding can be checked through the TensorInfo::padding() interface. + + // Import existing pointer as backing memory + auto info = getArmTensorInfo(arr, layout); + Arm_Tensor tensor; + tensor.allocator()->init(info); + void* buff = (void*)arr.buffer(); + tensor.allocator()->import_memory(buff); + return tensor; +} + +void copyFromTensor(const Arm_Tensor& inTensor, NDArray& output) { + //only for C order + //only for C order + if (output.ordering() != 'c') return; + auto shapeInfo = output.shapeInfo(); + auto bases = &(shapeInfo[1]); + Nd4jLong rank = shapeInfo[0]; + auto strides = output.stridesOf(); + int width = bases[rank - 1]; + uint8_t* outputBuffer = (uint8_t*)output.buffer(); + size_t offset = 0; + arm_compute::Window window; + arm_compute::Iterator tensor_it(&inTensor, window); + + int element_size = inTensor.info()->element_size(); + window.use_tensor_dimensions(inTensor.info()->tensor_shape(), /* first_dimension =*/arm_compute::Window::DimY); + +// if (output.ews() == 1) { + auto copySize = width * element_size; + auto dest = outputBuffer; + arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates& id) + { + auto src = tensor_it.ptr(); + memcpy(dest, src, copySize); + dest += copySize; + }, + tensor_it); + // } + // else { + // Nd4jLong coords[MAX_RANK] = {}; + // if(strides[rank-1]!=1){ + // throw std::runtime_error( "not implemented for subarrays whose last stride is not 1"); + // //TODO: implement to work with all subarrays properly + // } + // arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates& id) + // { + // auto src = tensor_it.ptr(); + // auto dest = outputBuffer + offset * element_size; + // memcpy(dest, src, width * element_size); + // offset = sd::inc_coords(bases, strides, coords, offset, rank, 1); + // }, + // tensor_it); + // } +} + +void copyToTensor(const NDArray& input, Arm_Tensor& outTensor) { + //only for C order + if (input.ordering() != 'c') return; + auto shapeInfo = input.shapeInfo(); + auto bases = &(shapeInfo[1]); + Nd4jLong rank = shapeInfo[0]; + auto strides = input.stridesOf(); + uint8_t *inputBuffer = (uint8_t*)input.buffer(); + int width = bases[rank - 1]; + size_t offset = 0; + arm_compute::Window window; + arm_compute::Iterator tensor_it(&outTensor, window); + int element_size = outTensor.info()->element_size(); + + window.use_tensor_dimensions(outTensor.info()->tensor_shape(), /* first_dimension =*/arm_compute::Window::DimY); + +// if (input.ews() == 1) { + + auto copySize = width * element_size; + auto src = inputBuffer; + arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates& id) + { + auto dest = tensor_it.ptr(); + memcpy(dest,src, copySize); + src += copySize; + }, + tensor_it); +// } +// else { +// Nd4jLong coords[MAX_RANK] = {}; +// if(strides[rank-1]!=1){ +// throw std::runtime_error( "not implemented for subarrays whose last stride is not 1"); +// //TODO: implement to work with all subarrays properly +// } +// arm_compute::execute_window_loop(window, [&](const arm_compute::Coordinates& id) +// { +// auto dest = tensor_it.ptr(); +// auto src = inputBuffer + offset * element_size; +// offset = sd::inc_coords(bases, strides, coords, offset, rank, 1); +// }, +// tensor_it); +// } +} + + +// armcompute should be built with debug option +void print_tensor(Arm_ITensor& tensor, const char* msg) { + auto info = tensor.info(); + auto padding = info->padding(); + std::cout << msg << "\ntotal: " << info->total_size() << "\n"; + + for (int i = 0; i < arm_compute::MAX_DIMS; i++) { + std::cout << info->dimension(i) << ","; + } + std::cout << std::endl; + for (int i = 0; i < arm_compute::MAX_DIMS; i++) { + std::cout << info->strides_in_bytes()[i] << ","; + } + std::cout << "\npadding: l " << padding.left << ", r " << padding.right + << ", t " << padding.top << ", b " << padding.bottom << std::endl; + +#ifdef ARM_COMPUTE_ASSERTS_ENABLED + //note it did not print correctly fro NHWC + std::cout << msg << ":\n"; + tensor.print(std::cout); + std::cout << std::endl; +#endif +} + +} +} +} diff --git a/libnd4j/include/ops/declarable/platform/armcompute/armcomputeUtils.h b/libnd4j/include/ops/declarable/platform/armcompute/armcomputeUtils.h new file mode 100644 index 000000000..72a4e6e89 --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/armcompute/armcomputeUtils.h @@ -0,0 +1,133 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + +#ifndef DEV_TESTSARMCOMPUTEUTILS_H +#define DEV_TESTSARMCOMPUTEUTILS_H + + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace samediff; + + +namespace sd { + namespace ops { + namespace platforms { + + using Arm_DataType = arm_compute::DataType; + using Arm_Tensor = arm_compute::Tensor; + using Arm_ITensor = arm_compute::ITensor; + using Arm_TensorInfo = arm_compute::TensorInfo; + using Arm_TensorShape = arm_compute::TensorShape; + using Arm_Strides = arm_compute::Strides; + /** + * Here we actually declare our platform helpers + */ + + + DECLARE_PLATFORM(maxpool2d, ENGINE_CPU); + + DECLARE_PLATFORM(avgpool2d, ENGINE_CPU); + + //utils + Arm_DataType getArmType(const sd::DataType& dType); + + Arm_TensorInfo getArmTensorInfo(int rank, Nd4jLong* bases, sd::DataType ndArrayType, arm_compute::DataLayout layout = arm_compute::DataLayout::UNKNOWN); + + Arm_TensorInfo getArmTensorInfo(const NDArray& arr, arm_compute::DataLayout layout = arm_compute::DataLayout::UNKNOWN); + + Arm_Tensor getArmTensor(const NDArray& arr, arm_compute::DataLayout layout = arm_compute::DataLayout::UNKNOWN); + + void copyFromTensor(const Arm_Tensor& inTensor, NDArray& output); + void copyToTensor(const NDArray& input, Arm_Tensor& outTensor); + void print_tensor(Arm_ITensor& tensor, const char* msg); + bool isArmcomputeFriendly(const NDArray& arr); + + + template + class ArmFunction { + public: + + template + void configure(NDArray *input , NDArray *output, arm_compute::DataLayout layout, Args&& ...args) { + + auto inInfo = getArmTensorInfo(*input, layout); + auto outInfo = getArmTensorInfo(*output, layout); + in.allocator()->init(inInfo); + out.allocator()->init(outInfo); + armFunction.configure(&in,&out,std::forward(args) ...); + if (in.info()->has_padding()) { + //allocate and copy + in.allocator()->allocate(); + //copy + copyToTensor(*input, in); + + } + else { + //import buffer + void* buff = input->buffer(); + in.allocator()->import_memory(buff); + } + if (out.info()->has_padding()) { + //store pointer to our array to copy after run + out.allocator()->allocate(); + outNd = output; + } + else { + //import + void* buff = output->buffer(); + out.allocator()->import_memory(buff); + } + + } + + void run() { + armFunction.run(); + if (outNd) { + copyFromTensor(out, *outNd); + } + } + + private: + Arm_Tensor in; + Arm_Tensor out; + NDArray *outNd=nullptr; + F armFunction{}; + }; + } + } +} + + + +#endif //DEV_TESTSARMCOMPUTEUTILS_H diff --git a/libnd4j/include/ops/declarable/platform/armcompute/avgpooling2d.cpp b/libnd4j/include/ops/declarable/platform/armcompute/avgpooling2d.cpp new file mode 100644 index 000000000..d8413104d --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/armcompute/avgpooling2d.cpp @@ -0,0 +1,106 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // Created by Abdelrauf (rauf@konduit.ai) 2020 + +#include +#include +#include +#include + + +#include "armcomputeUtils.h" + + +namespace sd { +namespace ops { +namespace platforms { + + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(avgpool2d, ENGINE_CPU) { + + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + + const auto kH = INT_ARG(0); + const auto kW = INT_ARG(1); + const auto sH = INT_ARG(2); + const auto sW = INT_ARG(3); + auto pH = INT_ARG(4); + auto pW = INT_ARG(5); + const auto dH = INT_ARG(6); + const auto dW = INT_ARG(7); + const auto paddingMode = INT_ARG(8); + const auto extraParam0 = INT_ARG(9); + const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 0-NCHW, 1-NHWC + + REQUIRE_TRUE(input->rankOf() == 4, 0, "AVGPOOL2D ARMCOMPUTE op: input should have rank of 4, but got %i instead", input->rankOf()); + REQUIRE_TRUE(dH != 0 && dW != 0, 0, "AVGPOOL2D ARMCOMPUTE op: dilation must not be zero, but got instead {%i, %i}", dH, dW); + + bool exclude_padding= (extraParam0 == 0) ? true : false; + + auto dataLayout = isNCHW ? arm_compute::DataLayout::NCHW : arm_compute::DataLayout::NHWC; + + // Calculate individual paddings + unsigned int pad_left, pad_top, pad_right, pad_bottom; + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + + if(paddingMode){ + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + } + pad_left = pW; + pad_top = pH; + pad_right = (oW - 1) * sW - iW + kW - pW ; + pad_bottom = (oH - 1) * sH - iH + kH - pH ; + +#if 0 + nd4j_printf("avgpool kH = %d, kW = %d, sH = %d, sW = %d , pH = %d , pW = %d, dH = %d, dW = %d, paddingMode = %d , isNCHW %d exclude pad %d \n" , kH , kW , sH , sW , pH + , pW , dH , dW , paddingMode,isNCHW?1:0 ,exclude_padding?1:0); +#endif + auto poolPad = arm_compute::PadStrideInfo(sW, sH, pad_left,pad_right, pad_top, pad_bottom, arm_compute::DimensionRoundingType::FLOOR); + auto poolInfo = arm_compute::PoolingLayerInfo(arm_compute::PoolingType::AVG, arm_compute::Size2D(kW, kH), dataLayout, poolPad, exclude_padding); + ArmFunction pool; + pool.configure(input,output, dataLayout, poolInfo); + + pool.run(); // run function + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(avgpool2d, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + const int dH = INT_ARG(6); + const int dW = INT_ARG(7); + // Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32 + auto dTypeInput = getArmType(input->dataType()); + auto dTypeOutput = getArmType(output->dataType()); + bool is_supported = dH==1 && dW==1 && isArmcomputeFriendly(*input) && isArmcomputeFriendly(*output) + && (dTypeInput ==Arm_DataType::F32) + && (dTypeOutput ==Arm_DataType::F32); + return is_supported; +} + + + +} +} +} diff --git a/libnd4j/include/ops/declarable/platform/armcompute/maxpooling2d.cpp b/libnd4j/include/ops/declarable/platform/armcompute/maxpooling2d.cpp new file mode 100644 index 000000000..cd6779628 --- /dev/null +++ b/libnd4j/include/ops/declarable/platform/armcompute/maxpooling2d.cpp @@ -0,0 +1,106 @@ +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // Created by Abdelrauf 2020 + + +#include +#include +#include +#include + + +#include "armcomputeUtils.h" + + +namespace sd { +namespace ops { +namespace platforms { + + +////////////////////////////////////////////////////////////////////////// +PLATFORM_IMPL(maxpool2d, ENGINE_CPU) { + + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + REQUIRE_TRUE(input->rankOf() == 4, 0, "MAXPOOL2D ARMCOMPUTE OP: input array should have rank of 4, but got %i instead", input->rankOf()); + + // 0,1 - kernel Height/Width; 2,3 - stride Height/Width; 4,5 - pad Height/Width; 6,7 - dilation Height/Width; 8 - same mode; + const int kH = INT_ARG(0); + const int kW = INT_ARG(1); + const int sH = INT_ARG(2); + const int sW = INT_ARG(3); + int pH = INT_ARG(4); + int pW = INT_ARG(5); + const int dH = INT_ARG(6); + const int dW = INT_ARG(7); + const int paddingMode = INT_ARG(8); + // const int extraParam0 = INT_ARG(9); + const int isNCHW = block.getIArguments()->size() > 10 ? !INT_ARG(10) : 1; // INT_ARG(10): 1-NHWC, 0-NCHW + + REQUIRE_TRUE(dH != 0 && dW != 0, 0, "MAXPOOL2D MKLDNN op: dilation must not be zero, but got instead {%i, %i}", dH, dW); + + auto dataLayout = isNCHW ? arm_compute::DataLayout::NCHW : arm_compute::DataLayout::NHWC; + + // Calculate individual paddings + unsigned int pad_left, pad_top, pad_right, pad_bottom; + int bS, iC, iH, iW, oC, oH, oW; // batch size, input channels, input height/width, output channels, output height/width; + int indIOioC, indIiH, indWoC, indWiC, indWkH, indOoH; // corresponding indexes + ConvolutionUtils::getSizesAndIndexesConv2d(isNCHW, 0, *input, *output, bS, iC, iH, iW, oC, oH, oW, indIOioC, indIiH, indWiC, indWoC, indWkH, indOoH); + + if(paddingMode){ + ConvolutionUtils::calcPadding2D(pH, pW, oH, oW, iH, iW, kH, kW, sH, sW, dH, dW); + } + pad_left = pW; + pad_top = pH; + pad_right = (oW - 1) * sW - iW + kW - pW ; + pad_bottom = (oH - 1) * sH - iH + kH - pH ; +#if 0 + nd4j_printf("avgpool kH = %d, kW = %d, sH = %d, sW = %d , pH = %d , pW = %d, dH = %d, dW = %d, paddingMode = %d , isNCHW %d exclude pad %d \n" , kH , kW , sH , sW , pH + , pW , dH , dW , paddingMode,isNCHW?1:0 ,exclude_padding?1:0); +#endif + + auto poolPad = arm_compute::PadStrideInfo(sW, sH, pad_left,pad_right, pad_top, pad_bottom, arm_compute::DimensionRoundingType::FLOOR); + auto poolInfo = arm_compute::PoolingLayerInfo(arm_compute::PoolingType::MAX, arm_compute::Size2D(kW, kH), dataLayout, poolPad); + ArmFunction pool; + + pool.configure(input,output, dataLayout, poolInfo); + + pool.run(); // run function + + return Status::OK(); +} + +////////////////////////////////////////////////////////////////////////// +PLATFORM_CHECK(maxpool2d, ENGINE_CPU) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + const int dH = INT_ARG(6); + const int dW = INT_ARG(7); + // Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32 + auto dTypeInput = getArmType(input->dataType()); + auto dTypeOutput = getArmType(output->dataType()); + bool is_supported = dH==1 && dW==1 && isArmcomputeFriendly(*input) && isArmcomputeFriendly(*output) + && (dTypeInput ==Arm_DataType::F32) + && (dTypeOutput ==Arm_DataType::F32); + return is_supported; +} + + + +} +} +} diff --git a/libnd4j/pi_build.sh b/libnd4j/pi_build.sh new file mode 100755 index 000000000..f96c3f1f1 --- /dev/null +++ b/libnd4j/pi_build.sh @@ -0,0 +1,185 @@ +#!/bin/bash +TARGET=armv7-a +BLAS_TARGET_NAME=ARMV7 +ARMCOMPUTE_TARGET=armv7a +#BASE_DIR=${HOME}/pi +#https://stackoverflow.com/questions/59895/how-to-get-the-source-directory-of-a-bash-script-from-within-the-script-itself +SOURCE="${BASH_SOURCE[0]}" +ARMCOMPUTE_DEBUG=1 +LIBND4J_BUILD_MODE=Release +while [ -h "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symlink + DIR="$( cd -P "$( dirname "$SOURCE" )" >/dev/null 2>&1 && pwd )" + SOURCE="$(readlink "$SOURCE")" + [[ $SOURCE != /* ]] && SOURCE="$DIR/$SOURCE" # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located +done +BASE_DIR="$( cd -P "$( dirname "$SOURCE" )" >/dev/null 2>&1 && pwd )" +CMAKE=cmake #/snap/bin/cmake + +mkdir -p ${BASE_DIR}/helper_bin/ + +CROSS_COMPILER_URL=https://sourceforge.net/projects/raspberry-pi-cross-compilers/files/Raspberry%20Pi%20GCC%20Cross-Compiler%20Toolchains/Buster/GCC%208.3.0/Raspberry%20Pi%203A%2B%2C%203B%2B%2C%204/cross-gcc-8.3.0-pi_3%2B.tar.gz/download +CROSS_COMPILER_DIR=${BASE_DIR}/helper_bin/cross_compiler + +SCONS_LOCAL_URL=http://prdownloads.sourceforge.net/scons/scons-local-3.1.1.tar.gz +SCONS_LOCAL_DIR=${BASE_DIR}/helper_bin/scons_local + +THIRD_PARTY=${BASE_DIR}/third_party_libs + +ARMCOMPUTE_GIT_URL=https://github.com/ARM-software/ComputeLibrary.git +ARMCOMPUTE_TAG=v20.05 +ARMCOMPUTE_DIR=${THIRD_PARTY}/arm_compute_dir + +OPENBLAS_GIT_URL="https://github.com/xianyi/OpenBLAS.git" +OPENBLAS_DIR=${THIRD_PARTY}/OpenBLAS + + +LIBND4J_SRC_DIR=${BASE_DIR} + +LIBND4J_BUILD_DIR=${BASE_DIR}/build_pi + +#for some downloads +XRTACT_STRIP="--strip-components=1" + +HAS_ARMCOMPUTE=1 +mkdir -p ${BASE_DIR} +mkdir -p ${THIRD_PARTY} + +#change directory to base +cd $BASE_DIR + +function message { + echo "BUILDER:::: ${@}" +} + + +function check_requirements { + for i in "${@}" + do + if [ ! -e "$i" ]; then + message "missing: ${i}" + exit -2 + fi + done +} + +function download_extract { + #$1 is url #2 is dir $3 is extract argument + if [ ! -f ${2}_file ]; then + message "download" + wget --quiet --show-progress -O ${2}_file ${1} + fi + + message "extract" + #extract + mkdir -p ${2} + command="tar -xzf ${2}_file --directory=${2} ${3} " + message $command + $command + + check_requirements "${2}" +} + +function git_check { + #$1 is url #$2 is dir #$3 is tag or branch if optional + command="git clone --quiet ${1} ${2}" + message "$command" + $command + if [ -n "$3" ]; then + cd ${2} + command="git checkout ${3}" + message "$command" + $command + cd ${BASE_DIR} + fi + check_requirements "${2}" +} + + +if [ ! -d ${CROSS_COMPILER_DIR} ]; then + #out file + message "download CROSS_COMPILER" + download_extract ${CROSS_COMPILER_URL} ${CROSS_COMPILER_DIR} ${XRTACT_STRIP} +fi + +#useful exports +export PI_FOLDER=${CROSS_COMPILER_DIR} +export RPI_BIN=${PI_FOLDER}/bin/arm-linux-gnueabihf +export PI_SYS_ROOT=${PI_FOLDER}/arm-linux-gnueabihf/libc +export LD_LIBRARY_PATH=${PI_FOLDER}/lib:$LD_LIBRARY_PATH +export CC=${RPI_BIN}-gcc +export FC=${RPI_BIN}-gfortran +export CXX=${RPI_BIN}-g++ +export CPP=${RPI_BIN}-cpp +export RANLIB=${RPI_BIN}-gcc-ranlib +export LD="${RPI_BIN}-ld" +export AR="${RPI_BIN}-ar" + + +#lets build OpenBlas +if [ ! -d "${OPENBLAS_DIR}" ]; then + message "download OpenBLAS" + git_check "${OPENBLAS_GIT_URL}" "${OPENBLAS_DIR}" +fi + +if [ ! -f "${THIRD_PARTY}/lib/libopenblas.so" ]; then + message "build and install OpenBLAS" + cd ${OPENBLAS_DIR} + + command="make TARGET=${BLAS_TARGET_NAME} HOSTCC=gcc CC=${CC} USE_THREAD=0 NOFORTRAN=1 CFLAGS=--sysroot=${PI_SYS_ROOT} LDFLAGS=\"-L${PI_SYS_ROOT}/../lib/ -lm\" &>/dev/null" + message $command + eval $command + message "install it" + command="make PREFIX=${THIRD_PARTY} install" + message $command + $command + cd $BASE_DIR + +fi +check_requirements ${THIRD_PARTY}/lib/libopenblas.so + + + +if [ ! -d ${SCONS_LOCAL_DIR} ]; then + #out file + message "download Scons local" + download_extract ${SCONS_LOCAL_URL} ${SCONS_LOCAL_DIR} +fi +check_requirements ${SCONS_LOCAL_DIR}/scons.py + + +if [ ! -d "${ARMCOMPUTE_DIR}" ]; then + message "download ArmCompute Source" + git_check ${ARMCOMPUTE_GIT_URL} "${ARMCOMPUTE_DIR}" "tags/${ARMCOMPUTE_TAG}" +fi + +#build armcompute +if [ ! -f "${ARMCOMPUTE_DIR}/build/libarm_compute-static.a" ]; then +message "build arm compute" +cd ${ARMCOMPUTE_DIR} +command="CC=gcc CXX=g++ python3 ${SCONS_LOCAL_DIR}/scons.py Werror=1 -j$(nproc) toolchain_prefix=${RPI_BIN}- debug=${ARMCOMPUTE_DEBUG} neon=1 opencl=0 extra_cxx_flags=-fPIC os=linux build=cross_compile arch=${ARMCOMPUTE_TARGET} &>/dev/null" +message $command +eval $command +cd ${BASE_DIR} +fi +check_requirements "${ARMCOMPUTE_DIR}/build/libarm_compute-static.a" "${ARMCOMPUTE_DIR}/build/libarm_compute_core-static.a" + + + +message "build cmake for LIBND4J. output: ${LIBND4J_BUILD_DIR}" + +TOOLCHAIN=${LIBND4J_SRC_DIR}/cmake/rpi.cmake +cmake_cmd="${CMAKE} -G \"Unix Makefiles\" -B${LIBND4J_BUILD_DIR} -S${LIBND4J_SRC_DIR} -DCMAKE_BUILD_TYPE=${LIBND4J_BUILD_MODE} -DCMAKE_TOOLCHAIN_FILE=${TOOLCHAIN} -DCMAKE_VERBOSE_MAKEFILE:BOOL=ON -DSD_ALL_OPS=true -DSD_CPU=true -DSD_LIBRARY_NAME=nd4jcpu -DSD_BUILD_TESTS=ON -DSD_ARM_BUILD=true -DOPENBLAS_PATH=${THIRD_PARTY} -DSD_ARCH=${TARGET} -DARMCOMPUTE_ROOT=${ARMCOMPUTE_DIR} -DHELPERS_armcompute=${HAS_ARMCOMPUTE}" +message $cmake_cmd +eval $cmake_cmd + +#build +message "lets build" + +cd ${LIBND4J_BUILD_DIR} +make -j $(nproc) + + + + + + diff --git a/libnd4j/tests_cpu/layers_tests/CMakeLists.txt b/libnd4j/tests_cpu/layers_tests/CMakeLists.txt index 563bf58f6..9478f6fe2 100644 --- a/libnd4j/tests_cpu/layers_tests/CMakeLists.txt +++ b/libnd4j/tests_cpu/layers_tests/CMakeLists.txt @@ -52,14 +52,19 @@ elseif(WIN32) set(CMAKE_CXX_FLAGS " -fPIC") endif() else() - set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3") set(CMAKE_CXX_FLAGS " -fPIC") + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3") + IF(${SD_ARCH} MATCHES "arm*") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=${SD_ARCH}") + else() + set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3") + if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64*") set(CMAKE_CXX_FLAGS " ${CMAKE_CXX_FLAGS} -mcpu=native") else() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -march=native -mtune=native") endif() - + endif() if (SD_CPU AND SD_SANITIZE) set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -fsanitize=address") else() @@ -130,7 +135,7 @@ if (SD_CPU) endif() add_executable(runtests ${TEST_SOURCES}) - target_link_libraries(runtests samediff_obj ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES} gtest gtest_main) + target_link_libraries(runtests samediff_obj ${MKLDNN_LIBRARIES} ${OPENBLAS_LIBRARIES} ${MKLDNN} ${BLAS_LIBRARIES} ${CPU_FEATURES} ${ARMCOMPUTE_LIBRARIES} gtest gtest_main) elseif(SD_CUDA) add_executable(runtests ${TEST_SOURCES}) diff --git a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp index 169c51124..39277cd87 100644 --- a/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp +++ b/libnd4j/tests_cpu/layers_tests/ConvolutionTests2.cpp @@ -1113,7 +1113,10 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_6) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); auto z = result.at(0); - +#if 0 + exp.printIndexedBuffer("Expected"); + z->printIndexedBuffer("Z"); +#endif ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -1132,7 +1135,10 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_7) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); auto z = result.at(0); - +#if 0 + exp.printIndexedBuffer("Expected"); + z->printIndexedBuffer("Z"); +#endif ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); @@ -1151,7 +1157,10 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_8) { ASSERT_EQ(ND4J_STATUS_OK, result.status()); auto z = result.at(0); - +#if 0 + exp.printIndexedBuffer("Expected"); + z->printIndexedBuffer("Z"); +#endif ASSERT_TRUE(exp.isSameShape(z)); ASSERT_TRUE(exp.equalsTo(z)); } @@ -1204,7 +1213,10 @@ TYPED_TEST(TypedConvolutionTests2, maxpool2d_10) { auto* output = results.at(0); ASSERT_EQ(Status::OK(), results.status()); - +#if 0 + expOutput.printIndexedBuffer("expOutput"); + output->printIndexedBuffer("output"); +#endif ASSERT_TRUE(expOutput.isSameShape(output)); ASSERT_TRUE(expOutput.equalsTo(output)); } diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp index d3d1deed8..beccc1aae 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp @@ -244,7 +244,8 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_decode) { #ifdef _RELEASE TEST_F(DeclarableOpsTests19, test_threshold_encode_decode_2) { // [2,1,135079944,1,1,8192,1,99] - auto initial = NDArrayFactory::create('c', {1, 135079944}); + constexpr int sizeX= 10*1000*1000; + auto initial = NDArrayFactory::create('c', {1, sizeX}); initial = 1.0f; auto exp = initial.dup(); auto neg = initial.like(); @@ -254,7 +255,7 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_decode_2) { auto enc_result = enc.evaluate({&initial}, {0.5f}); auto encoded = enc_result.at(1); - ASSERT_EQ(135079944 + 4, encoded->lengthOf()); + ASSERT_EQ(sizeX + 4, encoded->lengthOf()); ASSERT_NE(exp, initial); /* for (int e = 0; e < initial.lengthOf(); e++) { diff --git a/libnd4j/tests_cpu/layers_tests/SessionLocalTests.cpp b/libnd4j/tests_cpu/layers_tests/SessionLocalTests.cpp deleted file mode 100644 index 8481dfde5..000000000 --- a/libnd4j/tests_cpu/layers_tests/SessionLocalTests.cpp +++ /dev/null @@ -1,93 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -// -// @author raver119@gmail.com -// - -#ifndef LIBND4J_SESSIONLOCALTESTS_H -#define LIBND4J_SESSIONLOCALTESTS_H - -#include "testlayers.h" -#include -#include - -using namespace sd::graph; - -class SessionLocalTests : public testing::Test { -public: - -}; - -TEST_F(SessionLocalTests, BasicTests_1) { - VariableSpace variableSpace; - SessionLocalStorage storage(&variableSpace, nullptr); - - if (omp_get_max_threads() <= 1) - return; - - PRAGMA_OMP_PARALLEL_FOR_THREADS(4) - for (int e = 0; e < 4; e++) { - storage.startSession(); - } - - ASSERT_EQ(4, storage.numberOfSessions()); - - PRAGMA_OMP_PARALLEL_FOR_THREADS(4) - for (int e = 0; e < 4; e++) { - storage.endSession(); - } - - ASSERT_EQ(0, storage.numberOfSessions()); -} - - -TEST_F(SessionLocalTests, BasicTests_2) { - VariableSpace variableSpace; - SessionLocalStorage storage(&variableSpace, nullptr); - - if (omp_get_max_threads() <= 1) - return; - - auto alpha = sd::NDArrayFactory::create_('c',{5,5}); - alpha->assign(0.0); - - variableSpace.putVariable(-1, alpha); - - PRAGMA_OMP_PARALLEL_FOR_THREADS(4) - for (int e = 0; e < 4; e++) { - storage.startSession(); - - auto varSpace = storage.localVariableSpace(); - - auto arr = varSpace->getVariable(-1)->getNDArray(); - arr->applyScalar(sd::scalar::Add, (float) e+1, *arr); - } - - float lastValue = 0.0f; - for (int e = 1; e <= 4; e++) { - auto varSpace = storage.localVariableSpace((Nd4jLong) e); - - auto arr = varSpace->getVariable(-1)->getNDArray(); - - //nd4j_printf("Last value: %f; Current value: %f\n", lastValue, arr->e(0)); - - ASSERT_NE(lastValue, arr->e(0)); - lastValue = arr->e(0); - } -} - -#endif //LIBND4J_SESSIONLOCALTESTS_H diff --git a/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt b/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt index 7e01e2847..bbd632d27 100644 --- a/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt +++ b/libnd4j/tests_cpu/libnd4j_tests/CMakeLists.txt @@ -45,6 +45,21 @@ if ("${BUILD_MKLDNN}") set(MKLDNN dnnl) endif() +if (${HELPERS_armcompute}) + find_package(ARMCOMPUTE REQUIRED) + + if(ARMCOMPUTE_FOUND) + message("Found ARMCOMPUTE: ${ARMCOMPUTE_LIBRARIES}") + set(HAVE_ARMCOMPUTE 1) + # Add preprocessor definition for ARM Compute NEON + add_definitions(-DARMCOMPUTENEON_ENABLED) + #build our library with neon support + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mfpu=neon") + include_directories(${ARMCOMPUTE_INCLUDE}) + endif() + +endif() + # Download and unpack flatbuffers at configure time configure_file(../../CMakeLists.txt.in flatbuffers-download/CMakeLists.txt) execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . @@ -217,6 +232,10 @@ if ("${BUILD_MKLDNN}") file(GLOB_RECURSE CUSTOMOPS_PLATFORM_SOURCES false ../../include/ops/declarable/platform/mkldnn/*.cpp) endif() +if(HAVE_ARMCOMPUTE) + file(GLOB_RECURSE CUSTOMOPS_ARMCOMPUTE_SOURCES false ../include/ops/declarable/platform/armcompute/*.cpp ../include/ops/declarable/platform/armcompute/armcomputeUtils.h) +endif() + message("CPU backend") add_definitions(-D__CPUBLAS__=true) @@ -276,8 +295,9 @@ endforeach(TMP_PATH) add_executable(runtests ${LOOPS_SOURCES} ${LEGACY_SOURCES} ${EXEC_SOURCES} ${HELPERS_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES} - ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_PLATFORM_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} + ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_PLATFORM_SOURCES} + ${CUSTOMOPS_ARMCOMPUTE_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${OPS_SOURCES} ${TEST_SOURCES} ${PERF_SOURCES}) -target_link_libraries(runtests gtest ${MKLDNN} gtest_main ${BLAS_LIBRARIES}) +target_link_libraries(runtests gtest ${MKLDNN} ${ARMCOMPUTE_LIBRARIES} gtest_main ${BLAS_LIBRARIES})