print_input_processing.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_BINDINGS_PYTHON_PRINT_INPUT_PROCESSING_HPP
14 #define MLPACK_BINDINGS_PYTHON_PRINT_INPUT_PROCESSING_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 #include "get_arma_type.hpp"
18 #include "get_numpy_type.hpp"
19 #include "get_numpy_type_char.hpp"
20 #include "get_cython_type.hpp"
21 #include "strip_type.hpp"
22 #include "wrapper_functions.hpp"
23 
24 namespace mlpack {
25 namespace bindings {
26 namespace python {
27 
31 template<typename T>
33  util::ParamData& d,
34  const size_t indent,
35  const typename std::enable_if<!util::IsStdVector<T>::value>::type* = 0,
36  const typename std::enable_if<!arma::is_arma_type<T>::value>::type* = 0,
37  const typename std::enable_if<!data::HasSerialize<T>::value>::type* = 0,
38  const typename std::enable_if<!std::is_same<T,
39  std::tuple<data::DatasetInfo, arma::mat>>::value>::type* = 0)
40 {
41  // The copy_all_inputs parameter must be handled first, and therefore is
42  // outside the scope of this code.
43  if (d.name == "copy_all_inputs")
44  return;
45 
46  const std::string prefix(indent, ' ');
47 
48  std::string def = "None";
49  if (std::is_same<T, bool>::value)
50  def = "False";
51 
52  // Make sure that we don't use names that are Python keywords.
53  std::string name = GetValidName(d.name);
54 
66  std::cout << prefix << "# Detect if the parameter was passed; set if so."
67  << std::endl;
68  if (!d.required)
69  {
70  if (GetPrintableType<T>(d) == "bool")
71  {
72  std::cout << prefix << "if isinstance(" << name << ", "
73  << GetPrintableType<T>(d) << "):" << std::endl;
74  std::cout << prefix << " if " << name << " is not " << def << ":"
75  << std::endl;
76  }
77  else
78  {
79  std::cout << prefix << "if " << name << " is not " << def << ":"
80  << std::endl;
81  std::cout << prefix << " if isinstance(" << name << ", "
82  << GetPrintableType<T>(d) << "):" << std::endl;
83  }
84 
85  std::cout << prefix << " SetParam[" << GetCythonType<T>(d)
86  << "](p, <const string> '" << d.name << "', ";
87  if (GetCythonType<T>(d) == "string")
88  std::cout << name << ".encode(\"UTF-8\")";
89  else
90  std::cout << name;
91  std::cout << ")" << std::endl;
92  std::cout << prefix << " p.SetPassed(<const string> '" << d.name
93  << "')" << std::endl;
94 
95  // If this parameter is "verbose", then enable verbose output.
96  if (d.name == "verbose")
97  std::cout << prefix << " EnableVerbose()" << std::endl;
98 
99  if (GetPrintableType<T>(d) == "bool")
100  {
101  std::cout << " else:" << std::endl;
102  std::cout << " raise TypeError(" <<"\"'"<< name
103  << "' must have type \'" << GetPrintableType<T>(d)
104  << "'!\")" << std::endl;
105  }
106  else
107  {
108  std::cout << " else:" << std::endl;
109  std::cout << " raise TypeError(" <<"\"'"<< name
110  << "' must have type \'" << GetPrintableType<T>(d)
111  << "'!\")" << std::endl;
112  }
113  }
114  else
115  {
116  if (GetPrintableType<T>(d) == "bool")
117  {
118  std::cout << prefix << "if isinstance(" << name << ", "
119  << GetPrintableType<T>(d) << "):" << std::endl;
120  std::cout << prefix << " if " << name << " is not " << def << ":"
121  << std::endl;
122  }
123  else
124  {
125  std::cout << prefix << "if " << name << " is not " << def << ":"
126  << std::endl;
127  std::cout << prefix << " if isinstance(" << name << ", "
128  << GetPrintableType<T>(d) << "):" << std::endl;
129  }
130 
131  std::cout << prefix << " SetParam[" << GetCythonType<T>(d) << "](p, <const "
132  << "string> '" << d.name << "', ";
133  if (GetCythonType<T>(d) == "string")
134  std::cout << name << ".encode(\"UTF-8\")";
135  else if (GetCythonType<T>(d) == "vector[string]")
136  std::cout << "[i.encode(\"UTF-8\") for i in " << name << "]";
137  else
138  std::cout << name;
139  std::cout << ")" << std::endl;
140  std::cout << prefix << " p.SetPassed(<const string> '"
141  << d.name << "')" << std::endl;
142 
143  if (GetPrintableType<T>(d) == "bool")
144  {
145  std::cout << " else:" << std::endl;
146  std::cout << " raise TypeError(" <<"\"'"<< name
147  << "' must have type \'" << GetPrintableType<T>(d)
148  << "'!\")" << std::endl;
149  }
150  else
151  {
152  std::cout << " else:" << std::endl;
153  std::cout << " raise TypeError(" <<"\"'"<< name
154  << "' must have type \'" << GetPrintableType<T>(d)
155  << "'!\")" << std::endl;
156  }
157  }
158  std::cout << std::endl; // Extra line is to clear up the code a bit.
159 }
160 
164 template<typename T>
166  util::ParamData& d,
167  const size_t indent,
168  const typename std::enable_if<!arma::is_arma_type<T>::value>::type* = 0,
169  const typename std::enable_if<!data::HasSerialize<T>::value>::type* = 0,
170  const typename std::enable_if<!std::is_same<T,
171  std::tuple<data::DatasetInfo, arma::mat>>::value>::type* = 0,
172  const typename std::enable_if<util::IsStdVector<T>::value>::type* = 0)
173 {
174  const std::string prefix(indent, ' ');
175 
190  std::cout << prefix << "# Detect if the parameter was passed; set if so."
191  << std::endl;
192 
193  std::string name = GetValidName(d.name);
194 
195  if (!d.required)
196  {
197  std::cout << prefix << "if " << name << " is not None:"
198  << std::endl;
199  std::cout << prefix << " if isinstance(" << name << ", list):"
200  << std::endl;
201  std::cout << prefix << " if len(" << name << ") > 0:"
202  << std::endl;
203  std::cout << prefix << " if isinstance(" << name << "[0], "
204  << GetPrintableType<typename T::value_type>(d) << "):" << std::endl;
205  std::cout << prefix << " SetParam[" << GetCythonType<T>(d)
206  << "](p, <const string> '" << d.name << "', ";
207  // Strings need special handling.
208  if (GetCythonType<T>(d) == "vector[string]")
209  std::cout << "[i.encode(\"UTF-8\") for i in " << name << "]";
210  else
211  std::cout << name;
212  std::cout << ")" << std::endl;
213  std::cout << prefix << " p.SetPassed(<const string> '" << d.name
214  << "')" << std::endl;
215  std::cout << prefix << " else:" << std::endl;
216  std::cout << prefix << " raise TypeError(" <<"\"'"<< d.name
217  << "' must have type \'" << GetPrintableType<T>(d)
218  << "'!\")" << std::endl;
219  std::cout << prefix << " else:" << std::endl;
220  std::cout << prefix << " raise TypeError(" <<"\"'"<< d.name
221  << "' must have type \'list'!\")" << std::endl;
222  }
223  else
224  {
225  std::cout << prefix << "if isinstance(" << name << ", list):"
226  << std::endl;
227  std::cout << prefix << " if len(" << name << ") > 0:"
228  << std::endl;
229  std::cout << prefix << " if isinstance(" << name << "[0], "
230  << GetPrintableType<typename T::value_type>(d) << "):" << std::endl;
231  std::cout << prefix << " SetParam[" << GetCythonType<T>(d)
232  << "](p, <const string> '" << d.name << "', ";
233  // Strings need special handling.
234  if (GetCythonType<T>(d) == "vector[string]")
235  std::cout << "[i.encode(\"UTF-8\") for i in " << name << "]";
236  else
237  std::cout << name;
238  std::cout << ")" << std::endl;
239  std::cout << prefix << " p.SetPassed(<const string> '" << d.name
240  << "')" << std::endl;
241  std::cout << prefix << " else:" << std::endl;
242  std::cout << prefix << " raise TypeError(" <<"\"'"<< d.name
243  << "' must have type \'" << GetPrintableType<T>(d)
244  << "'!\")" << std::endl;
245  std::cout << prefix << "else:" << std::endl;
246  std::cout << prefix << " raise TypeError(" <<"\"'"<< d.name
247  << "' must have type \'list'!\")" << std::endl;
248  }
249 }
250 
254 template<typename T>
256  util::ParamData& d,
257  const size_t indent,
258  const typename std::enable_if<!util::IsStdVector<T>::value>::type* = 0,
259  const typename std::enable_if<arma::is_arma_type<T>::value>::type* = 0)
260 {
261  const std::string prefix(indent, ' ');
262 
278  std::cout << prefix << "# Detect if the parameter was passed; set if so."
279  << std::endl;
280  std::string name = GetValidName(d.name);
281 
282  if (!d.required)
283  {
284  if (T::is_row || T::is_col)
285  {
286  std::cout << prefix << "if " << name << " is not None:" << std::endl;
287  std::cout << prefix << " " << name << "_tuple = to_matrix("
288  << name << ", dtype=" << GetNumpyType<typename T::elem_type>()
289  << ", copy=p.Has('copy_all_inputs'))" << std::endl;
290  std::cout << prefix << " if len(" << name << "_tuple[0].shape) > 1:"
291  << std::endl;
292  std::cout << prefix << " if " << name << "_tuple[0]"
293  << ".shape[0] == 1 or " << name << "_tuple[0].shape[1] == 1:"
294  << std::endl;
295  std::cout << prefix << " " << name << "_tuple[0].shape = ("
296  << d.name << "_tuple[0].size,)" << std::endl;
297  std::cout << prefix << " " << name << "_mat = arma_numpy.numpy_to_"
298  << GetArmaType<T>() << "_" << GetNumpyTypeChar<T>() << "(" << name
299  << "_tuple[0], " << name << "_tuple[1])" << std::endl;
300  std::cout << prefix << " SetParam[" << GetCythonType<T>(d)
301  << "](p, <const string> '" << d.name << "', dereference("
302  << name << "_mat))"<< std::endl;
303  std::cout << prefix << " p.SetPassed(<const string> '" << d.name
304  << "')" << std::endl;
305  std::cout << prefix << " del " << name << "_mat" << std::endl;
306  }
307  else
308  {
309  std::cout << prefix << "if " << name << " is not None:" << std::endl;
310  std::cout << prefix << " " << name << "_tuple = to_matrix("
311  << name << ", dtype=" << GetNumpyType<typename T::elem_type>()
312  << ", copy=p.Has('copy_all_inputs'))" << std::endl;
313  std::cout << prefix << " if len(" << name << "_tuple[0].shape"
314  << ") < 2:" << std::endl;
315  std::cout << prefix << " " << name << "_tuple[0].shape = (" << name
316  << "_tuple[0].shape[0], 1)" << std::endl;
317  std::cout << prefix << " " << name << "_mat = arma_numpy.numpy_to_"
318  << GetArmaType<T>() << "_" << GetNumpyTypeChar<T>() << "(" << name
319  << "_tuple[0], " << name << "_tuple[1])" << std::endl;
320  std::cout << prefix << " SetParam[" << GetCythonType<T>(d)
321  << "](p, <const string> '" << d.name << "', dereference("
322  << name << "_mat))"<< std::endl;
323  std::cout << prefix << " p.SetPassed(<const string> '" << d.name
324  << "')" << std::endl;
325  std::cout << prefix << " del " << name << "_mat" << std::endl;
326  }
327  }
328  else
329  {
330  if (T::is_row || T::is_col)
331  {
332  std::cout << prefix << name << "_tuple = to_matrix(" << name
333  << ", dtype=" << GetNumpyType<typename T::elem_type>()
334  << ", copy=p.Has('copy_all_inputs'))" << std::endl;
335  std::cout << prefix << "if len(" << name << "_tuple[0].shape) > 1:"
336  << std::endl;
337  std::cout << prefix << " if " << name << "_tuple[0].shape[0] == 1 or "
338  << name << "_tuple[0].shape[1] == 1:" << std::endl;
339  std::cout << prefix << " " << name << "_tuple[0].shape = ("
340  << name << "_tuple[0].size,)" << std::endl;
341  std::cout << prefix << name << "_mat = arma_numpy.numpy_to_"
342  << GetArmaType<T>() << "_" << GetNumpyTypeChar<T>() << "(" << name
343  << "_tuple[0], " << name << "_tuple[1])" << std::endl;
344  std::cout << prefix << "SetParam[" << GetCythonType<T>(d)
345  << "](p, <const string> '" << d.name << "', dereference("
346  << name << "_mat))"<< std::endl;
347  std::cout << prefix << "p.SetPassed(<const string> '" << d.name << "')"
348  << std::endl;
349  std::cout << prefix << "del " << name << "_mat" << std::endl;
350  }
351  else
352  {
353  std::cout << prefix << name << "_tuple = to_matrix(" << name
354  << ", dtype=" << GetNumpyType<typename T::elem_type>()
355  << ", copy=p.Has('copy_all_inputs'))" << std::endl;
356  std::cout << prefix << "if len(" << name << "_tuple[0].shape) < 2:"
357  << std::endl;
358  std::cout << prefix << " " << name << "_tuple[0].shape = (" << name
359  << "_tuple[0].shape[0], 1)" << std::endl;
360  std::cout << prefix << name << "_mat = arma_numpy.numpy_to_"
361  << GetArmaType<T>() << "_" << GetNumpyTypeChar<T>() << "(" << name
362  << "_tuple[0], " << name << "_tuple[1])" << std::endl;
363  std::cout << prefix << "SetParam[" << GetCythonType<T>(d)
364  << "](p, <const string> '" << d.name << "', dereference(" << name
365  << "_mat))" << std::endl;
366  std::cout << prefix << "p.SetPassed(<const string> '" << d.name << "')"
367  << std::endl;
368  std::cout << prefix << "del " << name << "_mat" << std::endl;
369  }
370  }
371  std::cout << std::endl;
372 }
373 
377 template<typename T>
379  util::ParamData& d,
380  const size_t indent,
381  const typename std::enable_if<!util::IsStdVector<T>::value>::type* = 0,
382  const typename std::enable_if<!arma::is_arma_type<T>::value>::type* = 0,
383  const typename std::enable_if<data::HasSerialize<T>::value>::type* = 0)
384 {
385  // First, get the correct class name if needed.
386  std::string strippedType, printedType, defaultsType;
387  StripType(d.cppType, strippedType, printedType, defaultsType);
388 
389  std::string name = GetValidName(d.name);
390 
391  const std::string prefix(indent, ' ');
392 
409  std::cout << prefix << "# Detect if the parameter was passed; set if so."
410  << std::endl;
411  if (!d.required)
412  {
413  std::cout << prefix << "if " << name << " is not None:" << std::endl;
414  std::cout << prefix << " try:" << std::endl;
415  std::cout << prefix << " SetParamPtr[" << strippedType << "](p, '" << d.name
416  << "', (<" << strippedType << "Type?> " << name << ").modelptr, "
417  << "p.Has('copy_all_inputs'))" << std::endl;
418  std::cout << prefix << " except TypeError as e:" << std::endl;
419  std::cout << prefix << " if type(" << name << ").__name__ == '"
420  << strippedType << "Type':" << std::endl;
421  std::cout << prefix << " SetParamPtr[" << strippedType << "](p, '"
422  << d.name << "', (<" << strippedType << "Type> " << name
423  << ").modelptr, p.Has('copy_all_inputs'))" << std::endl;
424  std::cout << prefix << " else:" << std::endl;
425  std::cout << prefix << " raise e" << std::endl;
426  std::cout << prefix << " p.SetPassed(<const string> '" << d.name << "')"
427  << std::endl;
428  }
429  else
430  {
431  std::cout << prefix << "try:" << std::endl;
432  std::cout << prefix << " SetParamPtr[" << strippedType << "](p, '" << d.name
433  << "', (<" << strippedType << "Type?> " << name << ").modelptr, "
434  << "p.Has('copy_all_inputs'))" << std::endl;
435  std::cout << prefix << "except TypeError as e:" << std::endl;
436  std::cout << prefix << " if type(" << name << ").__name__ == '"
437  << strippedType << "Type':" << std::endl;
438  std::cout << prefix << " SetParamPtr[" << strippedType << "](p,'" << d.name
439  << "', (<" << strippedType << "Type> " << name << ").modelptr, "
440  << "p.Has('copy_all_inputs'))" << std::endl;
441  std::cout << prefix << " else:" << std::endl;
442  std::cout << prefix << " raise e" << std::endl;
443  std::cout << prefix << "p.SetPassed(<const string> '" << d.name << "')"
444  << std::endl;
445  }
446  std::cout << std::endl;
447 }
448 
452 template<typename T>
454  util::ParamData& d,
455  const size_t indent,
456  const typename std::enable_if<!util::IsStdVector<T>::value>::type* = 0,
457  const typename std::enable_if<std::is_same<T,
458  std::tuple<data::DatasetInfo, arma::mat>>::value>::type* = 0)
459 {
460  std::string name = GetValidName(d.name);
461 
462  // The user should pass in a matrix type of some sort.
463  const std::string prefix(indent, ' ');
464 
479  std::cout << prefix << "cdef np.ndarray " << name << "_dims" << std::endl;
480  std::cout << prefix << "# Detect if the parameter was passed; set if so."
481  << std::endl;
482  if (!d.required)
483  {
484  std::cout << prefix << "cdef extern from \"numpy/arrayobject.h\":" << std::endl;
485  std::cout << prefix << " void* PyArray_DATA(np.ndarray arr)" << std::endl;
486  std::cout << prefix << "if " << d.name << " is not None:" << std::endl;
487  std::cout << prefix << " " << d.name << "_tuple = to_matrix_with_info("
488  << d.name << ", dtype=np.double, copy=p.Has('copy_all_inputs'))"
489  << std::endl;
490  std::cout << prefix << " if len(" << name << "_tuple[0].shape"
491  << ") < 2:" << std::endl;
492  std::cout << prefix << " " << name << "_tuple[0].shape = (" << name
493  << "_tuple[0].shape[0], 1)" << std::endl;
494  std::cout << prefix << " " << name << "_mat = arma_numpy.numpy_to_mat_d("
495  << name << "_tuple[0], " << name << "_tuple[1])" << std::endl;
496  std::cout << prefix << " " << name << "_dims = " << name
497  << "_tuple[2]" << std::endl;
498  std::cout << prefix << " SetParamWithInfo[arma.Mat[double]](p, <const "
499  << "string> '" << d.name << "', dereference(" << d.name << "_mat), "
500  << "<const cbool*> PyArray_DATA(" << d.name << "_dims))" << std::endl;
501  std::cout << prefix << " p.SetPassed(<const string> '" << d.name
502  << "')" << std::endl;
503  std::cout << prefix << " del " << name << "_mat" << std::endl;
504  }
505  else
506  {
507  std::cout << prefix << "cdef extern from \"numpy/arrayobject.h\":" << std::endl;
508  std::cout << prefix << " void* PyArray_DATA(np.ndarray arr)" << std::endl;
509  std::cout << prefix << d.name << "_tuple = to_matrix_with_info(" << d.name
510  << ", dtype=np.double, copy=p.Has('copy_all_inputs'))"
511  << std::endl;
512  std::cout << prefix << "if len(" << name << "_tuple[0].shape"
513  << ") < 2:" << std::endl;
514  std::cout << prefix << " " << name << "_tuple[0].shape = (" << name
515  << "_tuple[0].shape[0], 1)" << std::endl;
516  std::cout << prefix << name << "_mat = arma_numpy.numpy_to_mat_d("
517  << name << "_tuple[0], " << name << "_tuple[1])" << std::endl;
518  std::cout << prefix << name << "_dims = " << name << "_tuple[2]"
519  << std::endl;
520  std::cout << prefix << "SetParamWithInfo[arma.Mat[double]](p, <const "
521  << "string> '" << d.name << "', dereference(" << d.name << "_mat), "
522  << "<const cbool*> PyArray_DATA(" << d.name << "_dims))" << std::endl;
523  std::cout << prefix << "p.SetPassed(<const string> '" << d.name << "')"
524  << std::endl;
525  std::cout << prefix << "del " << name << "_mat" << std::endl;
526  }
527  std::cout << std::endl;
528 }
529 
541 template<typename T>
543  const void* input,
544  void* /* output */)
545 {
546  PrintInputProcessing<typename std::remove_pointer<T>::type>(d,
547  *((size_t*) input));
548 }
549 
550 } // namespace python
551 } // namespace bindings
552 } // namespace mlpack
553 
554 #endif
Linear algebra utility functions, generally performed on matrices or vectors.
The core includes that mlpack expects; standard C++ includes, Armadillo, cereal, and a few basic mlpa...
python
Definition: CMakeLists.txt:7
This structure holds all of the information about a single parameter, including its value (which is s...
Definition: param_data.hpp:38
Metaprogramming structure for vector detection.
void PrintInputProcessing(util::ParamData &d, const size_t indent, const typename std::enable_if<!util::IsStdVector< T >::value >::type *=0, const typename std::enable_if<!arma::is_arma_type< T >::value >::type *=0, const typename std::enable_if<!data::HasSerialize< T >::value >::type *=0, const typename std::enable_if<!std::is_same< T, std::tuple< data::DatasetInfo, arma::mat >>::value >::type *=0)
Print input processing for a standard option type.
std::string name
Name of this parameter.
Definition: param_data.hpp:42
bool required
True if this option is required.
Definition: param_data.hpp:57
void StripType(const std::string &inputType, std::string &strippedType, std::string &printedType, std::string &defaultsType)
Given an input type like, e.g., "LogisticRegression<>", return three types that can be used in Python...
Definition: strip_type.hpp:28
std::string cppType
The true name of the type, as it would be written in C++.
Definition: param_data.hpp:67
if(NOT BUILD_GO_SHLIB) macro(add_go_binding name) endmacro() return() endif() endmacro() macro(post_go_setup) if(BUILD_GO_BINDINGS) file(APPEND "$
Definition: CMakeLists.txt:3
std::string GetValidName(const std::string &paramName)