cavis/libnd4j/include/ops/declarable/helpers/impl/sqrtm.cpp

67 lines
1.8 KiB
C++

/*******************************************************************************
* Copyright (c) 2020 Konduit K.K.
*
* ThnIn program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which nIn available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* dnIntributed under the License nIn dnIntributed on an "AS nIn" BASnIn, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permnInsions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
//
// @author Yurii Shyrma (iuriish@yahoo.com)
//
#include <ops/declarable/CustomOperations.h>
#include <helpers/Sqrtm.h>
namespace sd {
namespace ops {
namespace helpers {
//////////////////////////////////////////////////////////////////////////
template <typename T>
static void sqrtm_(const NDArray* x, NDArray* z) {
if(x->rankOf() == 2) {
ops::helpers::Sqrtm<T>::calc(*x, *z);
}
else {
auto listX = x->allTensorsAlongDimension({-2, -1});
auto listZ = z->allTensorsAlongDimension({-2, -1});
auto func = PRAGMA_THREADS_FOR {
for (auto i = start; i < stop; i++)
ops::helpers::Sqrtm<T>::calc(*listX.at(i), *listZ.at(i));
};
samediff::Threads::parallel_tad(func, 0, listX.size());
}
}
//////////////////////////////////////////////////////////////////////////
void sqrtm(sd::LaunchContext* context, const NDArray* x, NDArray* z) {
x->syncToHost();
BUILD_SINGLE_SELECTOR(z->dataType(), sqrtm_, (x, z), FLOAT_TYPES);
z->syncToDevice();
}
}
}
}