/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author raver119@gmail.com // #include #include #include #include #include #include #include #include #define CONSTANT_LIMIT 49152 __constant__ char deviceConstantMemory[CONSTANT_LIMIT]; namespace nd4j { static void* getConstantSpace() { Nd4jPointer dConstAddr; auto dZ = cudaGetSymbolAddress(reinterpret_cast(&dConstAddr), deviceConstantMemory); if (dZ != 0) throw cuda_exception::build("cudaGetSymbolAddress(...) failed", dZ); return dConstAddr; } int ConstantHelper::getCurrentDevice() { int dev = 0; auto res = cudaGetDevice(&dev); if (res != 0) throw cuda_exception::build("cudaGetDevice failed", res); return dev; } int ConstantHelper::getNumberOfDevices() { int dev = 0; auto res = cudaGetDeviceCount(&dev); if (res != 0) throw cuda_exception::build("cudaGetDeviceCount failed", res); return dev; } ConstantHelper::ConstantHelper() { auto initialDevice = getCurrentDevice(); auto numDevices = getNumberOfDevices(); _devicePointers.resize(numDevices); _deviceOffsets.resize(numDevices); _cache.resize(numDevices); _counters.resize(numDevices); // filling all pointers for (int e = 0; e < numDevices; e++) { auto res = cudaSetDevice(e); if (res != 0) throw cuda_exception::build("cudaSetDevice failed", res); auto constant = getConstantSpace(); std::map devCache; _devicePointers[e] = constant; _deviceOffsets[e] = 0; _cache[e] = devCache; _counters[e] = 0L; } // auto res = cudaSetDevice(initialDevice); if (res != 0) throw cuda_exception::build("Final cudaSetDevice failed", res); } ConstantHelper* ConstantHelper::getInstance() { if (!_INSTANCE) _INSTANCE = new nd4j::ConstantHelper(); return _INSTANCE; } void* ConstantHelper::replicatePointer(void *src, size_t numBytes, memory::Workspace *workspace) { _mutex.lock(); auto deviceId = getCurrentDevice(); Nd4jPointer constantPtr = nullptr; Nd4jLong constantOffset = 0L; if (_devicePointers[deviceId] == 0) { auto constant = getConstantSpace(); // filling default ptr, which will be 0 probably _devicePointers[deviceId] = constant; _deviceOffsets[deviceId] = 0; constantPtr = constant; } else { constantPtr = _devicePointers[deviceId]; constantOffset = _deviceOffsets[deviceId]; } if (constantOffset + numBytes >= CONSTANT_LIMIT) { int8_t *ptr = nullptr; ALLOCATE_SPECIAL(ptr, workspace, numBytes, int8_t); auto res = cudaMemcpy(ptr, src, numBytes, cudaMemcpyHostToDevice); if (res != 0) throw cuda_exception::build("cudaMemcpy failed", res); _mutex.unlock(); return ptr; } else { auto originalBytes = numBytes; auto rem = numBytes % 8; if (rem != 0) numBytes += 8 - rem; _deviceOffsets[deviceId] += numBytes; auto res = cudaMemcpyToSymbol(deviceConstantMemory, const_cast(src), originalBytes, constantOffset, cudaMemcpyHostToDevice); if (res != 0) throw cuda_exception::build("cudaMemcpyToSymbol failed", res); _mutex.unlock(); return reinterpret_cast(constantPtr) + constantOffset; } } ConstantDataBuffer* ConstantHelper::constantBuffer(const ConstantDescriptor &descriptor, nd4j::DataType dataType) { const auto deviceId = getCurrentDevice(); if (_cache[deviceId].count(descriptor) == 0) { ConstantHolder holder; _cache[deviceId][descriptor] = holder; } ConstantHolder* holder = &_cache[deviceId][descriptor]; if (holder->hasBuffer(dataType)) { return holder->getConstantDataBuffer(dataType); } else { auto numBytes = descriptor.length() * DataTypeUtils::sizeOf(dataType); auto cbuff = new int8_t[numBytes]; _counters[deviceId] += numBytes; // create buffer with this dtype if (descriptor.isFloat()) { BUILD_DOUBLE_SELECTOR(nd4j::DataType::DOUBLE, dataType, nd4j::SpecialTypeConverter::convertGeneric, (nullptr, const_cast(descriptor.floatValues().data()), descriptor.length(), cbuff), (nd4j::DataType::DOUBLE, double), LIBND4J_TYPES); } else if (descriptor.isInteger()) { BUILD_DOUBLE_SELECTOR(nd4j::DataType::INT64, dataType, nd4j::SpecialTypeConverter::convertGeneric, (nullptr, const_cast(descriptor.integerValues().data()), descriptor.length(), cbuff), (nd4j::DataType::INT64, Nd4jLong), LIBND4J_TYPES); } auto dbuff = replicatePointer(cbuff, descriptor.length() * DataTypeUtils::sizeOf(dataType)); ConstantDataBuffer dataBuffer(cbuff, dbuff, descriptor.length(), DataTypeUtils::sizeOf(dataType)); holder->addBuffer(dataBuffer, dataType); return holder->getConstantDataBuffer(dataType); } } Nd4jLong ConstantHelper::getCachedAmount(int deviceId) { int numDevices = getNumberOfDevices(); if (deviceId > numDevices || deviceId < 0) return 0L; else return _counters[deviceId]; } nd4j::ConstantHelper* nd4j::ConstantHelper::_INSTANCE = 0; }