/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author raver119@gmail.com // #include #include #include #include #include #ifdef __CUDABLAS__ #include #include #endif namespace sd { ExtraArguments::ExtraArguments(std::initializer_list arguments) { _fpArgs = arguments; } ExtraArguments::ExtraArguments(std::initializer_list arguments) { _intArgs = arguments; } ExtraArguments::ExtraArguments(const std::vector &arguments) { _fpArgs = arguments; } ExtraArguments::ExtraArguments(const std::vector &arguments) { _intArgs = arguments; } ExtraArguments::ExtraArguments(const std::vector &arguments) { for (const auto &v:arguments) _intArgs.emplace_back(static_cast(v)); } ExtraArguments::ExtraArguments() { // no-op } ExtraArguments::~ExtraArguments() { for (auto p:_pointers) { #ifdef __CUDABLAS__ cudaFree(p); #else // CPU branch delete[] reinterpret_cast(p); #endif } } template void ExtraArguments::convertAndCopy(Nd4jPointer pointer, Nd4jLong offset) { auto length = this->length(); auto target = reinterpret_cast(pointer); #ifdef __CUDABLAS__ target = new T[length]; #endif if (!_fpArgs.empty()) { for (int e = offset; e < _fpArgs.size(); e++) { target[e] = static_cast(_fpArgs[e]); } } else if (_intArgs.empty()) { for (int e = offset; e < _intArgs.size(); e++) { target[e] = static_cast(_intArgs[e]); } } #ifdef __CUDABLAS__ // TODO: maybe make it asynchronous eventually? cudaMemcpy(pointer, target, length * DataTypeUtils::sizeOf(DataTypeUtils::fromT()), cudaMemcpyHostToDevice); delete[] target; #endif } BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void ExtraArguments::convertAndCopy, (Nd4jPointer pointer, Nd4jLong offset), LIBND4J_TYPES); void* ExtraArguments::allocate(size_t length, size_t elementSize) { #ifdef __CUDABLAS__ Nd4jPointer ptr; auto res = cudaMalloc(reinterpret_cast(&ptr), length * elementSize); if (res != 0) throw std::runtime_error("Can't allocate CUDA memory"); #else // CPU branch auto ptr = new int8_t[length * elementSize]; if (!ptr) throw std::runtime_error("Can't allocate memory"); #endif return ptr; } size_t ExtraArguments::length() { if (!_fpArgs.empty()) return _fpArgs.size(); else if (!_intArgs.empty()) return _intArgs.size(); else return 0; } template void* ExtraArguments::argumentsAsT(Nd4jLong offset) { return argumentsAsT(DataTypeUtils::fromT(), offset); } BUILD_SINGLE_TEMPLATE(template ND4J_EXPORT void *ExtraArguments::argumentsAsT, (Nd4jLong offset), LIBND4J_TYPES); void* ExtraArguments::argumentsAsT(sd::DataType dataType, Nd4jLong offset) { if (_fpArgs.empty() && _intArgs.empty()) return nullptr; // we allocate pointer auto ptr = allocate(length() - offset, DataTypeUtils::sizeOf(dataType)); // fill it with data BUILD_SINGLE_SELECTOR(dataType, convertAndCopy, (ptr, offset), LIBND4J_TYPES); // store it internally for future release _pointers.emplace_back(ptr); return ptr; } }