[WIP] More CUDA fixes/updates (#62)
* CUDA reallocation update Signed-off-by: raver119 <raver119@gmail.com> * Legacy SoftMax/LogSoftMax/SoftMaxDerivative removed from cpp Signed-off-by: raver119 <raver119@gmail.com> * SoftMaxDerivative op removed Signed-off-by: raver119 <raver119@gmail.com> * few tests updates Signed-off-by: raver119 <raver119@gmail.com> * RNG fixes Signed-off-by: raver119 <raver119@gmail.com> * few more tests updates Signed-off-by: raver119 <raver119@gmail.com> * legacy Histogram/Pooling2D removed Signed-off-by: raver119 <raver119@gmail.com> * legacy Histogram removed Signed-off-by: raver119 <raver119@gmail.com> * histogram moved Signed-off-by: raver119 <raver119@gmail.com> * histogram moved cuda Signed-off-by: raver119 <raver119@gmail.com> * Histogram custom op Signed-off-by: raver119 <raver119@gmail.com>
This commit is contained in:
		
							parent
							
								
									5cf6859fc4
								
							
						
					
					
						commit
						fd6c0df024
					
				| @ -827,6 +827,9 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext  *lc, | ||||
|     auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); | ||||
| 
 | ||||
|     BUILD_SINGLE_SELECTOR(zType, functions::random::RandomFunction, ::execTransform(opNum, state, hZ, hZShapeInfo, extraArguments), FLOAT_TYPES); | ||||
| 
 | ||||
|     auto rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(state); | ||||
|     rng->rewindH(shape::length(hZShapeInfo)); | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////////
 | ||||
| @ -842,6 +845,9 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext  *lc, | ||||
|     auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); | ||||
| 
 | ||||
|     BUILD_SINGLE_SELECTOR(zType, functions::random::RandomFunction, ::execTransform(opNum, state, hX, hXShapeInfo, hZ, hZShapeInfo, extraArguments), FLOAT_TYPES); | ||||
| 
 | ||||
|     auto rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(state); | ||||
|     rng->rewindH(shape::length(hZShapeInfo)); | ||||
| } | ||||
| 
 | ||||
| ////////////////////////////////////////////////////////////////////////
 | ||||
| @ -859,6 +865,9 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext  *lc, | ||||
|     auto xType = nd4j::ArrayOptions::dataType(hZShapeInfo); | ||||
| 
 | ||||
|     BUILD_SINGLE_SELECTOR(xType, functions::random::RandomFunction, ::execTransform(opNum, state, hX, hXShapeInfo, hY, hYShapeInfo, hZ, hZShapeInfo, extraArguments), FLOAT_TYPES); | ||||
| 
 | ||||
|     auto rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(state); | ||||
|     rng->rewindH(shape::length(hZShapeInfo)); | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -774,78 +774,7 @@ void NativeOpExecutioner::execTransformStrict(nd4j::LaunchContext  *lc, | ||||
|     if (xType != zType || !DataTypeUtils::isR(xType)) | ||||
|         throw datatype_exception::build("NativeOpExecutioner::execTransformStrict requires X & Z to have same floating point type", xType, zType); | ||||
| 
 | ||||
|     switch (opNum) { | ||||
|         case transform::SoftMax: | ||||
|         case transform::SoftMaxDerivative: | ||||
|         case transform::LogSoftMax: { | ||||
|                 if (shape::isVector(hXShapeInfo)) { | ||||
|                     int length = shape::length(hXShapeInfo); | ||||
|                     int block = nd4j::math::nd4j_min<int>(length, 256); | ||||
| 
 | ||||
|                     launchDims.x = 1; | ||||
|                     launchDims.y = block; | ||||
|                     launchDims.z += (block * sizeof(double) * 4); | ||||
| 
 | ||||
|                     BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, lc->getAllocationPointer(), lc->getReductionPointer(), nullptr, nullptr), FLOAT_TYPES); | ||||
|                 } else { | ||||
|                     auto shape = shape::shapeOf(hXShapeInfo); | ||||
|                     auto reductionPointer = lc->getReductionPointer(); | ||||
| 					auto allocationPointer = lc->getAllocationPointer(); | ||||
| 					auto specialPointer = reinterpret_cast<double *>(allocationPointer); | ||||
| 
 | ||||
|                     // special pointer for special buffer for special ops | ||||
|                     auto dimension = reinterpret_cast<int *>(specialPointer); | ||||
|                     auto maxDimension = dimension + 1; | ||||
|                     auto maxShapeBuffer = reinterpret_cast<Nd4jLong *>(maxDimension + 1); | ||||
|                     auto special = reinterpret_cast<double *> (maxShapeBuffer + (MAX_RANK * 2 + 4)); | ||||
| 
 | ||||
| 
 | ||||
|                     Nd4jLong maxShape[2] = {shape::shapeOf(hXShapeInfo)[0], 1}; | ||||
|                     auto hostMaxShapeBuffer = shape::shapeBuffer(2, xType, maxShape); | ||||
| 
 | ||||
|                     prepareShapeBuffer<<<1, 1, 128, *stream>>>(dimension, maxDimension, maxShapeBuffer, shape[0], xType); | ||||
| 
 | ||||
|                     DEBUG_KERNEL(stream, opNum); | ||||
| 
 | ||||
|                     // max 3 | ||||
|                     execReduceSame(lc, reduce::Max, hX, hXShapeInfo, dX, dXShapeInfo, extraParams, nullptr, hostMaxShapeBuffer, special, maxShapeBuffer, maxDimension, 1, tadShapeInfo, tadOffsets); | ||||
| 
 | ||||
|                     DEBUG_KERNEL(stream, opNum); | ||||
| 
 | ||||
|                     // sub 1 | ||||
|                     execBroadcast(lc, broadcast::Subtract, hX, hXShapeInfo, dX, dXShapeInfo, nullptr, hostMaxShapeBuffer, special, maxShapeBuffer, nullptr, hZShapeInfo, dZ, dZShapeInfo, dimension, 1, tadShapeInfo, tadOffsets, nullptr, nullptr); | ||||
| 
 | ||||
|                     DEBUG_KERNEL(stream, opNum); | ||||
| 
 | ||||
|                     // exp 3 | ||||
|                     execTransformFloat(lc, transform::Exp, hZ, hZShapeInfo, dZ, dZShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, tadShapeInfo, tadOffsets); | ||||
| 
 | ||||
|                     DEBUG_KERNEL(stream, opNum); | ||||
| 
 | ||||
|                     //sum 1 | ||||
|                     execReduceSame(lc, reduce::Sum, hZ, hZShapeInfo, dZ, dZShapeInfo, extraParams, nullptr, hostMaxShapeBuffer, special, maxShapeBuffer, maxDimension, 1, tadShapeInfo, tadOffsets); | ||||
| 
 | ||||
|                     // divide 3 | ||||
|                     execBroadcast(lc, broadcast::Divide, hZ, hZShapeInfo, dZ, dZShapeInfo, nullptr, hostMaxShapeBuffer, special, maxShapeBuffer, nullptr, hZShapeInfo, dZ, dZShapeInfo, dimension, 1, tadShapeInfo, tadOffsets, nullptr, nullptr); | ||||
| 
 | ||||
|                     DEBUG_KERNEL(stream, opNum); | ||||
| 
 | ||||
|                     // log 3 | ||||
|                     if (opNum == transform::LogSoftMax) | ||||
|                         execTransformFloat(lc, transform::Log, nullptr, hZShapeInfo, dZ, dZShapeInfo, nullptr, hZShapeInfo, dZ, dZShapeInfo, extraParams, tadShapeInfo, tadOffsets); | ||||
|                     else if (opNum == transform::SoftMaxDerivative) | ||||
|                         execTransformStrict(lc, transform::SpecialDerivative, nullptr, hZShapeInfo, dZ, dZShapeInfo, nullptr, hZShapeInfo, dZ, dZShapeInfo, extraParams, tadShapeInfo, tadOffsets); | ||||
| 
 | ||||
|                     nd4j::DebugHelper::checkErrorCode(stream, "SoftMax(...) failed"); | ||||
| 
 | ||||
|                     delete hostMaxShapeBuffer; | ||||
|                 } | ||||
|             } | ||||
|             break; | ||||
|         default: { | ||||
|             BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), FLOAT_TYPES); | ||||
|         } | ||||
|     } | ||||
|     BUILD_SINGLE_SELECTOR(xType, functions::transform::TransformStrict, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), FLOAT_TYPES); | ||||
| } | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////// | ||||
| @ -869,22 +798,8 @@ void NativeOpExecutioner::execTransformFloat(nd4j::LaunchContext  *lc, | ||||
|     if (!DataTypeUtils::isR(zType)) | ||||
|         throw datatype_exception::build("NativeOpExecutioner::execTransformFloat requires Z to have floating point type", zType); | ||||
| 
 | ||||
|     if (opNum == transform::Histogram) { | ||||
|         dim3 launchDims(256, 256, 32768); | ||||
| 
 | ||||
|         Nd4jPointer maskedallocationPointer; | ||||
|         auto length = shape::length(hZShapeInfo); | ||||
|         cudaMalloc(reinterpret_cast<void **>(&maskedallocationPointer), length * launchDims.x * DataTypeUtils::sizeOf(nd4j::DataType::INT64)); | ||||
|         auto imaskedallocationPointer = reinterpret_cast<int *>(maskedallocationPointer); | ||||
| 
 | ||||
|         BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, imaskedallocationPointer, reductionPointer, nullptr, nullptr), LIBND4J_TYPES, FLOAT_TYPES); | ||||
| 
 | ||||
|         checkCudaErrors(cudaStreamSynchronize(*stream)); | ||||
|         cudaFree(maskedallocationPointer); | ||||
|     } else { | ||||
|         dim3 launchDims(512, 512, 16384); | ||||
|         BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, FLOAT_TYPES); | ||||
|     } | ||||
|     dim3 launchDims(512, 512, 16384); | ||||
|     BUILD_DOUBLE_SELECTOR(xType, zType, functions::transform::TransformFloat, ::executeTransformShaped(launchDims, stream, opNum, dX, dXShapeInfo, xRank, extraParams, dZ, dZShapeInfo, zRank, nullptr, nullptr, nullptr, nullptr), LIBND4J_TYPES, FLOAT_TYPES); | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| @ -1197,12 +1112,15 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext  *lc, | ||||
|     dim3 launchDims = dim3(512, 512, 32768); | ||||
|     auto zType = nd4j::ArrayOptions::dataType(hZShapeInfo); | ||||
| 
 | ||||
|     auto rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(stateHost); | ||||
| 
 | ||||
|     // functions::random::RandomFunction<float>::executeCudaSingle(launchDims, extraPointers, opNum, stateHost, dZ, dZShapeInfo, extraArguments), | ||||
|     BUILD_SINGLE_SELECTOR(zType, functions::random::RandomFunction, ::executeCudaSingle(launchDims, stream, opNum, stateDevice, dZ, dZShapeInfo, extraArguments), FLOAT_TYPES); | ||||
| 
 | ||||
|     checkCudaErrors(cudaMemcpyAsync(stateHost, stateDevice, sizeOf, cudaMemcpyDeviceToHost, *stream)); | ||||
|     checkCudaErrors(cudaStreamSynchronize(*stream)); | ||||
|     cudaFree(stateDevice); | ||||
| 
 | ||||
|     rng->rewindH(shape::length(hZShapeInfo)); | ||||
| } | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////// | ||||
| @ -1224,14 +1142,17 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext  *lc, | ||||
|     checkCudaErrors(cudaStreamSynchronize(*stream)); | ||||
|     checkCudaErrors(cudaMemcpyAsync(stateDevice, stateHost, sizeOf, cudaMemcpyHostToDevice, *stream)); | ||||
| 
 | ||||
