/******************************************************************************* * Copyright (c) 2015-2019 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // Created by saudet on 8/30/2018. // #ifndef LIBND4J_MKLDNNSTREAM_H #define LIBND4J_MKLDNNSTREAM_H #if !defined(__STANDALONE_BUILD__) #include "config.h" #endif #if defined(HAVE_MKLDNN) namespace sd { class MKLDNNStream { protected: std::string _opName; std::vector _inputs; std::vector _outputs; std::vector _floatArguments; std::vector _intArguments; public: template static bool isSupported() { // FIXME: strict float support doesn't work anymore return typeid(X) == typeid(float) && typeid(Y) == typeid(float); } static bool isSupported(const std::vector &arrays) { // FIXME: strict float support doesn't work anymore for (auto v:arrays) { if (v != nullptr && v->dataType() != sd::DataType::FLOAT32) { return false; } } return true; } explicit MKLDNNStream(const std::string &opName) : _opName(opName) { } bool checkAndReset(const std::vector &inputs, const std::vector &outputs, const std::vector &floatArguments, const std::vector &intArguments) { if (inputs != _inputs || outputs != _outputs || floatArguments != _floatArguments || intArguments != _intArguments) { _inputs = inputs; _outputs = outputs; _floatArguments = floatArguments; _intArguments = intArguments; return true; } return false; } }; } #endif #endif //LIBND4J_MKLDNNSTREAM_H