/******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // @author raver119@gmail.com // #include #include #include #include namespace nd4j { namespace ops { namespace helpers { template static void dropoutSimple(NDArray const* input, NDArray* output, double probValue, int seed) { nd4j::graph::RandomGenerator nodeRng(3019L, seed); int inLen = input->lengthOf(); PRAGMA_OMP_PARALLEL_FOR_IF(inLen > Environment::getInstance()->elementwiseThreshold()) for (Nd4jLong e = 0; e < inLen; ++e) { float val = nodeRng.relativeT(e, T(0.f), T(1.f)); if (val < probValue) output->p(e, input->e(e) / probValue); } } BUILD_SINGLE_TEMPLATE(template void dropoutSimple, (NDArray const* input, NDArray* output, double probValue, int seed), FLOAT_TYPES); template int dropOutFunctor_(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) { //NativeOps native; //nd4j::graph::RandomGenerator nodeRng(seed); //static int dropOutFunctor_(nd4j::random::RandomBuffer* rng, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) { //NativeOps native; //native.reSeedBuffer(nullptr, (long)seed, rng); //if (newRng ) if (reduceShape == nullptr){ dropoutSimple(input, output, probValue, seed); } else { REQUIRE_TRUE(reduceShape->lengthOf() <= input->rankOf(), 0, "dropout: Noise shape should be fittable to input"); std::vector dims(reduceShape->lengthOf()); bool fit = true; for( int i = 0; fit && (i < dims.size()); i++ ) { dims[i] = reduceShape->e(i); for (int e = 0; fit && (e < input->rankOf()); ++e) if (input->sizeAt(e) % dims[i]) { fit = false; } } // check dims to fit input REQUIRE_TRUE(fit, 0, "dropout: Noise shape should fit to input rank."); std::unique_ptr chunk(new NDArray('c', dims, output->dataType(), output->getContext())); chunk->assign(1.f); //chunk->applyRandom>(rng, nullptr, chunk.get(), &probValue); //NativeOpExecutioner::execRandom(random::DropOutInverted, rng, chunk->buffer(), chunk->shapeInfo(), chunk->buffer(), chunk->shapeInfo(), &prob); dropoutSimple(chunk.get(), chunk.get(), probValue, seed); // broadcast chunk to full matrix std::unique_ptr dropOutMultiplier(new NDArray(*input)); dropOutMultiplier->assign(1.f); *dropOutMultiplier += *chunk; output->assign(*input * *dropOutMultiplier); //input->applyPairwiseTransform(pairwise::Multiply, dropOutMultiplier.get(), output, nullptr); } return Status::OK(); } int dropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue) { auto xType = input->dataType(); BUILD_SINGLE_SELECTOR(xType, return dropOutFunctor_, (context, input, output, reduceShape, seed, probValue), FLOAT_TYPES); } BUILD_SINGLE_TEMPLATE(template int dropOutFunctor_, (graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue);, FLOAT_TYPES); /////////////////////////////////// backrpopagations /////////////////////////////////////////////// template static int dropOutFunctorBP_(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) { int res = dropOutFunctor(context, input, output, reduceShape, seed, probValue); if (ND4J_STATUS_OK == res) for (Nd4jLong e = 0; e < output->lengthOf(); e++) { if (output->e(e) != 0.f) output->p(e, gradOut->e(e) / probValue); // else (*output)(e) = T(0.f); } return res; } template static int alphaDropOutFunctor_(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { //NativeOps native; //auto rng = context.getRNG(); //native.reSeedBuffer(nullptr, (long)seed, rng); //if (rng == nullptr) // return ND4J_STATUS_BAD_RNG; //T probValueArr[] = {probValue, alpha, alpha1, beta}; //input->template applyRandom>(rng, nullptr, output, probValueArr); nd4j::graph::RandomGenerator nodeRng(3019L, seed); PRAGMA_OMP_PARALLEL_FOR_IF(input->lengthOf() > Environment::getInstance()->elementwiseThreshold()) for (Nd4jLong e = 0; e < input->lengthOf(); ++e) { float randVal = nodeRng.relativeT(e, T(0.f), T(1.f)); float xVal = input->e(e); output->p(e, randVal >= probValue ? alpha * beta + alpha1 : alpha * xVal + alpha1); } return ND4J_STATUS_OK; } template int alphaDropOutFunctorBP_(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { int res = alphaDropOutFunctor(context, input, output, reduceShape, seed, probValue, alpha, alpha1, beta); if (res == ND4J_STATUS_OK) { (*output) *= alpha; (*output) *= (*gradOut); //->applyPairwiseTransform(gradOut, output, nullptr); } return res; } int dropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue) { BUILD_SINGLE_SELECTOR(context.dataType(), return dropOutFunctorBP_, (context, input, gradOut, output, reduceShape, seed, probValue), FLOAT_TYPES); } BUILD_SINGLE_TEMPLATE(template int dropOutFunctorBP_, (graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue), FLOAT_TYPES); int alphaDropOutFunctor(graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { BUILD_SINGLE_SELECTOR(context.dataType(), return alphaDropOutFunctor_, (context, input, output, reduceShape, seed, probValue, alpha, alpha1, beta), FLOAT_TYPES); } BUILD_SINGLE_TEMPLATE(template int alphaDropOutFunctor_, (graph::Context& context, NDArray* input, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta), FLOAT_TYPES); int alphaDropOutFunctorBP(graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta) { BUILD_SINGLE_SELECTOR(context.dataType(), return alphaDropOutFunctorBP_, (context, input, gradOut, output, reduceShape, seed, probValue, alpha, alpha1, beta), FLOAT_TYPES); } BUILD_SINGLE_TEMPLATE(template int alphaDropOutFunctorBP_, (graph::Context& context, NDArray* input, NDArray* gradOut, NDArray* output, NDArray* reduceShape, int seed, double probValue, double alpha, double alpha1, double beta), FLOAT_TYPES); } } }