|     auto rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(stateHost); | ||||
| 
 | ||||
|     dim3 launchDims = dim3(512, 512, 32768); | ||||
|     auto xType = nd4j::ArrayOptions::dataType(hZShapeInfo); | ||||
|     // functions::random::RandomFunction<float>::executeCudaDouble(launchDims, extraPointers, opNum, stateHost, dX, dXShapeInfo, dZ, dZShapeInfo, extraArguments); | ||||
|     BUILD_SINGLE_SELECTOR(xType, functions::random::RandomFunction, ::executeCudaDouble(launchDims, stream, opNum, stateDevice, dX, dXShapeInfo, dZ, dZShapeInfo, extraArguments), FLOAT_TYPES); | ||||
| 
 | ||||
|     checkCudaErrors(cudaMemcpyAsync(stateHost, stateDevice, sizeOf, cudaMemcpyDeviceToHost, *stream)); | ||||
|     checkCudaErrors(cudaStreamSynchronize(*stream)); | ||||
|     cudaFree(stateDevice); | ||||
| 
 | ||||
|     rng->rewindH(shape::length(hZShapeInfo)); | ||||
| } | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////// | ||||
| @ -1254,14 +1175,17 @@ void NativeOpExecutioner::execRandom(nd4j::LaunchContext  *lc, | ||||
|     checkCudaErrors(cudaStreamSynchronize(*stream)); | ||||
|     checkCudaErrors(cudaMemcpyAsync(stateDevice, stateHost, sizeOf, cudaMemcpyHostToDevice, *stream)); | ||||
| 
 | ||||
|     auto rng = reinterpret_cast<nd4j::graph::RandomGenerator*>(stateHost); | ||||
| 
 | ||||
|     dim3 launchDims = dim3(512, 512, 32768); | ||||
|     auto xType = nd4j::ArrayOptions::dataType(hZShapeInfo); | ||||
|     // functions::random::RandomFunction<float>::executeCudaTriple(launchDims, extraPointers, opNum, stateHost, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraArguments); | ||||
|     BUILD_SINGLE_SELECTOR(xType, functions::random::RandomFunction, ::executeCudaTriple(launchDims, stream, opNum, stateDevice, dX, dXShapeInfo, dY, dYShapeInfo, dZ, dZShapeInfo, extraArguments), FLOAT_TYPES); | ||||
| 
 | ||||
|     checkCudaErrors(cudaMemcpyAsync(stateHost, stateDevice, sizeOf, cudaMemcpyDeviceToHost, *stream)); | ||||
|     checkCudaErrors(cudaStreamSynchronize(*stream)); | ||||
|     cudaFree(stateDevice); | ||||
| 
 | ||||
|     rng->rewindH(shape::length(hZShapeInfo)); | ||||
| } | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////// | ||||
|  | ||||
| @ -1958,7 +1958,7 @@ void NativeOps::execRandom(Nd4jPointer *extraPointers, | ||||
|                           void *extraArguments) { | ||||
| 
 | ||||
|     LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); | ||||
|     NativeOpExecutioner::execRandom(&lc, opNum, extraPointers, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); | ||||
|     NativeOpExecutioner::execRandom(&lc, opNum, stateHost, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); | ||||
| } | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////// | ||||
| @ -1970,7 +1970,7 @@ void NativeOps::execRandom(Nd4jPointer *extraPointers, int opNum, Nd4jPointer st | ||||
| 						   void *extraArguments) { | ||||
| 
 | ||||
|     LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); | ||||
|     NativeOpExecutioner::execRandom(&lc, opNum, extraPointers, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); | ||||
|     NativeOpExecutioner::execRandom(&lc, opNum, stateHost, hX, hXShapeInfo, dX, dXShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); | ||||
| } | ||||
| 
 | ||||
| //////////////////////////////////////////////////////////////////////// | ||||
| @ -1984,7 +1984,7 @@ void NativeOps::execRandom(Nd4jPointer *extraPointers, int opNum, Nd4jPointer st | ||||
| 							void *extraArguments) { | ||||
| 
 | ||||
|     LaunchContext lc(extraPointers[1], extraPointers[4], extraPointers[5], extraPointers[3]); | ||||
|     NativeOpExecutioner::execRandom(&lc, opNum, extraPointers, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); | ||||
|     NativeOpExecutioner::execRandom(&lc, opNum, stateHost, hX, hXShapeInfo, dX, dXShapeInfo, hY, hYShapeInfo, dY, dYShapeInfo, hZ, hZShapeInfo, dZ, dZShapeInfo, extraArguments); | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -152,9 +152,9 @@ namespace functions { | ||||
|                 __syncthreads(); | ||||
| 
 | ||||
|                 // using this loop instead of memcpy | ||||
|                 for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) { | ||||
|                 for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) | ||||
|                     cB[e] = dB[e]; | ||||
|                 } | ||||
| 
 | ||||
|                 __syncthreads(); | ||||
| 
 | ||||
| 
 | ||||
| @ -213,9 +213,9 @@ namespace functions { | ||||
|                 __syncthreads(); | ||||
| 
 | ||||
|                 // using this loop instead of memcpy | ||||
|                 for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) { | ||||
|                 for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) | ||||
|                     cB[e] = dB[e]; | ||||
|                 } | ||||
| 
 | ||||
|                 __syncthreads(); | ||||
| 
 | ||||
| 
 | ||||
| @ -262,9 +262,9 @@ namespace functions { | ||||
|                 __syncthreads(); | ||||
| 
 | ||||
|                 // using this loop instead of memcpy | ||||
|                 for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) { | ||||
|                 for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) | ||||
|                     cB[e] = dB[e]; | ||||
|                 } | ||||
| 
 | ||||
|                 __syncthreads(); | ||||
| 
 | ||||
|                 int tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||
|  | ||||
| @ -107,9 +107,6 @@ | ||||
| 
 | ||||
| 
 | ||||
| #define TRANSFORM_STRICT_OPS \ | ||||
|         (0, SoftMax), \ | ||||
|         (1, SoftMaxDerivative), \ | ||||
|         (2, LogSoftMax) ,\ | ||||
|         (3, ELUDerivative), \ | ||||
|         (4, TanhDerivative), \ | ||||
|         (5, HardTanhDerivative), \ | ||||
| @ -167,9 +164,7 @@ | ||||
| 
 | ||||
| // these ops return one of FLOAT data types
 | ||||
| #define TRANSFORM_FLOAT_OPS \ | ||||
|         (0, Histogram), \ | ||||
|         (1, Sqrt), \ | ||||
|         (2, Pooling2D) ,\ | ||||
|         (3, RSqrt) | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -0,0 +1,57 @@ | ||||
| /*******************************************************************************
 | ||||
|  * Copyright (c) 2015-2018 Skymind, Inc. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0.
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //
 | ||||
| // @author raver119@gmail.com
 | ||||
| //
 | ||||
| 
 | ||||
| #include <op_boilerplate.h> | ||||
| #if NOT_EXCLUDED(OP_histogram) | ||||
| 
 | ||||
| #include <ops/declarable/CustomOperations.h> | ||||
| #include <ops/declarable/helpers/transforms.h> | ||||
| #include <ops/declarable/helpers/histogram.h> | ||||
| 
 | ||||
| namespace nd4j { | ||||
|     namespace ops { | ||||
|         CUSTOM_OP_IMPL(histogram, 1, 1, false, 0, 1) { | ||||
|             auto input = INPUT_VARIABLE(0); | ||||
|             auto numBins = INT_ARG(0); | ||||
|             auto output = OUTPUT_VARIABLE(0); | ||||
| 
 | ||||
|             REQUIRE_TRUE(numBins == output->lengthOf(), 0, "Histogram: numBins must match output length") | ||||
| 
 | ||||
|             helpers::histogramHelper(block.launchContext(), *input, *output); | ||||
| 
 | ||||
|             return Status::OK(); | ||||
|         } | ||||
| 
 | ||||
|         DECLARE_SHAPE_FN(histogram) { | ||||
|             auto numBins = INT_ARG(0); | ||||
| 
 | ||||
|             return SHAPELIST(ConstantShapeHelper::getInstance()->vectorShapeInfo(numBins, nd4j::DataType::INT64)); | ||||
|         } | ||||
| 
 | ||||
| 
 | ||||
|         DECLARE_TYPES(histogram) { | ||||
|             getOpDescriptor() | ||||
|                     ->setAllowedInputTypes(0, {ALL_INTS, ALL_FLOATS}) | ||||
|                     ->setAllowedOutputTypes({ALL_INTS}); | ||||
|         }; | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| #endif | ||||
| @ -206,6 +206,13 @@ namespace nd4j { | ||||
|         #if NOT_EXCLUDED(OP_hashcode) | ||||
|         DECLARE_REDUCTION_OP(hashcode, 1, 1, false, 0, 0); | ||||
|         #endif | ||||
| 
 | ||||
|         /**
 | ||||
|          * This operation calculates number of entries per bin | ||||
|          */ | ||||
|         #if NOT_EXCLUDED(OP_histogram) | ||||
|         DECLARE_CUSTOM_OP(histogram, 1, 1, false, 0, 1); | ||||
|         #endif | ||||
|     } | ||||
| } | ||||
| 
 | ||||
