| 
									
										
										
										
											2019-09-11 21:50:28 +03:00
										 |  |  | /*******************************************************************************
 | 
					
						
							|  |  |  |  * Copyright (c) 2015-2018 Skymind, Inc. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * This program and the accompanying materials are made available under the | 
					
						
							|  |  |  |  * terms of the Apache License, Version 2.0 which is available at | 
					
						
							|  |  |  |  * https://www.apache.org/licenses/LICENSE-2.0.
 | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  |  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | 
					
						
							|  |  |  |  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | 
					
						
							|  |  |  |  * License for the specific language governing permissions and limitations | 
					
						
							|  |  |  |  * under the License. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * SPDX-License-Identifier: Apache-2.0 | 
					
						
							|  |  |  |  ******************************************************************************/ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | // @author saudet
 | 
					
						
							|  |  |  | // @author raver119@gmail.com
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include <ops/declarable/PlatformHelper.h>
 | 
					
						
							|  |  |  | #include <ops/declarable/OpRegistrator.h>
 | 
					
						
							|  |  |  | #include <platform_boilerplate.h>
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include <helpers/MKLDNNStream.h>
 | 
					
						
							|  |  |  | #include "mkldnnUtils.h"
 | 
					
						
							|  |  |  | #include <ops/declarable/helpers/convolutions.h>
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 13:23:08 +03:00
										 |  |  | using namespace dnnl; | 
					
						
							| 
									
										
										
										
											2019-09-11 21:50:28 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  | namespace nd4j { | 
					
						
							|  |  |  |     namespace ops { | 
					
						
							|  |  |  |         namespace platforms { | 
					
						
							|  |  |  |             PLATFORM_IMPL(lrn) { | 
					
						
							|  |  |  |                 auto input = INPUT_VARIABLE(0); | 
					
						
							|  |  |  |                 auto output = OUTPUT_VARIABLE(0); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 REQUIRE_TRUE(input->rankOf() == 4, 0, "lrn: Input rank of 4 expected, but got %i instead", | 
					
						
							|  |  |  |                              input->rankOf()); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 double alpha = T_ARG(1); | 
					
						
							|  |  |  |                 double beta = T_ARG(2); | 
					
						
							|  |  |  |                 double bias = T_ARG(0); | 
					
						
							|  |  |  |                 int depth = INT_ARG(0); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 13:23:08 +03:00
										 |  |  |                 dnnl_memory_desc_t empty; | 
					
						
							|  |  |  |                 dnnl::memory::desc lrn_src_md(empty), lrn_dst_md(empty), user_src_md(empty), user_dst_md(empty); | 
					
						
							| 
									
										
										
										
											2019-09-11 21:50:28 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |                 mkldnnUtils::getMKLDNNMemoryDescLrn(input, nullptr, output, &lrn_src_md, nullptr, &lrn_dst_md, | 
					
						
							|  |  |  |                                                         &user_src_md, nullptr, &user_dst_md, input->rankOf() - 1); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 auto lrn_desc = lrn_forward::desc(prop_kind::forward_inference, algorithm::lrn_across_channels, | 
					
						
							|  |  |  |                                                       lrn_src_md, (2 * depth + 1), alpha * (2 * depth + 1), beta, bias); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 auto engine = mkldnnUtils::getEngine(LaunchContext::defaultContext()->engine()); | 
					
						
							| 
									
										
										
										
											2019-11-20 13:23:08 +03:00
										 |  |  |                 dnnl::stream stream(engine); | 
					
						
							| 
									
										
										
										
											2019-09-11 21:50:28 +03:00
										 |  |  |                 auto lrn_prim_desc = lrn_forward::primitive_desc(lrn_desc, engine); | 
					
						
							| 
									
										
										
										
											2019-11-20 13:23:08 +03:00
										 |  |  |                 auto user_src_memory = dnnl::memory(user_src_md, engine, input->buffer()); | 
					
						
							|  |  |  |                 auto user_dst_memory = dnnl::memory(user_dst_md, engine, output->buffer()); | 
					
						
							| 
									
										
										
										
											2019-09-11 21:50:28 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |                 auto lrn_src_memory = user_src_memory; | 
					
						
							|  |  |  |                 if (lrn_prim_desc.src_desc() != user_src_memory.get_desc()) { | 
					
						
							| 
									
										
										
										
											2019-11-20 13:23:08 +03:00
										 |  |  |                     lrn_src_memory = dnnl::memory(lrn_prim_desc.src_desc(), engine); | 
					
						
							| 
									
										
										
										
											2019-09-11 21:50:28 +03:00
										 |  |  |                     reorder(user_src_memory, lrn_src_memory).execute(stream, user_src_memory, lrn_src_memory); | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 auto lrn_dst_memory = user_dst_memory; | 
					
						
							|  |  |  |                 if (lrn_prim_desc.dst_desc() != user_dst_memory.get_desc()) { | 
					
						
							| 
									
										
										
										
											2019-11-20 13:23:08 +03:00
										 |  |  |                     lrn_dst_memory = dnnl::memory(lrn_prim_desc.dst_desc(), engine); | 
					
						
							| 
									
										
										
										
											2019-09-11 21:50:28 +03:00
										 |  |  |                 } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-11-20 13:23:08 +03:00
										 |  |  |                 lrn_forward(lrn_prim_desc).execute(stream, {{DNNL_ARG_SRC, lrn_src_memory}, | 
					
						
							|  |  |  |                                                                 {DNNL_ARG_DST, lrn_dst_memory}}); | 
					
						
							| 
									
										
										
										
											2019-09-11 21:50:28 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |                 if (lrn_prim_desc.dst_desc() != user_dst_memory.get_desc()) { | 
					
						
							|  |  |  |                     reorder(lrn_dst_memory, user_dst_memory).execute(stream, lrn_dst_memory, user_dst_memory); | 
					
						
							|  |  |  |                 } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 stream.wait(); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 return Status::OK(); | 
					
						
							|  |  |  |             }; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             PLATFORM_CHECK(lrn) { | 
					
						
							|  |  |  |                 auto input = INPUT_VARIABLE(0); | 
					
						
							|  |  |  |                 auto output = OUTPUT_VARIABLE(0); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 return block.isUseMKLDNN() && nd4j::MKLDNNStream::isSupported({input, output}); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | } |