cavis/libnd4j/include/helpers/helper_random.h

239 lines
6.9 KiB
C
Raw Normal View History

2021-02-01 13:31:45 +01:00
/* ******************************************************************************
*
2019-06-06 14:21:15 +02:00
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
2021-02-01 13:31:45 +01:00
* See the NOTICE file distributed with this work for additional
* information regarding copyright ownership.
2019-06-06 14:21:15 +02:00
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author raver119@gmail.com
//
#ifndef LIBND4J_HELPER_RANDOM_H
#define LIBND4J_HELPER_RANDOM_H
#ifdef __CUDACC__
#include <curand.h>
#endif
#include <helpers/helper_generator.h>
#ifndef __CUDACC__
#include <mutex>
#endif
namespace sd {
2019-06-06 14:21:15 +02:00
namespace random {
template<typename T>
class RandomHelper {
private:
sd::random::IGenerator *generator;
sd::random::RandomBuffer *buffer;
2019-06-06 14:21:15 +02:00
public:
_CUDA_HD RandomHelper(sd::random::IGenerator *generator) {
2019-06-06 14:21:15 +02:00
this->generator = generator;
this->buffer = generator->getBuffer();
}
_CUDA_HD RandomHelper(sd::random::RandomBuffer *buffer) {
2019-06-06 14:21:15 +02:00
this->buffer = buffer;
}
/**
* This method returns random int in range [0..MAX_INT]
* @return
*/
inline _CUDA_D int nextInt() {
int r = (int) nextUInt();
return r < 0 ? -1 * r : r;
};
inline _CUDA_D uint64_t nextUInt() {
return buffer->getNextElement();
}
/**
* This method returns random int in range [0..to]
* @param to
* @return
*/
inline _CUDA_D int nextInt(int to) {
int r = nextInt();
int m = to - 1;
if ((to & m) == 0) // i.e., bound is a power of 2
r = (int) ((to * (long) r) >> 31);
else {
for (int u = r;
u - (r = u % to) + m < 0;
u = nextInt());
}
return r;
};
/**
* This method returns random int in range [from..to]
* @param from
* @param to
* @return
*/
inline _CUDA_D int nextInt(int from, int to) {
if (from == 0)
return nextInt(to);
return from + nextInt(to - from);
};
/**
* This method returns random T in range of [0..MAX_FLOAT]
* @return
*/
inline _CUDA_D T nextMaxT() {
T rnd = (T) buffer->getNextElement();
return rnd < 0 ? -1 * rnd : rnd;
};
/**
* This method returns random T in range of [0..1]
* @return
*/
inline _CUDA_D T nextT() {
return (T) nextUInt() / (T) sd::DataTypeUtils::max<Nd4jULong>();
2019-06-06 14:21:15 +02:00
}
/**
* This method returns random T in range of [0..to]
* @param to
* @return
*/
inline _CUDA_D T nextT(T to) {
if (to == (T) 1.0f)
return nextT();
return nextT((T) 0.0f, to);
};
/**
* This method returns random T in range [from..to]
* @param from
* @param to
* @return
*/
inline _CUDA_D T nextT(T from, T to) {
return from + (nextT() * (to - from));
}
inline _CUDA_D uint64_t relativeUInt(Nd4jLong index) {
return buffer->getElement(index);
}
/**
* relative methods are made as workaround for lock-free concurrent execution
*/
inline _CUDA_D int relativeInt(Nd4jLong index) {
return (int) (relativeUInt(index) % (sd::DataTypeUtils::max<uint32_t>() + 1));
2019-06-06 14:21:15 +02:00
}
/**
* This method returns random int within [0..to]
*
* @param index
* @param to
* @return
*/
inline _CUDA_D int relativeInt(Nd4jLong index, int to) {
int rel = relativeInt(index);
return rel % to;
}
/**
* This method returns random int within [from..to]
*
* @param index
* @param to
* @param from
* @return
*/
inline int _CUDA_D relativeInt(Nd4jLong index, int to, int from) {
if (from == 0)
return relativeInt(index, to);
return from + relativeInt(index, to - from);
}
/**
* This method returns random T within [0..1]
*
* @param index
* @return
*/
inline _CUDA_D T relativeT(Nd4jLong index) {
if (sizeof(T) < 4) {
// FIXME: this is fast hack for short types, like fp16. This should be improved.
return (T)((float) relativeUInt(index) / (float) sd::DataTypeUtils::max<uint32_t>());
} else return (T) relativeUInt(index) / (T) sd::DataTypeUtils::max<uint32_t>();
2019-06-06 14:21:15 +02:00
}
/**
* This method returns random T within [0..to]
*
* @param index
* @param to
* @return
*/
inline _CUDA_D T relativeT(Nd4jLong index, T to) {
if (to == (T) 1.0f)
return relativeT(index);
return relativeT(index, (T) 0.0f, to);
}
/**
* This method returns random T within [from..to]
*
* @param index
* @param from
* @param to
* @return
*/
inline _CUDA_D T relativeT(Nd4jLong index, T from, T to) {
return from + (relativeT(index) * (to - from));
}
/**
* This method skips X elements from buffer
*
* @param numberOfElements number of elements to skip
*/
inline _CUDA_D void rewind(Nd4jLong numberOfElements) {
buffer->rewindH(numberOfElements);
}
};
}
}
#endif //LIBND4J_HELPER_RANDOM_H