|  | ||||
							
								
								
									
										83
									
								
								libnd4j/include/ops/declarable/helpers/cpu/histogram.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										83
									
								
								libnd4j/include/ops/declarable/helpers/cpu/histogram.cpp
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,83 @@ | ||||
| /*******************************************************************************
 | ||||
|  * Copyright (c) 2015-2018 Skymind, Inc. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0.
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //
 | ||||
| // @author raver119@gmail.com
 | ||||
| //
 | ||||
| 
 | ||||
| #include <ops/declarable/helpers/histogram.h> | ||||
| 
 | ||||
| namespace nd4j { | ||||
|     namespace ops { | ||||
|         namespace helpers { | ||||
|             template <typename X, typename Z> | ||||
|             static void histogram_(void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong numBins, double min_val, double max_val) { | ||||
|                 auto dx = reinterpret_cast<X*>(xBuffer); | ||||
|                 auto result = reinterpret_cast<Z*>(zBuffer); | ||||
| 
 | ||||
|                 int length = shape::length(xShapeInfo); | ||||
|                 // FIXME: 2???
 | ||||
|                 int _threads = 2; | ||||
| 
 | ||||
|                 int span = (length / _threads) + 8; | ||||
| 
 | ||||
|                 X binSize = (max_val - min_val) / (numBins); | ||||
| 
 | ||||
|                 PRAGMA_OMP_PARALLEL_THREADS(_threads) | ||||
|                 { | ||||
|                     int tid, start, end; | ||||
| 
 | ||||
|                     int *bins = new int[numBins]; | ||||
|                     std::memset(bins, 0, sizeof(int) * numBins); | ||||
|                     tid = omp_get_thread_num(); | ||||
|                     start = span * tid; | ||||
|                     end = span * (tid + 1); | ||||
|                     if (end > length) end = length; | ||||
| 
 | ||||
|                     PRAGMA_OMP_SIMD | ||||
|                     for (int x = start; x < end; x++) { | ||||
|                         int idx = (int) ((dx[x] - min_val) / binSize); | ||||
|                         if (idx < 0) | ||||
|                             idx = 0; | ||||
|                         else if (idx >= numBins) | ||||
|                             idx = numBins - 1; | ||||
| 
 | ||||
|                         bins[idx]++; | ||||
|                     } | ||||
| 
 | ||||
|                     PRAGMA_OMP_CRITICAL | ||||
|                     { | ||||
|                         PRAGMA_OMP_SIMD | ||||
|                         for (int x = 0; x < numBins; x++) { | ||||
|                             result[x] += bins[x]; | ||||
|                         } | ||||
| 
 | ||||
|                     } | ||||
| 
 | ||||
|                     delete[] bins; | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             void histogramHelper(nd4j::LaunchContext *context, NDArray &input, NDArray &output) { | ||||
|                 Nd4jLong numBins = output.lengthOf(); | ||||
|                 double min_val = input.reduceNumber(reduce::SameOps::Min).e<double>(0); | ||||
|                 double max_val = input.reduceNumber(reduce::SameOps::Max).e<double>(0); | ||||
| 
 | ||||
|                 BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, (input.buffer(), input.shapeInfo(), output.getBuffer(), output.getShapeInfo(), numBins, min_val, max_val), LIBND4J_TYPES, INTEGER_TYPES); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
							
								
								
									
										134
									
								
								libnd4j/include/ops/declarable/helpers/cuda/histogram.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										134
									
								
								libnd4j/include/ops/declarable/helpers/cuda/histogram.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,134 @@ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2015-2018 Skymind, Inc. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| // | ||||
| // @author raver119@gmail.com | ||||
| // | ||||
| 
 | ||||
| #include <ops/declarable/helpers/histogram.h> | ||||
| #include <NDArrayFactory.h> | ||||
| 
 | ||||
| namespace nd4j { | ||||
|     namespace ops { | ||||
|         namespace helpers { | ||||
|             template <typename X, typename Z> | ||||
|             void _CUDA_G histogramKernel(void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, void *allocationPointer, void *reductionPointer, Nd4jLong numBins, double min_val, double max_val) { | ||||
|                 int tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||
|                 auto dx = reinterpret_cast<X*>(xBuffer); | ||||
|                 auto result = reinterpret_cast<Z*>(zBuffer); | ||||
| 
 | ||||
|                 __shared__ Z *bins; | ||||
|                 __shared__ int length; | ||||
|                 __shared__ Z *reductor; | ||||
|                 if (threadIdx.x == 0) { | ||||
|                     extern __shared__ unsigned char shmem[]; | ||||
|                     bins = (Z *) shmem; | ||||
|                     reductor = ((Z *) allocationPointer) + (numBins * blockIdx.x); | ||||
| 
 | ||||
|                     length = shape::length(xShapeInfo); | ||||
|                 } | ||||
|                 __syncthreads(); | ||||
| 
 | ||||
|                 Z binSize = (max_val - min_val) / (numBins); | ||||
| 
 | ||||
|                 for (int e = threadIdx.x; e < numBins; e += blockDim.x) { | ||||
|                     bins[e] = (Z) 0.0f; | ||||
|                 } | ||||
|                 __syncthreads(); | ||||
| 
 | ||||
|                 for (int e = tid; e < length; e+= blockDim.x * gridDim.x) { | ||||
|                     int idx = (int) ((dx[e] - min_val) / binSize); | ||||
|                     if (idx < 0) idx = 0; | ||||
|                     else if (idx >= numBins) idx = numBins - 1; | ||||
| 
 | ||||
|                     nd4j::math::atomics::nd4j_atomicAdd(&bins[idx], (Z) 1.0f); | ||||
|                 } | ||||
|                 __syncthreads(); | ||||
| 
 | ||||
|                 // transfer shared memory to reduction memory | ||||
| 
 | ||||
| 
 | ||||
|                 if (gridDim.x > 1) { | ||||
|                     unsigned int *tc = (unsigned int *)reductionPointer; | ||||
|                     __shared__ bool amLast; | ||||
| 
 | ||||
|                     for (int e = threadIdx.x; e < numBins; e += blockDim.x) { | ||||
|                         reductor[e] = bins[e]; | ||||
|                     } | ||||
|                     __threadfence(); | ||||
|                     __syncthreads(); | ||||
| 
 | ||||
|                     if (threadIdx.x == 0) { | ||||
|                         unsigned int ticket = atomicInc(&tc[16384], gridDim.x); | ||||
|                         amLast = (ticket == gridDim.x - 1); | ||||
|                     } | ||||
|                     __syncthreads(); | ||||
| 
 | ||||
|                     if (amLast) { | ||||
|                         tc[16384] = 0; | ||||
| 
 | ||||
|                         // nullify shared memory for future accumulation | ||||
|                         for (int e = threadIdx.x; e < numBins; e += blockDim.x) { | ||||
|                             bins[e] = (Z) 0.0f; | ||||
|                         } | ||||
| 
 | ||||
|                         // accumulate reduced bins | ||||
|                         for (int r = 0; r < gridDim.x; r++) { | ||||
|                             Z *ptrBuf = ((Z *)allocationPointer) + (r * numBins); | ||||
| 
 | ||||
|                             for (int e = threadIdx.x; e < numBins; e += blockDim.x) { | ||||
|                                 bins[e] += ptrBuf[e]; | ||||
|                             } | ||||
|                         } | ||||
|                         __syncthreads(); | ||||
| 
 | ||||
|                         // write them out to Z | ||||
|                         for (int e = threadIdx.x; e < numBins; e += blockDim.x) { | ||||
|                             result[e] = bins[e]; | ||||
|                         } | ||||
|                     } | ||||
|                 } else { | ||||
|                     // if there's only 1 block - just write away data | ||||
|                     for (int e = threadIdx.x; e < numBins; e += blockDim.x) { | ||||
|                         result[e] = bins[e]; | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             template <typename X, typename Z> | ||||
|             static void histogram_(nd4j::LaunchContext *context, void *xBuffer, Nd4jLong *xShapeInfo, void *zBuffer, Nd4jLong *zShapeInfo, Nd4jLong numBins, double min_val, double max_val) { | ||||
|                 int numThreads = 256; | ||||
|                 int numBlocks = nd4j::math::nd4j_max<int>(256, nd4j::math::nd4j_min<int>(1, shape::length(xShapeInfo) / numThreads)); | ||||
|                 int workspaceSize = numBlocks * numBins; | ||||
|                 auto tmp = NDArrayFactory::create<Z>('c',{workspaceSize}); | ||||
| 
 | ||||
|                 histogramKernel<X, Z><<<numBlocks, numThreads, 32768, *context->getCudaStream()>>>(xBuffer, xShapeInfo, zBuffer, zShapeInfo, tmp.getSpecialBuffer(), context->getReductionPointer(), numBins, min_val, max_val); | ||||
| 
 | ||||
|                 cudaStreamSynchronize(*context->getCudaStream()); | ||||
|             } | ||||
| 
 | ||||
|             void histogramHelper(nd4j::LaunchContext *context, NDArray &input, NDArray &output) { | ||||
|                 Nd4jLong numBins = output.lengthOf(); | ||||
|                 double min_val = input.reduceNumber(reduce::SameOps::Min).e<double>(0); | ||||
|                 double max_val = input.reduceNumber(reduce::SameOps::Max).e<double>(0); | ||||
| 
 | ||||
|                 BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), histogram_, (context, input.specialBuffer(), input.specialShapeInfo(), output.getSpecialBuffer(), output.getSpecialShapeInfo(), numBins, min_val, max_val), LIBND4J_TYPES, INTEGER_TYPES); | ||||
| 
 | ||||
|                 NDArray::registerSpecialUse({&output}, {&input}); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| } | ||||
							
								
								
									
										34
									
								
								libnd4j/include/ops/declarable/helpers/histogram.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								libnd4j/include/ops/declarable/helpers/histogram.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,34 @@ | ||||
| /*******************************************************************************
 | ||||
|  * Copyright (c) 2015-2018 Skymind, Inc. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0.
 | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| //
 | ||||
| // @author raver119@gmail.com
 | ||||
| //
 | ||||
| 
 | ||||
| #ifndef LIBND4J_HISTOGRAM_H | ||||
| #define LIBND4J_HISTOGRAM_H | ||||
| 
 | ||||
| #include <NDArray.h> | ||||
| 
 | ||||
| namespace nd4j { | ||||
|     namespace ops { | ||||
|         namespace helpers { | ||||
|             void histogramHelper(nd4j::LaunchContext *context, NDArray &input, NDArray &output); | ||||
|         } | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| #endif //DEV_TESTS_HISTOGRAM_H
 | ||||
| @ -747,88 +747,7 @@ static void execSpecial(T *in, Nd4jLong *inShapeBuffer, Z *out, Nd4jLong *outSha | ||||
|                              int *allocationPointer, Z *reductionPointer,  | ||||
|                              Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { | ||||
| 
 | ||||
|             int numBins = (int) extraParams[0]; | ||||
|             Z min_val = extraParams[1]; | ||||
|             Z max_val = extraParams[2]; | ||||
| 
 | ||||
|             int tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||
| 
 | ||||
|             __shared__ Z *bins; | ||||
|             __shared__ int length; | ||||
|             __shared__ Z *reductor; | ||||
|             if (threadIdx.x == 0) { | ||||
|                 extern __shared__ unsigned char shmem[]; | ||||
|                 bins = (Z *) shmem; | ||||
|                 reductor = ((Z *) allocationPointer) + (numBins * blockIdx.x); | ||||
| 
 | ||||
|                 length = shape::length(xShapeBuffer); | ||||
|             } | ||||
|             __syncthreads(); | ||||
| 
 | ||||
|             Z binSize = (max_val - min_val) / (numBins); | ||||
| 
 | ||||
|             for (int e = threadIdx.x; e < numBins; e += blockDim.x) { | ||||
|                 bins[e] = (Z) 0.0f; | ||||
|             } | ||||
|             __syncthreads(); | ||||
| 
 | ||||
|             for (int e = tid; e < length; e+= blockDim.x * gridDim.x) { | ||||
|                 int idx = (int) ((dx[e] - min_val) / binSize); | ||||
| 				    if (idx < 0) idx = 0; | ||||
| 					else if (idx >= numBins) idx = numBins - 1; | ||||
| 
 | ||||
| 				nd4j::math::atomics::nd4j_atomicAdd(&bins[idx], (Z) 1.0f); | ||||
|             } | ||||
|             __syncthreads(); | ||||
| 
 | ||||
|             // transfer shared memory to reduction memory
 | ||||
| 
 | ||||
| 
 | ||||
|             if (gridDim.x > 1) { | ||||
|                 unsigned int *tc = (unsigned int *)reductionPointer; | ||||
|                 __shared__ bool amLast; | ||||
| 
 | ||||
|                 for (int e = threadIdx.x; e < numBins; e += blockDim.x) { | ||||
|                     reductor[e] = bins[e]; | ||||
|                 } | ||||
|                 __threadfence(); | ||||
|                 __syncthreads(); | ||||
| 
 | ||||
|                 if (threadIdx.x == 0) { | ||||
| 						unsigned int ticket = atomicInc(&tc[16384], gridDim.x); | ||||
| 						amLast = (ticket == gridDim.x - 1); | ||||
| 				} | ||||
| 				__syncthreads(); | ||||
| 
 | ||||
| 				if (amLast) { | ||||
| 				    tc[16384] = 0; | ||||
| 
 | ||||
|                     // nullify shared memory for future accumulation
 | ||||
|                     for (int e = threadIdx.x; e < numBins; e += blockDim.x) { | ||||
|                         bins[e] = (Z) 0.0f; | ||||
|                     } | ||||
| 
 | ||||
|                     // accumulate reduced bins
 | ||||
|                     for (int r = 0; r < gridDim.x; r++) { | ||||
|                         Z *ptrBuf = ((Z *)allocationPointer) + (r * numBins); | ||||
| 
 | ||||
|                         for (int e = threadIdx.x; e < numBins; e += blockDim.x) { | ||||
|                             bins[e] += ptrBuf[e]; | ||||
|                         } | ||||
|                     } | ||||
|                     __syncthreads(); | ||||
| 
 | ||||
|                     // write them out to Z
 | ||||
|                     for (int e = threadIdx.x; e < numBins; e += blockDim.x) { | ||||
|                         result[e] = bins[e]; | ||||
|                     } | ||||
| 				} | ||||
|             } else { | ||||
|                 // if there's only 1 block - just write away data
 | ||||
|                 for (int e = threadIdx.x; e < numBins; e += blockDim.x) { | ||||
|                     result[e] = bins[e]; | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
| 		}; | ||||
| #endif | ||||
| @ -840,52 +759,7 @@ static void execSpecial(T *in, Nd4jLong *inShapeBuffer, Z *out, Nd4jLong *outSha | ||||
| 				Nd4jLong *zShapeBuffer, | ||||
| 				Z *extraParams, Nd4jLong *tadShapeInfo, Nd4jLong *tadOffsets) { | ||||
| 
 | ||||
| 			int length = shape::length(xShapeBuffer); | ||||
| 			int _threads = 2; | ||||
| 
 | ||||
| 			int numBins = (int) extraParams[0]; | ||||
| 			int span = (length / _threads) + 8; | ||||
| 
 | ||||
| 			// get min over input
 | ||||
|             T min_val = extraParams[1]; | ||||
|             T max_val = extraParams[2]; | ||||
| 
 | ||||
| 			T binSize = (max_val - min_val) / (numBins); | ||||
| 
 | ||||
| 
 | ||||
|             PRAGMA_OMP_PARALLEL_THREADS(_threads) | ||||
| 			{ | ||||
| 				int tid, start, end; | ||||
| 
 | ||||
| 				int *bins = new int[numBins]; | ||||
|                 std::memset(bins, 0, sizeof(int) * numBins); | ||||
| 				tid = omp_get_thread_num(); | ||||
| 				start = span * tid; | ||||
| 				end = span * (tid + 1); | ||||
| 				if (end > length) end = length; | ||||
| 
 | ||||
| 				PRAGMA_OMP_SIMD | ||||
| 				for (int x = start; x < end; x++) { | ||||
| 					int idx = (int) ((dx[x] - min_val) / binSize); | ||||
| 					if (idx < 0) | ||||
| 						idx = 0; | ||||
| 					else if (idx >= numBins) | ||||
| 						idx = numBins - 1; | ||||
| 
 | ||||
| 					bins[idx]++; | ||||
| 				} | ||||
| 
 | ||||
|                 PRAGMA_OMP_CRITICAL | ||||
| 				{ | ||||
|                     PRAGMA_OMP_SIMD | ||||
| 					for (int x = 0; x < numBins; x++) { | ||||
| 						result[x] += bins[x]; | ||||
| 					} | ||||
| 
 | ||||
| 				} | ||||
| 
 | ||||
| 				delete[] bins; | ||||
| 			} | ||||
| 
 | ||||
| 		} | ||||
| 
 | ||||
|  | ||||
| @ -88,10 +88,10 @@ namespace randomOps { | ||||
|             __syncthreads(); | ||||
| 
 | ||||
|             // using this loop instead of memcpy
 | ||||
|             for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) { | ||||
|             for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) | ||||
|                 cB[e] = dB[e]; | ||||
|             } | ||||
| //            __syncthreads();  // Eliminated due RTX20xx specific
 | ||||
| 
 | ||||
|             __syncthreads(); | ||||
| 
 | ||||
|             int tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||
| 
 | ||||
| @ -137,10 +137,6 @@ namespace randomOps { | ||||
| //                    __syncthreads();  // Eliminated due RTX20xx specific
 | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
| //            __syncthreads();  // Eliminated due RTX20xx specific
 | ||||
|             if (threadIdx.x == 0 && blockIdx.x == 0) | ||||
|                 devRng->rewindH(zLength); | ||||
|         } | ||||
| #endif | ||||
| 
 | ||||
| @ -208,9 +204,6 @@ namespace randomOps { | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             // update rng state
 | ||||
|             rng->rewindH(zLength); | ||||
|         } | ||||
|     }; | ||||
| 
 | ||||
| @ -277,7 +270,7 @@ namespace randomOps { | ||||
|             for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) | ||||
|                 cB[e] = dB[e]; | ||||
| 
 | ||||
| //            __syncthreads();  // Eliminated due RTX20xx specific
 | ||||
|             __syncthreads(); | ||||
| 
 | ||||
|             int tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||
| 
 | ||||
| @ -299,10 +292,6 @@ namespace randomOps { | ||||
|                     z[epm * zEWS] =  (nd4j::math::nd4j_sqrt<T,T>(t * nd4j::math::nd4j_log<T,T>(r0)) * nd4j::math::nd4j_sin<T,T>(two_pi * r1)) * stddev + realMean1; | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
| //            __syncthreads();          // Eliminated due RTX20xx specific
 | ||||
|             if (threadIdx.x == 0 && blockIdx.x == 0) | ||||
|                 devRng->rewindH(zLength); | ||||
|         } | ||||
| #endif | ||||
| 
 | ||||
| @ -352,9 +341,6 @@ namespace randomOps { | ||||
|                     z[epm * zEWS] = z1; | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             // update rng state
 | ||||
|             rng->rewindH(zLength); | ||||
|         } | ||||
|     }; | ||||
| 
 | ||||
| @ -401,10 +387,10 @@ namespace randomOps { | ||||
|             __syncthreads(); | ||||
| 
 | ||||
|             // using this loop instead of memcpy
 | ||||
|             for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) { | ||||
|             for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) | ||||
|                 cB[e] = dB[e]; | ||||
|             } | ||||
| //            __syncthreads();  // Eliminated due RTX20xx specific
 | ||||
| 
 | ||||
|             __syncthreads(); | ||||
| 
 | ||||
|             int tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||
| 
 | ||||
| @ -423,10 +409,6 @@ namespace randomOps { | ||||
|                 // if trials is set to 0, effectively we just have successful memset
 | ||||
|                 z[e * zEWS] = static_cast<T>(success); | ||||
|             } | ||||
| 
 | ||||
| //            __syncthreads();  // Eliminated due RTX20xx specific
 | ||||
|             if (trials > 0 && threadIdx.x == 0 && blockIdx.x == 0) | ||||
|                 devRng->rewindH(zLength * trials); | ||||
|         } | ||||
| #endif | ||||
| 
 | ||||
| @ -472,10 +454,6 @@ namespace randomOps { | ||||
|                     z[e * zEWS] = static_cast<T>(success); | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             // update rng state
 | ||||
|             if (trials > 0) | ||||
|                 rng->rewindH(zLength * trials); | ||||
|         } | ||||
|     }; | ||||
| 
 | ||||
| @ -522,10 +500,10 @@ namespace randomOps { | ||||
|             __syncthreads(); | ||||
| 
 | ||||
|             // using this loop instead of memcpy
 | ||||
|             for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) { | ||||
|             for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) | ||||
|                 cB[e] = dB[e]; | ||||
|             } | ||||
| //            __syncthreads();  // Eliminated due RTX20xx specific
 | ||||
| 
 | ||||
|             __syncthreads(); | ||||
| 
 | ||||
|             int tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||
| 
 | ||||
| @ -591,10 +569,6 @@ namespace randomOps { | ||||
|                     z[e * zEWS] = static_cast<T>(success); | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             // update rng state
 | ||||
|             if (trials > 0) | ||||
|                 rng->rewindH(zLength * trials); | ||||
|         } | ||||
|     }; | ||||
|      | ||||
| @ -676,10 +650,10 @@ namespace randomOps { | ||||
|             __syncthreads(); | ||||
| 
 | ||||
|             // using this loop instead of memcpy
 | ||||
|             for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) { | ||||
|             for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) | ||||
|                 cB[e] = dB[e]; | ||||
|             } | ||||
| //            __syncthreads();  // Eliminated due RTX20xx specific
 | ||||
| 
 | ||||
|             __syncthreads(); | ||||
| 
 | ||||
|             int tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||
| 
 | ||||
| @ -724,9 +698,6 @@ namespace randomOps { | ||||
|                         z[e] = mean + nd4j::DataTypeUtils::min<T>(); | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             // update rng state
 | ||||
|             rng->rewindH(zLength); | ||||
|         } | ||||
|     }; | ||||
| 
 | ||||
| @ -788,10 +759,10 @@ namespace randomOps { | ||||
|             __syncthreads(); | ||||
| 
 | ||||
|             // using this loop instead of memcpy
 | ||||
|             for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) { | ||||
|             for (int e = threadIdx.x; e < sizeof(nd4j::graph::RandomGenerator); e+= blockDim.x) | ||||
|                 cB[e] = dB[e]; | ||||
|             } | ||||
| //            __syncthreads();  // Eliminated due RTX20xx specific
 | ||||
| 
 | ||||
|             __syncthreads(); | ||||
| 
 | ||||
|             int tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||
| 
 | ||||
| @ -868,9 +839,6 @@ namespace randomOps { | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             // update rng state
 | ||||
|             rng->rewindH(zLength); | ||||
|         } | ||||
|     }; | ||||
| 
 | ||||
|  | ||||
| @ -1724,9 +1724,9 @@ namespace nd4j { | ||||
|             } | ||||
|         }; | ||||
| 
 | ||||
|         TransformBenchmark tbSoftmax(transform::StrictOps::SoftMax, "softmax"); | ||||
|         //TransformBenchmark tbSoftmax(transform::StrictOps::SoftMax, "softmax");
 | ||||
| 
 | ||||
|         output += helper.runOperationSuit(&tbSoftmax, generator2, batch2, "Softmax"); | ||||
|         //output += helper.runOperationSuit(&tbSoftmax, generator2, batch2, "Softmax");
 | ||||
| 
 | ||||
|         return output; | ||||
|     } | ||||
|  | ||||
| @ -60,11 +60,11 @@ namespace nd4j { | ||||
|         sbRelu.setY(NDArrayFactory::create_<T>(0.0)); | ||||
| 
 | ||||
|         TransformBenchmark tbSigmoid(transform::StrictOps::Sigmoid, "sigmoid"); | ||||
|         TransformBenchmark tbSoftmax(transform::StrictOps::SoftMax, "softmax"); | ||||
|         //TransformBenchmark tbSoftmax(transform::StrictOps::SoftMax, "softmax");
 | ||||
| 
 | ||||
|         output += helper.runOperationSuit(&sbRelu, generator, batch, "RELU"); | ||||
|         output += helper.runOperationSuit(&tbSigmoid, generator, batch, "Sigmoid"); | ||||
|         output += helper.runOperationSuit(&tbSigmoid, generator, batch, "Softmax"); | ||||
|         //output += helper.runOperationSuit(&tbSigmoid, generator, batch, "Softmax");
 | ||||
| 
 | ||||
|         return output; | ||||
|     } | ||||
|  | ||||
| @ -49,7 +49,6 @@ import org.nd4j.linalg.api.ops.impl.summarystats.Variance; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.comparison.*; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.custom.*; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.floating.Histogram; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.gradient.*; | ||||
| @ -748,8 +747,6 @@ public class LegacyOpMapper { | ||||
| 
 | ||||
|     public static Class<?> transformFloatingOpClass(int opNum){ | ||||
|         switch (opNum){ | ||||
|             case 0: | ||||
|                 return Histogram.class; | ||||
|             case 1: | ||||
|                 return Sqrt.class; | ||||
|             case 3: | ||||
|  | ||||
| @ -70,10 +70,10 @@ import org.nd4j.linalg.api.ops.impl.shape.bp.SliceBp; | ||||
| import org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp; | ||||
| import org.nd4j.linalg.api.ops.impl.shape.bp.TileBp; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.Assert; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.Histogram; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.bool.BooleanNot; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.custom.*; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.floating.Histogram; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.pairwise.BinaryMinimalRelativeError; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.*; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.gradient.*; | ||||
| @ -888,7 +888,6 @@ public class OpValidation { | ||||
|                 SELUDerivative.class, | ||||
|                 SigmoidDerivative.class, | ||||
|                 org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative.class, | ||||
|                 org.nd4j.linalg.api.ops.impl.transforms.strict.SoftMaxDerivative.class, | ||||
|                 SoftSignDerivative.class, | ||||
|                 TanhDerivative.class, | ||||
|                 SwishDerivative.class, | ||||
| @ -987,7 +986,6 @@ public class OpValidation { | ||||
|                 OneHot.class, | ||||
|                 BinaryMinimalRelativeError.class, | ||||
|                 BinaryMinimalRelativeError.class, | ||||
|                 Histogram.class, | ||||
|                 InvertPermutation.class,    //Uses integer indices | ||||
|                 ConfusionMatrix.class,      //Integer indices | ||||
|                 Linspace.class,             //No input array | ||||
| @ -1056,7 +1054,9 @@ public class OpValidation { | ||||
|                 LogicalAnd.class, | ||||
|                 LogicalNot.class, | ||||
|                 LogicalOr.class, | ||||
|                 LogicalXor.class | ||||
|                 LogicalXor.class, | ||||
| 
 | ||||
|                 Histogram.class | ||||
|         ); | ||||
| 
 | ||||
|         return new HashSet<>(list); | ||||
|  | ||||
| @ -424,7 +424,6 @@ public class ImportClassMapping { | ||||
|             org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentProd.class, | ||||
|             org.nd4j.linalg.api.ops.impl.transforms.custom.segment.SegmentSum.class, | ||||
|             org.nd4j.linalg.api.ops.impl.transforms.dtype.Cast.class, | ||||
|             org.nd4j.linalg.api.ops.impl.transforms.floating.Histogram.class, | ||||
|             org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt.class, | ||||
|             org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt.class, | ||||
|             org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative.class, | ||||
| @ -552,7 +551,6 @@ public class ImportClassMapping { | ||||
|             org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative.class, | ||||
|             org.nd4j.linalg.api.ops.impl.transforms.strict.Sin.class, | ||||
|             org.nd4j.linalg.api.ops.impl.transforms.strict.Sinh.class, | ||||
|             org.nd4j.linalg.api.ops.impl.transforms.strict.SoftMaxDerivative.class, | ||||
|             org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus.class, | ||||
|             org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign.class, | ||||
|             org.nd4j.linalg.api.ops.impl.transforms.strict.Stabilize.class, | ||||
|  | ||||
| @ -19,20 +19,6 @@ package org.nd4j.linalg.activations; | ||||
| import org.nd4j.autodiff.samediff.SDVariable; | ||||
| import org.nd4j.autodiff.samediff.SameDiff; | ||||
| import org.nd4j.linalg.activations.impl.*; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.api.ops.Op; | ||||
| import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU; | ||||
| import org.nd4j.linalg.api.ops.impl.scalar.ScalarSet; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.same.OldIdentity; | ||||
| import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear; | ||||
| import org.nd4j.linalg.api.ops.impl.scalar.Step; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.strict.*; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.gradient.*; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.same.Cube; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.strict.SigmoidDerivative; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftMaxDerivative; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative; | ||||
| 
 | ||||
| /** | ||||
|  * This enum is the factory for the activation function. | ||||
|  | ||||
| @ -0,0 +1,55 @@ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2015-2018 Skymind, Inc. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| package org.nd4j.linalg.api.ops.impl.transforms; | ||||
| 
 | ||||
| import org.nd4j.base.Preconditions; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.api.ops.DynamicCustomOp; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| 
 | ||||
| /** | ||||
|  * Histogram op wrapper | ||||
|  * | ||||
|  * @author raver119@gmail.com | ||||
|  */ | ||||
| public class Histogram extends DynamicCustomOp { | ||||
|     private long numBins; | ||||
| 
 | ||||
|     public Histogram() { | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     public Histogram(INDArray input, INDArray output) { | ||||
|         Preconditions.checkArgument(output.isZ(), "Histogram op output should have integer data type"); | ||||
| 
 | ||||
|         numBins = output.length(); | ||||
|         inputArguments.add(input); | ||||
|         outputArguments.add(output); | ||||
|         iArguments.add(numBins); | ||||
|     } | ||||
| 
 | ||||
|     public Histogram(INDArray input, long numBins) { | ||||
|         this.numBins = numBins; | ||||
|         inputArguments.add(input); | ||||
|         iArguments.add(numBins); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public String opName() { | ||||
|         return "histogram"; | ||||
|     } | ||||
| } | ||||
| @ -1,114 +0,0 @@ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2015-2018 Skymind, Inc. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| package org.nd4j.linalg.api.ops.impl.transforms.floating; | ||||
| 
 | ||||
| import lombok.val; | ||||
| import onnx.OnnxProto3; | ||||
| import org.nd4j.autodiff.samediff.SDVariable; | ||||
| import org.nd4j.autodiff.samediff.SameDiff; | ||||
| import org.nd4j.imports.NoOpNameFoundException; | ||||
| import org.nd4j.imports.descriptors.properties.PropertyMapping; | ||||
| import org.nd4j.imports.graphmapper.tf.TFGraphMapper; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.api.ops.BaseTransformFloatOp; | ||||
| import org.nd4j.linalg.api.ops.BaseTransformOp; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.tensorflow.framework.AttrValue; | ||||
| import org.tensorflow.framework.GraphDef; | ||||
| import org.tensorflow.framework.NodeDef; | ||||
| 
 | ||||
| import java.util.HashMap; | ||||
| import java.util.LinkedHashMap; | ||||
| import java.util.List; | ||||
| import java.util.Map; | ||||
| 
 | ||||
| /** | ||||
|  * @author raver119@gmail.com | ||||
|  */ | ||||
| public class Histogram extends BaseTransformFloatOp { | ||||
|     public Histogram(SameDiff sameDiff, SDVariable i_v, boolean inPlace, int numBins) { | ||||
|         super(sameDiff, i_v, inPlace); | ||||
|         this.numBins = numBins; | ||||
|     } | ||||
| 
 | ||||
|     public Histogram(SameDiff sameDiff, SDVariable i_v, int[] shape, boolean inPlace, Object[] extraArgs, int numBins) { | ||||
|         super(sameDiff, i_v, shape, inPlace, extraArgs); | ||||
|         this.numBins = numBins; | ||||
|     } | ||||
| 
 | ||||
|     public Histogram(SameDiff sameDiff, SDVariable i_v, Object[] extraArgs, int numBins) { | ||||
|         super(sameDiff, i_v, extraArgs); | ||||
|         this.numBins = numBins; | ||||
|     } | ||||
| 
 | ||||
|     private int numBins = 0; | ||||
| 
 | ||||
|     public Histogram() { | ||||
|         //no-op | ||||
|     } | ||||
| 
 | ||||
|     public Histogram(INDArray x, INDArray z) { | ||||
|         setX(x); | ||||
|         setZ(z); | ||||
| 
 | ||||
|         //FIXME: int cast | ||||
|         numBins = (int) z.length(); | ||||
| 
 | ||||
|         double max = x.maxNumber().doubleValue(); | ||||
|         double min = x.minNumber().doubleValue(); | ||||
| 
 | ||||
|         this.extraArgs = new Object[] {(double) numBins, min, max}; | ||||
|     } | ||||
| 
 | ||||
|     public Histogram(INDArray x, int numberOfBins) { | ||||
|         this(x, Nd4j.create(x.dataType(), numberOfBins)); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Map<String, Object> propertiesForFunction() { | ||||
|         Map<String,Object> ret = new LinkedHashMap<>(); | ||||
|         ret.put("numBins",numBins); | ||||
|         return ret; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public int opNum() { | ||||
|         return 0; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public String opName() { | ||||
|         return "histogram"; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Override | ||||
|     public String onnxName() { | ||||
|         throw new NoOpNameFoundException("No onnx op opName found for " +  opName()); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public String tensorflowName() { | ||||
|         throw new NoOpNameFoundException("No tensorflow op opName found for " +  opName()); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Override | ||||
|     public List<SDVariable> doDiff(List<SDVariable> f1) { | ||||
|         throw new UnsupportedOperationException("Not supported"); | ||||
|     } | ||||
| } | ||||
| @ -1,76 +0,0 @@ | ||||
| /******************************************************************************* | ||||
|  * Copyright (c) 2015-2018 Skymind, Inc. | ||||
|  * | ||||
|  * This program and the accompanying materials are made available under the | ||||
|  * terms of the Apache License, Version 2.0 which is available at | ||||
|  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  * | ||||
|  * Unless required by applicable law or agreed to in writing, software | ||||
|  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  * License for the specific language governing permissions and limitations | ||||
|  * under the License. | ||||
|  * | ||||
|  * SPDX-License-Identifier: Apache-2.0 | ||||
|  ******************************************************************************/ | ||||
| 
 | ||||
| package org.nd4j.linalg.api.ops.impl.transforms.strict; | ||||
| 
 | ||||
| import org.nd4j.autodiff.samediff.SDVariable; | ||||
| import org.nd4j.autodiff.samediff.SameDiff; | ||||
| import org.nd4j.imports.NoOpNameFoundException; | ||||
| import org.nd4j.linalg.api.buffer.DataBuffer; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.api.ops.CustomOp; | ||||
| import org.nd4j.linalg.api.ops.Op; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; | ||||
| 
 | ||||
| import java.nio.Buffer; | ||||
| 
 | ||||
| /** | ||||
|  * Softmax derivative | ||||
|  * | ||||
|  * @author Adam Gibson | ||||
|  */ | ||||
| public class SoftMaxDerivative extends SoftMax { | ||||
|     public SoftMaxDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2) { | ||||
|         super(sameDiff, new SDVariable[]{i_v1, i_v2}); | ||||
|     } | ||||
| 
 | ||||
|     public SoftMaxDerivative(SameDiff sameDiff, SDVariable i_v1, SDVariable i_v2, boolean inPlace) { | ||||
|         super(sameDiff, new SDVariable[]{ i_v1, i_v2}, inPlace); | ||||
|     } | ||||
| 
 | ||||
|     public SoftMaxDerivative(INDArray x, INDArray z) { | ||||
|         super(x, z); | ||||
|     } | ||||
| 
 | ||||
|     public SoftMaxDerivative(INDArray x) { | ||||
|         super(x, x); | ||||
|     } | ||||
| 
 | ||||
|     public SoftMaxDerivative() {} | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     @Override | ||||
|     public int opNum() { | ||||
|         return 1; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public String onnxName() { | ||||
|         throw new NoOpNameFoundException("No onnx op opName found for " +  opName()); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public String tensorflowName() { | ||||
|         throw new NoOpNameFoundException("No tensorflow op opName found for " +  opName()); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public String opName() { | ||||
|         return "_softmaxderivative"; | ||||
|     } | ||||
| } | ||||
| @ -1529,85 +1529,40 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda | ||||
|         Nd4j.getExecutioner().commit(); | ||||
| 
 | ||||
| 
 | ||||
|             AllocationPoint old = allocationPoint; | ||||
|             allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, dataType()), false); | ||||
|         AllocationPoint old = allocationPoint; | ||||
|         allocationPoint = AtomicAllocator.getInstance().allocateMemory(this, new AllocationShape(length, elementSize, dataType()), false); | ||||
| 
 | ||||
|             Nd4j.getDeallocatorService().pickObject(this); | ||||
|             trackingPoint = allocationPoint.getObjectId(); | ||||
|         Nd4j.getDeallocatorService().pickObject(this); | ||||
|         trackingPoint = allocationPoint.getObjectId(); | ||||
|         val oldLength = this.length; | ||||
|         this.length = length; | ||||
| 
 | ||||
|             switch(dataType()){ | ||||
|                 case DOUBLE: | ||||
|                     this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asDoublePointer(); | ||||
|                     indexer = DoubleIndexer.create((DoublePointer) pointer); | ||||
|                     break; | ||||
|                 case FLOAT: | ||||
|                     this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asFloatPointer(); | ||||
|                     indexer = FloatIndexer.create((FloatPointer) pointer); | ||||
|                     break; | ||||
|                 case BFLOAT16: | ||||
|                     this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asShortPointer(); | ||||
|                     indexer = Bfloat16Indexer.create((ShortPointer) pointer); | ||||
|                     break; | ||||
|                 case HALF: | ||||
|                     this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asShortPointer(); | ||||
|                     indexer = ShortIndexer.create((ShortPointer) pointer); | ||||
|                     break; | ||||
|                 case LONG: | ||||
|                     this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asLongPointer(); | ||||
|                     indexer = LongIndexer.create((LongPointer) pointer); | ||||
|                     break; | ||||
|                 case UINT64: | ||||
|                     this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asLongPointer(); | ||||
|                     indexer = LongIndexer.create((LongPointer) pointer); | ||||
|                     break; | ||||
|                 case INT: | ||||
|                     this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asIntPointer(); | ||||
|                     indexer = IntIndexer.create((IntPointer) pointer); | ||||
|                     break; | ||||
|                 case UINT32: | ||||
|                     this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asIntPointer(); | ||||
|                     indexer = IntIndexer.create((IntPointer) pointer); | ||||
|                     break; | ||||
|                 case SHORT: | ||||
|                     this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asShortPointer(); | ||||
|                     indexer = ShortIndexer.create((ShortPointer) pointer); | ||||
|                     break; | ||||
|                 case UINT16: | ||||
|                     this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asShortPointer(); | ||||
|                     indexer = UShortIndexer.create((ShortPointer) pointer); | ||||
|                     break; | ||||
|                 case BYTE: | ||||
|                     this.pointer = new CudaPointer(allocationPoint.getPointers().getHostPointer(), length, 0).asBytePointer(); | ||||
|                     indexer = ByteIndexer.create((BytePointer) pointer); | ||||
|                     break; | ||||
|                 default: | ||||
|                     throw new UnsupportedOperationException(); | ||||
|             } | ||||
|         // if original buffer had host pointer allocated, we'll reallocate host buffer as well | ||||
|         if (old.getHostPointer() != null) { | ||||
|             lazyAllocateHostPointer(); | ||||
|         } | ||||
| 
 | ||||
|             CudaContext context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); | ||||
|             NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(allocationPoint.getDevicePointer(), 0, length * elementSize, 0, context.getSpecialStream()); | ||||
|         val context = (CudaContext) AtomicAllocator.getInstance().getDeviceContext().getContext(); | ||||
|         NativeOpsHolder.getInstance().getDeviceNativeOps().memsetAsync(allocationPoint.getDevicePointer(), 0, length * elementSize, 0, context.getSpecialStream()); | ||||
| 
 | ||||
|             MemcpyDirection direction = MemcpyDirection.DEVICE_TO_DEVICE; | ||||
|             val perfD = PerformanceTracker.getInstance().helperStartTransaction(); | ||||
|         MemcpyDirection direction = MemcpyDirection.DEVICE_TO_DEVICE; | ||||
|         val perfD = PerformanceTracker.getInstance().helperStartTransaction(); | ||||
| 
 | ||||
|             if (old.isActualOnDeviceSide()) { | ||||
|                 NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), old.getDevicePointer(), this.length * elementSize, CudaConstants.cudaMemcpyDeviceToDevice, context.getSpecialStream()); | ||||
|             } else if (old.isActualOnHostSide()) { | ||||
|                 NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), old.getHostPointer(), this.length * elementSize, CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream()); | ||||
|                 direction = MemcpyDirection.HOST_TO_DEVICE; | ||||
|             } | ||||
|         if (old.isActualOnDeviceSide()) { | ||||
|             NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), old.getDevicePointer(), oldLength * elementSize, CudaConstants.cudaMemcpyDeviceToDevice, context.getSpecialStream()); | ||||
|         } else if (old.isActualOnHostSide()) { | ||||
|             NativeOpsHolder.getInstance().getDeviceNativeOps().memcpyAsync(allocationPoint.getDevicePointer(), old.getHostPointer(), oldLength * elementSize, CudaConstants.cudaMemcpyHostToDevice, context.getSpecialStream()); | ||||
|             direction = MemcpyDirection.HOST_TO_DEVICE; | ||||
|         } | ||||
| 
 | ||||
|             context.getSpecialStream().synchronize(); | ||||
|         context.getSpecialStream().synchronize(); | ||||
| 
 | ||||
|             PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD, allocationPoint.getNumberOfBytes(), direction); | ||||
|         PerformanceTracker.getInstance().helperRegisterTransaction(allocationPoint.getDeviceId(), perfD, allocationPoint.getNumberOfBytes(), direction); | ||||
| 
 | ||||
