/******************************************************************************* * Copyright (c) 2015-2019 Skymind, Inc. * * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the * License for the specific language governing permissions and limitations * under the License. * * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ // // Created by saudet on 8/30/2018. // #ifndef LIBND4J_MKLDNNSTREAM_H #define LIBND4J_MKLDNNSTREAM_H #ifndef __STANDALONE_BUILD__ #include "config.h" #endif #ifdef HAVE_MKLDNN #include namespace nd4j { class MKLDNNStream { protected: std::string _opName; std::vector _inputs; std::vector _outputs; std::vector _floatArguments; std::vector _intArguments; mkldnn::engine _engine = mkldnn::engine(mkldnn::engine::cpu, 0); std::vector _memory; std::vector _operations; public: template static bool isSupported() { return typeid(X) == typeid(float) && typeid(Y) == typeid(float); } static bool isSupported(const std::vector &arrays) { for (auto i = arrays.begin(); i != arrays.end(); i++) { if (*i != nullptr && (*i)->dataType() != nd4j::DataType::FLOAT32) { return false; } } return true; } MKLDNNStream(const std::string &opName) : _opName(opName) { } bool checkAndReset(const std::vector &inputs, const std::vector &outputs, const std::vector &floatArguments, const std::vector &intArguments) { if (inputs != _inputs || outputs != _outputs || floatArguments != _floatArguments || intArguments != _intArguments) { _inputs = inputs; _outputs = outputs; _floatArguments = floatArguments; _intArguments = intArguments; _operations.clear(); _memory.clear(); return true; } return false; } const mkldnn::engine &getEngine() { return _engine; } void setEngine(const mkldnn::engine &engine) { _engine = engine; } const std::vector &getMemory() { return _memory; } void setMemory(const std::vector &memory) { _memory = memory; } void addMemory(const mkldnn::memory &memory) { _memory.push_back(memory); } const std::vector &getOperations() { return _operations; } void setOperations(const std::vector &operations) { _operations = operations; } void addOperation(const mkldnn::primitive &operation) { _operations.push_back(operation); } bool submitAndWait(mkldnn::stream::kind kind = mkldnn::stream::kind::eager) { nd4j_debug("Executing %s with MKL-DNN\n", _opName.c_str()); // need to create a new one because already executed streams become unusable mkldnn::stream stream(kind); return stream.submit(_operations).wait(); } }; } #endif #endif //LIBND4J_MKLDNNSTREAM_H