| 
									
										
										
										
											2021-02-01 21:31:45 +09:00
										 |  |  | /* ******************************************************************************
 | 
					
						
							|  |  |  |  * | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |  * | 
					
						
							|  |  |  |  * This program and the accompanying materials are made available under the | 
					
						
							|  |  |  |  * terms of the Apache License, Version 2.0 which is available at | 
					
						
							|  |  |  |  * https://www.apache.org/licenses/LICENSE-2.0.
 | 
					
						
							|  |  |  |  * | 
					
						
							| 
									
										
										
										
											2021-02-01 21:31:45 +09:00
										 |  |  |  *  See the NOTICE file distributed with this work for additional | 
					
						
							|  |  |  |  *  information regarding copyright ownership. | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |  * Unless required by applicable law or agreed to in writing, software | 
					
						
							|  |  |  |  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | 
					
						
							|  |  |  |  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | 
					
						
							|  |  |  |  * License for the specific language governing permissions and limitations | 
					
						
							|  |  |  |  * under the License. | 
					
						
							|  |  |  |  * | 
					
						
							|  |  |  |  * SPDX-License-Identifier: Apache-2.0 | 
					
						
							|  |  |  |  ******************************************************************************/ | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | // Created by raver119 on 01/11/17.
 | 
					
						
							|  |  |  | //
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  | #include <system/op_boilerplate.h>
 | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | #if NOT_EXCLUDED(OP_onehot)
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #include <ops/declarable/CustomOperations.h>
 | 
					
						
							|  |  |  | #include <helpers/ShapeUtils.h>
 | 
					
						
							|  |  |  | #include <ops/declarable/helpers/one_hot.h>
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  | namespace sd { | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |     namespace ops { | 
					
						
							|  |  |  |         CUSTOM_OP_IMPL(onehot, 1, 1, false, -2, -2) { | 
					
						
							|  |  |  |             auto input = INPUT_VARIABLE(0); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             // FIXME: double?
 | 
					
						
							|  |  |  |             double on(1.0f); // T_ARG(0);
 | 
					
						
							|  |  |  |             double off(0.0f); //T_ARG(1);
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2021-02-01 14:31:20 +09:00
										 |  |  |             auto axis = -1; //INT_ARG(0);
 | 
					
						
							|  |  |  |             auto depth = -1; //INT_ARG(1);
 | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |             if (block.numI() > 0) | 
					
						
							|  |  |  |                 axis = INT_ARG(0); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if (block.numI() > 1) { | 
					
						
							|  |  |  |                 depth = INT_ARG(1); | 
					
						
							|  |  |  |             } else if (block.width() > 1) { | 
					
						
							|  |  |  |                 depth = INPUT_VARIABLE(1)->e<int>(0); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             REQUIRE_TRUE(depth > 0, 0, "OneHot: depth must be positive value"); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if (block.width() > 2) { | 
					
						
							|  |  |  |                 on = INPUT_VARIABLE(2)->e<double>(0); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 if (block.width() > 3) | 
					
						
							|  |  |  |                     off = INPUT_VARIABLE(3)->e<double>(0); | 
					
						
							|  |  |  |             } else if (block.numT() > 0) { | 
					
						
							|  |  |  |                 on = T_ARG(0); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 if (block.numT() > 1) | 
					
						
							|  |  |  |                     off = T_ARG(1); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             auto output = OUTPUT_VARIABLE(0); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if (axis < 0) | 
					
						
							|  |  |  |                 axis = output->rankOf() + axis; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             helpers::onehot(block.launchContext(), input, output, axis, depth, on, off); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             return Status::OK(); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         DECLARE_SHAPE_FN(onehot) { | 
					
						
							|  |  |  |             auto inShape = inputShape->at(0); | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |             sd::DataType dtype = block.numD() > 0 ? D_ARG(0) : sd::DataType::FLOAT32; | 
					
						
							| 
									
										
										
										
											2020-01-30 10:07:24 +03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |             int depth = -1; | 
					
						
							|  |  |  |             Nd4jLong axis = -1; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if (block.numI() > 0) | 
					
						
							|  |  |  |                 axis = INT_ARG(0); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |              if (block.numI() > 1) { | 
					
						
							|  |  |  |                 depth = INT_ARG(1); | 
					
						
							|  |  |  |             } else if (block.width() > 1) { | 
					
						
							|  |  |  |                 depth = INPUT_VARIABLE(1)->e<int>(0); | 
					
						
							|  |  |  |             } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             REQUIRE_TRUE(depth > 0, 0, "OneHot: depth must be positive value"); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             int rank = shape::rank(inShape); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if (axis < 0) | 
					
						
							|  |  |  |                 axis = rank + 1 + axis; | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             std::vector<Nd4jLong> shape; | 
					
						
							|  |  |  |             for (int e = 0; e < rank; e++) | 
					
						
							|  |  |  |                 shape.push_back(shape::shapeOf(inShape)[e]); | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             shape.insert(shape.begin() + axis, depth); | 
					
						
							| 
									
										
										
										
											2020-06-06 15:26:55 +03:00
										 |  |  |             auto newShape = ConstantShapeHelper::getInstance().createShapeInfo(dtype, 'c', rank + 1, shape.data()); | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  | 
 | 
					
						
							|  |  |  |             return SHAPELIST(newShape); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         DECLARE_TYPES(onehot) { | 
					
						
							|  |  |  |             getOpDescriptor() | 
					
						
							| 
									
										
										
										
											2020-03-02 12:49:41 +03:00
										 |  |  |                     ->setAllowedInputTypes(sd::DataType::ANY) | 
					
						
							| 
									
										
										
										
											2019-06-06 15:21:15 +03:00
										 |  |  |                     ->setAllowedOutputTypes({ALL_FLOATS, ALL_INTS}); | 
					
						
							|  |  |  |         } | 
					
						
							|  |  |  |     } | 
					
						
							|  |  |  | } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | #endif
 |