|             allocationPoint.tickDeviceWrite(); | ||||
|             // we're keeping pointer reference for JVM | ||||
|             pointer.address(); | ||||
|         allocationPoint.tickDeviceWrite(); | ||||
| 
 | ||||
| 
 | ||||
|             // we need to update length with new value now | ||||
|             //this.length = length; | ||||
|         // we need to update length with new value now | ||||
|         //this.length = length; | ||||
|         if(isAttached()){ | ||||
|             // do nothing here, that's workspaces | ||||
|         } else{ | ||||
| @ -1619,7 +1574,10 @@ public abstract class BaseCudaDataBuffer extends BaseDataBuffer implements JCuda | ||||
| 
 | ||||
|     @Override | ||||
|     public long capacity() { | ||||
|         return pointer.capacity(); | ||||
|         if (allocationPoint.getHostPointer() != null) | ||||
|             return pointer.capacity(); | ||||
|         else | ||||
|             return length; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|  | ||||
| @ -16755,6 +16755,27 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); | ||||
|                                                                                 } | ||||
| //         #endif | ||||
| 
 | ||||
|         /** | ||||
|          * This operation calculates number of entries per bin | ||||
|          */ | ||||
| //         #if NOT_EXCLUDED(OP_histogram) | ||||
|         @Namespace("nd4j::ops") public static class histogram extends DeclarableCustomOp { | ||||
|             static { Loader.load(); } | ||||
|             /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ | ||||
|             public histogram(Pointer p) { super(p); } | ||||
|             /** Native array allocator. Access with {@link Pointer#position(long)}. */ | ||||
|             public histogram(long size) { super((Pointer)null); allocateArray(size); } | ||||
|             private native void allocateArray(long size); | ||||
|             @Override public histogram position(long position) { | ||||
|                 return (histogram)super.position(position); | ||||
|             } | ||||
|          | ||||
|                                                                                     public histogram() { super((Pointer)null); allocate(); } | ||||
|                                                                                     private native void allocate(); | ||||
|                                                                                     public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); | ||||
|                                                                                 } | ||||
| //         #endif | ||||
|      | ||||
| 
 | ||||
