parent
935cbf5df3
commit
72cb5936f1
|
@ -806,7 +806,6 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr
|
||||||
void clipByGlobalNorm_(nd4j::LaunchContext * context, std::vector<NDArray*> const& inputs, double clipNorm, nd4j::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace) {
|
void clipByGlobalNorm_(nd4j::LaunchContext * context, std::vector<NDArray*> const& inputs, double clipNorm, nd4j::memory::Workspace* workspace, std::vector<NDArray*>& outputs, bool isInplace) {
|
||||||
NDArray globalNorm = NDArrayFactory::create<T>(0, inputs[0]->getContext()); //sqrt(sum([l2norm(t)**2 for t in t_list]))
|
NDArray globalNorm = NDArrayFactory::create<T>(0, inputs[0]->getContext()); //sqrt(sum([l2norm(t)**2 for t in t_list]))
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR
|
|
||||||
for (auto i = 0; i < inputs.size(); i++) {
|
for (auto i = 0; i < inputs.size(); i++) {
|
||||||
auto input = inputs[i];
|
auto input = inputs[i];
|
||||||
auto l2norm = input->reduceNumber(reduce::Norm2);
|
auto l2norm = input->reduceNumber(reduce::Norm2);
|
||||||
|
@ -818,7 +817,6 @@ void clipByNormBP(nd4j::LaunchContext* context, const NDArray& input, const NDAr
|
||||||
globalNorm.syncToHost();
|
globalNorm.syncToHost();
|
||||||
const T factor = clipNorm / globalNorm.e<T>(0);
|
const T factor = clipNorm / globalNorm.e<T>(0);
|
||||||
|
|
||||||
PRAGMA_OMP_PARALLEL_FOR
|
|
||||||
for (size_t e = 0; e < inputs.size(); e++) {
|
for (size_t e = 0; e < inputs.size(); e++) {
|
||||||
// all-reduce
|
// all-reduce
|
||||||
auto input = inputs[e];
|
auto input = inputs[e];
|
||||||
|
|
Loading…
Reference in New Issue