parent
24e43e9856
commit
fb139fcac6
|
@ -807,7 +807,7 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr
|
||||||
BUILD_SINGLE_TEMPLATE(template void clipByNorm_, (nd4j::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace), FLOAT_TYPES);
|
BUILD_SINGLE_TEMPLATE(template void clipByNorm_, (nd4j::LaunchContext * context, NDArray& input, NDArray& output, const std::vector<int>& dimensions, const NDArray& clipNorm, const bool isInplace), FLOAT_TYPES);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void clipByGlobalNorm_(nd4j::LaunchContext * context, std::vector<NDArray*> const& inputs, double clipNorm, nd4j::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace) {
|
void clipByGlobalNorm_(nd4j::LaunchContext * context, std::vector<NDArray*> const& inputs, double clipNorm, nd4j::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace) {
|
||||||
NDArray globalNorm = NDArrayFactory::create<T>(0, inputs[0]->getContext()); //sqrt(sum([l2norm(t)**2 for t in t_list]))
|
NDArray globalNorm = NDArrayFactory::create<T>(0, inputs[0]->getContext()); //sqrt(sum([l2norm(t)**2 for t in t_list]))
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR
|
PRAGMA_OMP_PARALLEL_FOR
|
||||||
|
|
Loading…
Reference in New Issue