| 
 | ||||
| // #endif | ||||
|  | ||||
| @ -79,7 +79,6 @@ import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.Axpy; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.same.OldReverse; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.same.Sign; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.strict.ACosh; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftMaxDerivative; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh; | ||||
| import org.nd4j.linalg.api.shape.Shape; | ||||
| import org.nd4j.linalg.checkutil.NDArrayCreationUtil; | ||||
| @ -931,13 +930,6 @@ public class Nd4jTestsC extends BaseNd4jTest { | ||||
|         System.out.println("Second: OK"); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testSoftmaxDerivative() { | ||||
|         INDArray input = Nd4j.create(new double[] {-1.07, -0.01, 0.45, 0.95, 0.45, 0.16, 0.20, 0.80, 0.89, 0.25}).reshape(1, -1).transpose(); | ||||
|         INDArray output = Nd4j.create(10, 1); | ||||
|         Nd4j.getExecutioner().exec(new SoftMaxDerivative(input, output)); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Test | ||||
|     public void testVStackDifferentOrders() { | ||||
| @ -7245,19 +7237,6 @@ public class Nd4jTestsC extends BaseNd4jTest { | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Test | ||||
|     public void testInvalidTransformsSoftmax(){ | ||||
|         INDArray arr = Nd4j.zeros(2,3,4); | ||||
|         try{ | ||||
|             Transforms.softmax(arr); | ||||
|             fail("Expected exception"); | ||||
|         } catch (IllegalArgumentException e){ | ||||
|             //OK | ||||
|             assertTrue(e.getMessage().contains("rank 2")); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testEmptyCasting(){ | ||||
|         for(val from : DataType.values()) { | ||||
|  | ||||
| @ -17,6 +17,7 @@ | ||||
| package org.nd4j.linalg.api.buffer; | ||||
| 
 | ||||
| 
 | ||||
| import lombok.val; | ||||
| import org.junit.Test; | ||||
| import org.nd4j.linalg.BaseNd4jTest; | ||||
| import org.nd4j.linalg.api.memory.MemoryWorkspace; | ||||
| @ -28,6 +29,7 @@ import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.linalg.factory.Nd4jBackend; | ||||
| 
 | ||||
| import java.io.*; | ||||
| import java.util.Arrays; | ||||
| 
 | ||||
| import static org.junit.Assert.assertEquals; | ||||
| import static org.junit.Assert.assertTrue; | ||||
| @ -85,10 +87,12 @@ public class IntDataBufferTests extends BaseNd4jTest { | ||||
|     public void testReallocation() { | ||||
|         DataBuffer buffer = Nd4j.createBuffer(new int[] {1, 2, 3, 4}); | ||||
|         assertEquals(4, buffer.capacity()); | ||||
|         int[] old = buffer.asInt(); | ||||
|         buffer.reallocate(6); | ||||
|         val old = buffer.asInt(); | ||||
|         assertEquals(6, buffer.capacity()); | ||||
|         assertArrayEquals(old, buffer.asInt()); | ||||
|         val newContent = buffer.asInt(); | ||||
|         assertEquals(6, newContent.length); | ||||
|         assertArrayEquals(old, Arrays.copyOf(newContent, old.length)); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
| @ -98,12 +102,14 @@ public class IntDataBufferTests extends BaseNd4jTest { | ||||
|         MemoryWorkspace workspace = Nd4j.getWorkspaceManager().getAndActivateWorkspace(initialConfig, "SOME_ID"); | ||||
| 
 | ||||
|         DataBuffer buffer = Nd4j.createBuffer(new int[] {1, 2, 3, 4}); | ||||
|         int[] old = buffer.asInt(); | ||||
|         val old = buffer.asInt(); | ||||
|         assertTrue(buffer.isAttached()); | ||||
|         assertEquals(4, buffer.capacity()); | ||||
|         buffer.reallocate(6); | ||||
|         assertEquals(6, buffer.capacity()); | ||||
|         assertArrayEquals(old, buffer.asInt()); | ||||
|         val newContent = buffer.asInt(); | ||||
|         assertEquals(6, newContent.length); | ||||
|         assertArrayEquals(old, Arrays.copyOf(newContent, old.length)); | ||||
|         workspace.close(); | ||||
|     } | ||||
| 
 | ||||
|  | ||||
| @ -82,12 +82,12 @@ public class RngTests extends BaseNd4jTest { | ||||
|         INDArray narr = Nd4j.randn('c', rows, cols); | ||||
|         assertArrayEquals(new long[] {rows, cols}, narr.shape()); | ||||
|         assertEquals('c', narr.ordering()); | ||||
|         assertEquals(narr.meanNumber().doubleValue(), 0.0, 0.05); | ||||
|         assertEquals(0.0, narr.meanNumber().doubleValue(), 0.05); | ||||
| 
 | ||||
|         INDArray narr2 = Nd4j.randn('f', rows, cols); | ||||
|         assertArrayEquals(new long[] {rows, cols}, narr2.shape()); | ||||
|         assertEquals('f', narr2.ordering()); | ||||
|         assertEquals(narr2.meanNumber().doubleValue(), 0.0, 0.05); | ||||
|         assertEquals(0.0, narr2.meanNumber().doubleValue(), 0.05); | ||||
| 
 | ||||
|         INDArray narr3 = Nd4j.randn('c', new int[] {rows, cols, dim2}); | ||||
|         assertArrayEquals(new long[] {rows, cols, dim2}, narr3.shape()); | ||||
| @ -97,7 +97,7 @@ public class RngTests extends BaseNd4jTest { | ||||
|         INDArray narr4 = Nd4j.randn('f', new int[] {rows, cols, dim2}); | ||||
|         assertArrayEquals(new long[] {rows, cols, dim2}, narr4.shape()); | ||||
|         assertEquals('f', narr4.ordering()); | ||||
|         assertEquals(narr4.meanNumber().doubleValue(), 0.0, 0.05); | ||||
|         assertEquals(0.0, narr4.meanNumber().doubleValue(), 0.05); | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|  | ||||
| @ -29,7 +29,6 @@ import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftMaxDerivative; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.linalg.factory.Nd4jBackend; | ||||
| import org.nd4j.linalg.indexing.BooleanIndexing; | ||||
| @ -159,7 +158,6 @@ public class CrashTest extends BaseNd4jTest { | ||||
| 
 | ||||
|         // logisoftmax, softmax & softmax derivative | ||||
|         Nd4j.getExecutioner().exec((CustomOp) new SoftMax(x)); | ||||
|         Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(x)); | ||||
|         Nd4j.getExecutioner().exec((CustomOp) new LogSoftMax(x)); | ||||
| 
 | ||||
| 
 | ||||
|  | ||||
| @ -441,8 +441,8 @@ public class BooleanIndexingTest extends BaseNd4jTest { | ||||
|         Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ANY_PANIC); | ||||
|         NativeOpsHolder.getInstance().getDeviceNativeOps().enableDebugMode(true); | ||||
|         INDArray arr = Nd4j.linspace(1,4,4, Nd4j.dataType()).reshape(2,2); | ||||
|         INDArray filtered = BooleanIndexing.chooseFrom(new INDArray[]{arr},Arrays.asList(2.0), Collections.emptyList(),new GreaterThan()); | ||||
|         assertEquals(4,filtered.length()); | ||||
|         INDArray filtered = BooleanIndexing.chooseFrom(new INDArray[]{arr}, Arrays.asList(2.0), Collections.emptyList(),new GreaterThan()); | ||||
|         assertEquals(2, filtered.length()); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| @ -450,7 +450,7 @@ public class BooleanIndexingTest extends BaseNd4jTest { | ||||
|     public void testChooseGreaterThanZero() { | ||||
|         INDArray zero = Nd4j.linspace(0,4,4, Nd4j.dataType()); | ||||
|         INDArray filtered = BooleanIndexing.chooseFrom(new INDArray[]{zero},Arrays.asList(0.0), Collections.emptyList(),new GreaterThan()); | ||||
|         assertEquals(3,filtered.length()); | ||||
|         assertEquals(3, filtered.length()); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|  | ||||
| @ -331,7 +331,7 @@ public class MixedDataTypesTests extends BaseNd4jTest { | ||||
|         assertArrayEquals(exp, arr); | ||||
|     } | ||||
| 
 | ||||
|     @Test(expected = IllegalArgumentException.class) | ||||
|     @Test(expected = RuntimeException.class) | ||||
|     public void testTypesValidation_3() { | ||||
|         val arrayX = Nd4j.create(new int[]{1, 2, 3, 4}, new  long[]{4}, DataType.INT); | ||||
| 
 | ||||
|  | ||||
| @ -195,47 +195,6 @@ public class DerivativeTests extends BaseNd4jTest { | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testSoftMaxDerivative() { | ||||
|           Random r = new Random(12345L); | ||||
| 
 | ||||
|         int[] mb = new int[] {10, 2, 1}; | ||||
|         for (int minibatch : mb) { | ||||
|             System.out.println("Minibatch size: " + minibatch); | ||||
|             INDArray z = Nd4j.zeros(minibatch, 5); | ||||
|             double[][] in = new double[minibatch][5]; | ||||
|             double[][] softmax = new double[minibatch][5]; | ||||
|             double[][] expOut = new double[minibatch][5]; | ||||
|             for (int i = 0; i < minibatch; i++) { | ||||
|                 double rowSumExp = 0.0; | ||||
|                 for (int j = 0; j < 5; j++) { | ||||
|                     in[i][j] = 10 * r.nextDouble(); | ||||
|                     z.putScalar(new int[] {i, j}, in[i][j]); | ||||
|                     rowSumExp += FastMath.exp(in[i][j]); | ||||
|                 } | ||||
|                 for (int j = 0; j < 5; j++) { | ||||
|                     softmax[i][j] = FastMath.exp(in[i][j]) / rowSumExp; | ||||
|                     expOut[i][j] = softmax[i][j] * (1.0 - softmax[i][j]); | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             INDArray sm = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(z.dup()))[0]; | ||||
|             INDArray zPrime = Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(z))[0]; | ||||
|             System.out.println(Arrays.toString(sm.data().asDouble())); | ||||
|             System.out.println(Arrays.toString(zPrime.data().asDouble())); | ||||
|             assertNotEquals(sm, zPrime); | ||||
| 
 | ||||
|             for (int i = 0; i < minibatch; i++) { | ||||
|                 for (int j = 0; j < 5; j++) { | ||||
|                     double relError = Math.abs(expOut[i][j] - zPrime.getDouble(i, j)) | ||||
|                                     / (Math.abs(expOut[i][j]) + Math.abs(zPrime.getDouble(i, j))); | ||||
|                     //                    System.out.println("Error: " + relError); | ||||
|                     assertTrue(relError < REL_ERROR_TOLERANCE); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Test | ||||
|     public void testSoftPlusDerivative() { | ||||
| @ -384,180 +343,6 @@ public class DerivativeTests extends BaseNd4jTest { | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void softmaxsimpleLossTest() { | ||||
|         /* | ||||
|             Softmax derivative is correct if it is standalone | ||||
|             But when we are applying it in the chain rule the current derivative function is incomplete. | ||||
|             For this test, I am assuming that the function off interest is just MSE | ||||
|             What the fix is: | ||||
|             We need the derivative of a softmax needs to return a rank 2 matrix. | ||||
|             Right now we get only the diagonal elements of this matrix | ||||
|             http://stats.stackexchange.com/questions/79454/softmax-layer-in-a-neural-network | ||||
|          */ | ||||
|         //random array represeting preout | ||||
|         INDArray X = Nd4j.rand(1, 2); | ||||
|         //preout transformed to y_hat with softmax | ||||
|         INDArray YHat = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(X.dup()))[0]; | ||||
| 
 | ||||
|         //hard coding something to construct a function with, using MSE | ||||
|         INDArray Y = Nd4j.create(new double[][] {{0.123, 1 - 0.123}}); | ||||
| 
 | ||||
|         //This is the MSE now | ||||
|         double lossHere = Transforms.pow(Y.sub(YHat), 2).sumNumber().doubleValue(); | ||||
| 
 | ||||
|         INDArray softmaxDer = Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(X.dup()))[0]; | ||||
| 
 | ||||
|         //the way we apply the chain rule now is 2*(y-yhat)*softmaxder | ||||
|         INDArray dLdY = Y.sub(YHat).mul(-2); | ||||
|         INDArray currentGradient = dLdY.mul(softmaxDer); | ||||
| 
 | ||||
|         //what I think we should be doing | ||||
|         // we have x0, x1 -> y0,y1 | ||||
|         //we need the derivatives of the output of the softmax wrt every input (x0,x1) | ||||
|         // we only have dy0/dx0 and dy1/dx1 | ||||
|         // we also need dy0/dx1 and dy1/dx0 | ||||
|         // the below is the chain rule in calc applied when L is a function of y0,y1; y0 and y1 are in turn functions of BOTH (x0 and x1) | ||||
|         // dL/dx0 = (dl/dy0) * (dy0/dx0) + (dL/dy1) * (dy1/dx0) | ||||
|         // dL/dx1 = (dl/dy0) * (dy0/dx1) + (dL/dy1) * (dy1/dx1) | ||||
|         // worked it out on paper and googled it (should have googled first, gave formula from link above) | ||||
|         // dy0/dx0 = y0*(1-y0) = y0*y1 | ||||
|         // dy1/dx0 = -y1*(1-y1) = -y0*y1 | ||||
|         // dy0/dx1 = -y0*(1-y0) = -y0*y1 | ||||
|         // dy1/dx1 = y1*(1-y1) = y0*y1 | ||||
|         //[ dL/dy0 dL/dy1] [[dy0/dx0 dy1/dx0] [dy0/dx1 dy1/dx1]] | ||||
|         double y0y1 = softmaxDer.getDouble(0, 0); | ||||
|         //hack but this is what we need to implement, straightforward here but complicated for >2 | ||||
|         //INDArray mysoftmaxDer = Nd4j.create(new double[][] {{y0y1,y0y1*-1},{-1*y0y1,y0y1}}); | ||||
|         INDArray mysoftmaxDer = correctSoftmax(X); | ||||
|         INDArray myGradient = mysoftmaxDer.mulRowVector(dLdY).sum(1); | ||||
| 
 | ||||
|         double epsilon = 0.0001; | ||||
|         INDArray Xiplus, Ximinus; | ||||
|         INDArray YHatplus, YHatminus; | ||||
|         double lossplus, lossminus; | ||||
| 
 | ||||
|         INDArray numGradient = Nd4j.zeros(1, 2); | ||||
| 
 | ||||
|         for (int i = 0; i < 2; i++) { | ||||
|             /* change X one value one at a time */ | ||||
| 
 | ||||
|             // +epsilon | ||||
|             double x = X.getDouble(0, i); | ||||
|             Xiplus = X.dup(); | ||||
|             Xiplus.put(0, i, x + epsilon); | ||||
|             YHatplus = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(Xiplus.dup()))[0]; | ||||
|             lossplus = Transforms.pow(Y.sub(YHatplus), 2).sumNumber().doubleValue(); | ||||
| 
 | ||||
|             // -epsilon | ||||
|             Ximinus = X.dup(); | ||||
|             Ximinus.put(0, i, x - epsilon); | ||||
|             YHatminus = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(Ximinus.dup()))[0]; | ||||
|             lossminus = Transforms.pow(Y.sub(YHatminus), 2).sumNumber().doubleValue(); | ||||
| 
 | ||||
|             double gradienti = (lossplus - lossminus) / (2 * epsilon); | ||||
|             numGradient.put(0, i, gradienti); | ||||
|         } | ||||
|         System.out.println("========================="); | ||||
|         System.out.println("NUMERICAL:"); | ||||
|         System.out.println(numGradient); | ||||
|         System.out.println("\nCURRENTLY:"); | ||||
|         System.out.println(currentGradient); | ||||
|         System.out.println("\nMY GRADIENT:"); | ||||
|         System.out.println(myGradient + "\n"); | ||||
|         System.out.println( | ||||
|                         "Because of the nature of the derivative of the softmax for length = 2, our current method will make it off by a factor of 2"); | ||||
|         System.out.println("========================="); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Test | ||||
|     public void softmaxsimplelongerlengthLossTest() { | ||||
|         /* | ||||
|             Read comments in earlier test for length = 2 | ||||
|          */ | ||||
|         //random array represeting preout | ||||
|         int someLength = 7; | ||||
| 
 | ||||
|         INDArray X = Nd4j.rand(1, someLength); | ||||
|         //preout transformed to y_hat with softmax | ||||
|         INDArray YHat = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(X.dup()))[0]; | ||||
| 
 | ||||
|         //hard coding something to construct a function with, using MSE | ||||
|         INDArray temp = Nd4j.rand(1, someLength); | ||||
|         INDArray Y = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(temp))[0]; | ||||
| 
 | ||||
|         //This is the MSE now | ||||
|         double lossHere = Transforms.pow(Y.sub(YHat), 2).sumNumber().doubleValue(); | ||||
| 
 | ||||
|         INDArray softmaxDer = Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(X.dup()))[0]; | ||||
| 
 | ||||
|         //the way we apply the chain rule now is 2*(y-yhat)*softmaxder | ||||
|         INDArray dLdY = Y.sub(YHat).mul(-2); | ||||
|         INDArray currentGradient = dLdY.mul(softmaxDer); | ||||
| 
 | ||||
|         INDArray mysoftmaxDer = correctSoftmax(X); | ||||
|         INDArray myGradient = mysoftmaxDer.mulRowVector(dLdY).sum(1); | ||||
| 
 | ||||
|         double epsilon = 0.0001; | ||||
|         INDArray Xiplus, Ximinus; | ||||
|         INDArray YHatplus, YHatminus; | ||||
|         double lossplus, lossminus; | ||||
| 
 | ||||
|         INDArray numGradient = Nd4j.zeros(1, someLength); | ||||
| 
 | ||||
|         for (int i = 0; i < someLength; i++) { | ||||
|             /* change X one value one at a time */ | ||||
| 
 | ||||
|             // +epsilon | ||||
|             double x = X.getDouble(0, i); | ||||
|             Xiplus = X.dup(); | ||||
|             Xiplus.put(0, i, x + epsilon); | ||||
|             YHatplus = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(Xiplus.dup()))[0]; | ||||
|             lossplus = Transforms.pow(Y.sub(YHatplus), 2).sumNumber().doubleValue(); | ||||
| 
 | ||||
|             // -epsilon | ||||
|             Ximinus = X.dup(); | ||||
|             Ximinus.put(0, i, x - epsilon); | ||||
|             YHatminus = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(Ximinus.dup()))[0]; | ||||
|             lossminus = Transforms.pow(Y.sub(YHatminus), 2).sumNumber().doubleValue(); | ||||
| 
 | ||||
|             double gradienti = (lossplus - lossminus) / (2 * epsilon); | ||||
|             numGradient.put(0, i, gradienti); | ||||
|         } | ||||
|         System.out.println("========================="); | ||||
|         System.out.println("NUMERICAL GRADIENT:"); | ||||
|         System.out.println(new NDArrayStrings(6).format(numGradient).toString()); | ||||
|         System.out.println("\nANALYTIC USING EXISTING SOFTMAX DER:"); | ||||
|         System.out.println(new NDArrayStrings(6).format(currentGradient).toString()); | ||||
|         System.out.println("\nGRADIENT USING MY VERSION OF SOFTMAX DER:"); | ||||
|         System.out.println(new NDArrayStrings(6).format(myGradient).toString()); | ||||
|         System.out.println("========================="); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     public static INDArray correctSoftmax(INDArray X) { | ||||
|         // this is only for X a row vector | ||||
|         // should return rank 2 matrix diagonal elements are pi*(1-pi) | ||||
|         //rest are -pi*pj | ||||
|         INDArray p = Nd4j.getExecutioner().exec((CustomOp) new SoftMax(X.dup()))[0]; | ||||
|         INDArray pCol = p.dup().transpose(); | ||||
|         INDArray pipj = pCol.mmul(p); | ||||
|         pipj.muli(-1); | ||||
| 
 | ||||
|         //so now pipj is correct except for the diagonal elements | ||||
|         // which by the way is what our current softmax der gives us | ||||
|         INDArray diagp = Nd4j.getExecutioner().exec((CustomOp) new SoftMaxDerivative(X.dup()))[0]; | ||||
| 
 | ||||
| 
 | ||||
|         //ugly for loop to correct diag elements | ||||
|         for (int i = 0; i < X.length(); i++) { | ||||
|             pipj.put(i, i, diagp.getDouble(0, i)); | ||||
|         } | ||||
| 
 | ||||
|         return pipj; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public char ordering() { | ||||
|         return 'f'; | ||||
|  | ||||
| @ -50,9 +50,10 @@ import org.nd4j.linalg.api.ops.impl.scalar.ScalarReverseSubtraction; | ||||
| import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarGreaterThan; | ||||
| import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarLessThan; | ||||
| import org.nd4j.linalg.api.ops.impl.summarystats.Variance; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.Histogram; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.HistogramFixedWidth; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.floating.Histogram; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.AddOp; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.OldMulOp; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.strict.Exp; | ||||
| @ -876,14 +877,14 @@ public class OpExecutionerTestsC extends BaseNd4jTest { | ||||
|     @Test | ||||
|     public void testHistogram1() { | ||||
|         INDArray x = Nd4j.linspace(1, 1000, 100000, DataType.DOUBLE); | ||||
|         INDArray z = Nd4j.zeros(DataType.DOUBLE,new long[]{20}); | ||||
|         INDArray z = Nd4j.zeros(DataType.LONG,new long[]{20}); | ||||
| 
 | ||||
|         INDArray xDup = x.dup(); | ||||
|         INDArray zDup = z.dup(); | ||||
| 
 | ||||
|         INDArray zExp = Nd4j.create(DataType.DOUBLE, 20).assign(5000); | ||||
|         INDArray zExp = Nd4j.create(DataType.LONG, 20).assign(5000); | ||||
| 
 | ||||
|         Histogram histogram = new Histogram(x, z); | ||||
|         val histogram = new Histogram(x, z); | ||||
| 
 | ||||
|         Nd4j.getExecutioner().exec(histogram); | ||||
| 
 | ||||
| @ -899,16 +900,13 @@ public class OpExecutionerTestsC extends BaseNd4jTest { | ||||
|     public void testHistogram2() { | ||||
|         INDArray x = Nd4j.create(new float[] {0f, 0f, 0f, 5f, 5f, 5f, 10f, 10f, 10f}); | ||||
| 
 | ||||
| 
 | ||||
|         INDArray xDup = x.dup(); | ||||
| 
 | ||||
|         INDArray zExp = Nd4j.zeros(DataType.FLOAT, 10).putScalar(0, 3f).putScalar(5, 3f).putScalar(9, 3f); | ||||
|         INDArray zExp = Nd4j.zeros(DataType.LONG, 10).putScalar(0, 3).putScalar(5, 3).putScalar(9, 3); | ||||
| 
 | ||||
|         Histogram histogram = new Histogram(x, 10); | ||||
|         val histogram = new Histogram(x, 10); | ||||
| 
 | ||||
|         Nd4j.getExecutioner().exec(histogram); | ||||
| 
 | ||||
|         INDArray z = histogram.z(); | ||||
|         val z = Nd4j.getExecutioner().exec(histogram)[0]; | ||||
| 
 | ||||
|         assertEquals(xDup, x); | ||||
| 
 | ||||
|  | ||||
| @ -19,6 +19,7 @@ package org.nd4j.linalg.serde; | ||||
| import lombok.Data; | ||||
| import lombok.EqualsAndHashCode; | ||||
| import lombok.NoArgsConstructor; | ||||
| import lombok.val; | ||||
| import org.junit.Test; | ||||
| import org.nd4j.linalg.BaseNd4jTest; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| @ -58,7 +59,7 @@ public class JsonSerdeTests extends BaseNd4jTest { | ||||
|                 Nd4j.getRandom().setSeed(12345); | ||||
|                 INDArray in = Nd4j.rand(DataType.DOUBLE, 3, 4).muli(20).subi(10); | ||||
| 
 | ||||
|                 ObjectMapper om = new ObjectMapper(); | ||||
|                 val om = new ObjectMapper(); | ||||
| 
 | ||||
|                 for (DataType dt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF, DataType.LONG, DataType.INT, DataType.SHORT, | ||||
|                         DataType.BYTE, DataType.UBYTE, DataType.BOOL, DataType.UTF8}) { | ||||
|  | ||||
| @ -673,7 +673,7 @@ public class BasicWorkspaceTests extends BaseNd4jTest { | ||||
| 
 | ||||
|         workspace.initializeWorkspace(); | ||||
|         long reqMemory = 12 * Nd4j.sizeOfDataType(arrayCold.dataType()); | ||||
|         assertEquals(reqMemory + reqMemory % 8 + Nd4j.sizeOfDataType(DOUBLE), workspace.getCurrentSize()); | ||||
|         assertEquals(reqMemory + reqMemory % 8, workspace.getCurrentSize()); | ||||
| 
 | ||||
| 
 | ||||
|         log.info("-----------------------"); | ||||
| @ -692,7 +692,7 @@ public class BasicWorkspaceTests extends BaseNd4jTest { | ||||
| 
 | ||||
|             array.addi(1.0); | ||||
| 
 | ||||
|             assertEquals(reqMem + reqMem % 8 + Nd4j.sizeOfDataType(DOUBLE), workspace.getPrimaryOffset()); | ||||
|             assertEquals(reqMem + reqMem % 8, workspace.getPrimaryOffset()); | ||||
| 
 | ||||
|             assertEquals("Failed on iteration " + x, 10, array.sumNumber().doubleValue(), 0.01); | ||||
| 
 | ||||
| @ -746,7 +746,7 @@ public class BasicWorkspaceTests extends BaseNd4jTest { | ||||
| 
 | ||||
|         INDArray dup = array.dup(); | ||||
| 
 | ||||
|         assertEquals((reqMemory + reqMemory % 8) * 2 + Nd4j.sizeOfDataType(DOUBLE), workspace.getPrimaryOffset()); | ||||
|         assertEquals((reqMemory + reqMemory % 8) * 2, workspace.getPrimaryOffset()); | ||||
| 
 | ||||
|         assertEquals(5, dup.sumNumber().doubleValue(), 0.01); | ||||
| 
 | ||||
|  | ||||
| @ -91,7 +91,7 @@ public class DebugModeTests extends BaseNd4jTest { | ||||
|             assertEquals(0, ws.getDeviceOffset()); | ||||
| 
 | ||||
|             // array buffer should be spilled now | ||||
|             assertEquals(10 * 10 * Nd4j.sizeOfDataType(DataType.DOUBLE) + Nd4j.sizeOfDataType(DataType.DOUBLE), ws.getSpilledSize()); | ||||
|             assertEquals(10 * 10 * Nd4j.sizeOfDataType(DataType.DOUBLE), ws.getSpilledSize()); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| @ -118,7 +118,7 @@ public class DebugModeTests extends BaseNd4jTest { | ||||
|             assertEquals(0, ws.getDeviceOffset()); | ||||
| 
 | ||||
|             // array buffer should be spilled now | ||||
|             assertEquals(10 * 10 * Nd4j.sizeOfDataType(DataType.DOUBLE) + Nd4j.sizeOfDataType(DataType.DOUBLE), ws.getSpilledSize()); | ||||
|             assertEquals(10 * 10 * Nd4j.sizeOfDataType(DataType.DOUBLE), ws.getSpilledSize()); | ||||
|         } | ||||
| 
 | ||||
|         try (val ws = (Nd4jWorkspace) Nd4j.getWorkspaceManager().getAndActivateWorkspace(basicConfig, "R_119_1992")) { | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user